forked from M-Labs/artiq
compiler: implement numpy.full (#424).
This commit is contained in:
parent
7a671fb2fd
commit
d90fd7dc00
|
@ -170,6 +170,9 @@ def fn_min():
|
||||||
def fn_max():
|
def fn_max():
|
||||||
return types.TBuiltinFunction("max")
|
return types.TBuiltinFunction("max")
|
||||||
|
|
||||||
|
def fn_make_array():
|
||||||
|
return types.TBuiltinFunction("make_array")
|
||||||
|
|
||||||
def fn_print():
|
def fn_print():
|
||||||
return types.TBuiltinFunction("print")
|
return types.TBuiltinFunction("print")
|
||||||
|
|
||||||
|
|
|
@ -170,6 +170,10 @@ class ASTSynthesizer:
|
||||||
typ = builtins.fn_array()
|
typ = builtins.fn_array()
|
||||||
return asttyped.NameConstantT(value=None, type=typ,
|
return asttyped.NameConstantT(value=None, type=typ,
|
||||||
loc=self._add("numpy.array"))
|
loc=self._add("numpy.array"))
|
||||||
|
elif value is numpy.full:
|
||||||
|
typ = builtins.fn_make_array()
|
||||||
|
return asttyped.NameConstantT(value=None, type=typ,
|
||||||
|
loc=self._add("numpy.full"))
|
||||||
elif isinstance(value, (int, float)):
|
elif isinstance(value, (int, float)):
|
||||||
if isinstance(value, int):
|
if isinstance(value, int):
|
||||||
typ = builtins.TInt()
|
typ = builtins.TInt()
|
||||||
|
|
|
@ -1681,6 +1681,21 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
return self.append(ir.Select(cond, arg0, arg1))
|
return self.append(ir.Select(cond, arg0, arg1))
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
elif types.is_builtin(typ, "make_array"):
|
||||||
|
if len(node.args) == 2 and len(node.keywords) == 0:
|
||||||
|
arg0, arg1 = map(self.visit, node.args)
|
||||||
|
|
||||||
|
result = self.append(ir.Alloc([arg0], node.type))
|
||||||
|
def body_gen(index):
|
||||||
|
self.append(ir.SetElem(result, index, arg1))
|
||||||
|
return self.append(ir.Arith(ast.Add(loc=None), index,
|
||||||
|
ir.Constant(1, arg0.type)))
|
||||||
|
self._make_loop(ir.Constant(0, self._size_type),
|
||||||
|
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, arg0)),
|
||||||
|
body_gen)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
elif types.is_builtin(typ, "print"):
|
elif types.is_builtin(typ, "print"):
|
||||||
self.polymorphic_print([self.visit(arg) for arg in node.args],
|
self.polymorphic_print([self.visit(arg) for arg in node.args],
|
||||||
separator=" ", suffix="\n")
|
separator=" ", suffix="\n")
|
||||||
|
@ -1725,7 +1740,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
else:
|
else:
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"builtin function '{name}' cannot be used in this context",
|
"builtin function '{name}' cannot be used in this context",
|
||||||
{"name": typ.name},
|
{"name": typ.find().name},
|
||||||
node.loc)
|
node.loc)
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
|
|
@ -845,6 +845,23 @@ class Inferencer(algorithm.Visitor):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
diagnose(valid_forms())
|
diagnose(valid_forms())
|
||||||
|
elif types.is_builtin(typ, "make_array"):
|
||||||
|
valid_forms = lambda: [
|
||||||
|
valid_form("numpy.full(count:int32, value:'a) -> numpy.array(elt='a)")
|
||||||
|
]
|
||||||
|
|
||||||
|
self._unify(node.type, builtins.TArray(),
|
||||||
|
node.loc, None)
|
||||||
|
|
||||||
|
if len(node.args) == 2 and len(node.keywords) == 0:
|
||||||
|
arg0, arg1 = node.args
|
||||||
|
|
||||||
|
self._unify(arg0.type, builtins.TInt32(),
|
||||||
|
arg0.loc, None)
|
||||||
|
self._unify(arg1.type, node.type.find()["elt"],
|
||||||
|
arg1.loc, None)
|
||||||
|
else:
|
||||||
|
diagnose(valid_forms())
|
||||||
elif types.is_builtin(typ, "rtio_log"):
|
elif types.is_builtin(typ, "rtio_log"):
|
||||||
valid_forms = lambda: [
|
valid_forms = lambda: [
|
||||||
valid_form("rtio_log(channel:str, args...) -> None"),
|
valid_form("rtio_log(channel:str, args...) -> None"),
|
||||||
|
|
|
@ -109,6 +109,10 @@ class _RPC(EnvExperiment):
|
||||||
def numpy_things(self):
|
def numpy_things(self):
|
||||||
return (numpy.int32(10), numpy.int64(20), numpy.array([42,]))
|
return (numpy.int32(10), numpy.int64(20), numpy.array([42,]))
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def numpy_full(self):
|
||||||
|
return numpy.full(10, 20)
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def builtin(self):
|
def builtin(self):
|
||||||
sleep(1.0)
|
sleep(1.0)
|
||||||
|
@ -126,6 +130,7 @@ class RPCTest(ExperimentCase):
|
||||||
self.assertEqual(exp.args1kwargs2(), 2)
|
self.assertEqual(exp.args1kwargs2(), 2)
|
||||||
self.assertEqual(exp.numpy_things(),
|
self.assertEqual(exp.numpy_things(),
|
||||||
(numpy.int32(10), numpy.int64(20), numpy.array([42,])))
|
(numpy.int32(10), numpy.int64(20), numpy.array([42,])))
|
||||||
|
self.assertTrue((exp.numpy_full() == numpy.full(10, 20)).all())
|
||||||
exp.builtin()
|
exp.builtin()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue