compiler: implement numpy.full (#424).

This commit is contained in:
whitequark 2016-07-07 06:33:30 +00:00
parent 7a671fb2fd
commit d90fd7dc00
5 changed files with 45 additions and 1 deletions

View File

@ -170,6 +170,9 @@ def fn_min():
def fn_max(): def fn_max():
return types.TBuiltinFunction("max") return types.TBuiltinFunction("max")
def fn_make_array():
return types.TBuiltinFunction("make_array")
def fn_print(): def fn_print():
return types.TBuiltinFunction("print") return types.TBuiltinFunction("print")

View File

@ -170,6 +170,10 @@ class ASTSynthesizer:
typ = builtins.fn_array() typ = builtins.fn_array()
return asttyped.NameConstantT(value=None, type=typ, return asttyped.NameConstantT(value=None, type=typ,
loc=self._add("numpy.array")) loc=self._add("numpy.array"))
elif value is numpy.full:
typ = builtins.fn_make_array()
return asttyped.NameConstantT(value=None, type=typ,
loc=self._add("numpy.full"))
elif isinstance(value, (int, float)): elif isinstance(value, (int, float)):
if isinstance(value, int): if isinstance(value, int):
typ = builtins.TInt() typ = builtins.TInt()

View File

@ -1681,6 +1681,21 @@ class ARTIQIRGenerator(algorithm.Visitor):
return self.append(ir.Select(cond, arg0, arg1)) return self.append(ir.Select(cond, arg0, arg1))
else: else:
assert False assert False
elif types.is_builtin(typ, "make_array"):
if len(node.args) == 2 and len(node.keywords) == 0:
arg0, arg1 = map(self.visit, node.args)
result = self.append(ir.Alloc([arg0], node.type))
def body_gen(index):
self.append(ir.SetElem(result, index, arg1))
return self.append(ir.Arith(ast.Add(loc=None), index,
ir.Constant(1, arg0.type)))
self._make_loop(ir.Constant(0, self._size_type),
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, arg0)),
body_gen)
return result
else:
assert False
elif types.is_builtin(typ, "print"): elif types.is_builtin(typ, "print"):
self.polymorphic_print([self.visit(arg) for arg in node.args], self.polymorphic_print([self.visit(arg) for arg in node.args],
separator=" ", suffix="\n") separator=" ", suffix="\n")
@ -1725,7 +1740,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
else: else:
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"builtin function '{name}' cannot be used in this context", "builtin function '{name}' cannot be used in this context",
{"name": typ.name}, {"name": typ.find().name},
node.loc) node.loc)
self.engine.process(diag) self.engine.process(diag)

View File

@ -845,6 +845,23 @@ class Inferencer(algorithm.Visitor):
pass pass
else: else:
diagnose(valid_forms()) diagnose(valid_forms())
elif types.is_builtin(typ, "make_array"):
valid_forms = lambda: [
valid_form("numpy.full(count:int32, value:'a) -> numpy.array(elt='a)")
]
self._unify(node.type, builtins.TArray(),
node.loc, None)
if len(node.args) == 2 and len(node.keywords) == 0:
arg0, arg1 = node.args
self._unify(arg0.type, builtins.TInt32(),
arg0.loc, None)
self._unify(arg1.type, node.type.find()["elt"],
arg1.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "rtio_log"): elif types.is_builtin(typ, "rtio_log"):
valid_forms = lambda: [ valid_forms = lambda: [
valid_form("rtio_log(channel:str, args...) -> None"), valid_form("rtio_log(channel:str, args...) -> None"),

View File

@ -109,6 +109,10 @@ class _RPC(EnvExperiment):
def numpy_things(self): def numpy_things(self):
return (numpy.int32(10), numpy.int64(20), numpy.array([42,])) return (numpy.int32(10), numpy.int64(20), numpy.array([42,]))
@kernel
def numpy_full(self):
return numpy.full(10, 20)
@kernel @kernel
def builtin(self): def builtin(self):
sleep(1.0) sleep(1.0)
@ -126,6 +130,7 @@ class RPCTest(ExperimentCase):
self.assertEqual(exp.args1kwargs2(), 2) self.assertEqual(exp.args1kwargs2(), 2)
self.assertEqual(exp.numpy_things(), self.assertEqual(exp.numpy_things(),
(numpy.int32(10), numpy.int64(20), numpy.array([42,]))) (numpy.int32(10), numpy.int64(20), numpy.array([42,])))
self.assertTrue((exp.numpy_full() == numpy.full(10, 20)).all())
exp.builtin() exp.builtin()