forked from M-Labs/artiq
compiler: implement numpy.full (#424).
This commit is contained in:
parent
7a671fb2fd
commit
d90fd7dc00
@ -170,6 +170,9 @@ def fn_min():
|
||||
def fn_max():
|
||||
return types.TBuiltinFunction("max")
|
||||
|
||||
def fn_make_array():
|
||||
return types.TBuiltinFunction("make_array")
|
||||
|
||||
def fn_print():
|
||||
return types.TBuiltinFunction("print")
|
||||
|
||||
|
@ -170,6 +170,10 @@ class ASTSynthesizer:
|
||||
typ = builtins.fn_array()
|
||||
return asttyped.NameConstantT(value=None, type=typ,
|
||||
loc=self._add("numpy.array"))
|
||||
elif value is numpy.full:
|
||||
typ = builtins.fn_make_array()
|
||||
return asttyped.NameConstantT(value=None, type=typ,
|
||||
loc=self._add("numpy.full"))
|
||||
elif isinstance(value, (int, float)):
|
||||
if isinstance(value, int):
|
||||
typ = builtins.TInt()
|
||||
|
@ -1681,6 +1681,21 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
return self.append(ir.Select(cond, arg0, arg1))
|
||||
else:
|
||||
assert False
|
||||
elif types.is_builtin(typ, "make_array"):
|
||||
if len(node.args) == 2 and len(node.keywords) == 0:
|
||||
arg0, arg1 = map(self.visit, node.args)
|
||||
|
||||
result = self.append(ir.Alloc([arg0], node.type))
|
||||
def body_gen(index):
|
||||
self.append(ir.SetElem(result, index, arg1))
|
||||
return self.append(ir.Arith(ast.Add(loc=None), index,
|
||||
ir.Constant(1, arg0.type)))
|
||||
self._make_loop(ir.Constant(0, self._size_type),
|
||||
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, arg0)),
|
||||
body_gen)
|
||||
return result
|
||||
else:
|
||||
assert False
|
||||
elif types.is_builtin(typ, "print"):
|
||||
self.polymorphic_print([self.visit(arg) for arg in node.args],
|
||||
separator=" ", suffix="\n")
|
||||
@ -1725,7 +1740,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
else:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"builtin function '{name}' cannot be used in this context",
|
||||
{"name": typ.name},
|
||||
{"name": typ.find().name},
|
||||
node.loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
|
@ -845,6 +845,23 @@ class Inferencer(algorithm.Visitor):
|
||||
pass
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "make_array"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("numpy.full(count:int32, value:'a) -> numpy.array(elt='a)")
|
||||
]
|
||||
|
||||
self._unify(node.type, builtins.TArray(),
|
||||
node.loc, None)
|
||||
|
||||
if len(node.args) == 2 and len(node.keywords) == 0:
|
||||
arg0, arg1 = node.args
|
||||
|
||||
self._unify(arg0.type, builtins.TInt32(),
|
||||
arg0.loc, None)
|
||||
self._unify(arg1.type, node.type.find()["elt"],
|
||||
arg1.loc, None)
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "rtio_log"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("rtio_log(channel:str, args...) -> None"),
|
||||
|
@ -109,6 +109,10 @@ class _RPC(EnvExperiment):
|
||||
def numpy_things(self):
|
||||
return (numpy.int32(10), numpy.int64(20), numpy.array([42,]))
|
||||
|
||||
@kernel
|
||||
def numpy_full(self):
|
||||
return numpy.full(10, 20)
|
||||
|
||||
@kernel
|
||||
def builtin(self):
|
||||
sleep(1.0)
|
||||
@ -126,6 +130,7 @@ class RPCTest(ExperimentCase):
|
||||
self.assertEqual(exp.args1kwargs2(), 2)
|
||||
self.assertEqual(exp.numpy_things(),
|
||||
(numpy.int32(10), numpy.int64(20), numpy.array([42,])))
|
||||
self.assertTrue((exp.numpy_full() == numpy.full(10, 20)).all())
|
||||
exp.builtin()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user