From d90fd7dc00db60b56273f896d3aea33339db3133 Mon Sep 17 00:00:00 2001 From: whitequark Date: Thu, 7 Jul 2016 06:33:30 +0000 Subject: [PATCH] compiler: implement numpy.full (#424). --- artiq/compiler/builtins.py | 3 +++ artiq/compiler/embedding.py | 4 ++++ artiq/compiler/transforms/artiq_ir_generator.py | 17 ++++++++++++++++- artiq/compiler/transforms/inferencer.py | 17 +++++++++++++++++ artiq/test/coredevice/test_embedding.py | 5 +++++ 5 files changed, 45 insertions(+), 1 deletion(-) diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index f060330f0..830769eec 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -170,6 +170,9 @@ def fn_min(): def fn_max(): return types.TBuiltinFunction("max") +def fn_make_array(): + return types.TBuiltinFunction("make_array") + def fn_print(): return types.TBuiltinFunction("print") diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 9295c05ad..7b4806223 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -170,6 +170,10 @@ class ASTSynthesizer: typ = builtins.fn_array() return asttyped.NameConstantT(value=None, type=typ, loc=self._add("numpy.array")) + elif value is numpy.full: + typ = builtins.fn_make_array() + return asttyped.NameConstantT(value=None, type=typ, + loc=self._add("numpy.full")) elif isinstance(value, (int, float)): if isinstance(value, int): typ = builtins.TInt() diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 549bbef21..ee0232fc6 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -1681,6 +1681,21 @@ class ARTIQIRGenerator(algorithm.Visitor): return self.append(ir.Select(cond, arg0, arg1)) else: assert False + elif types.is_builtin(typ, "make_array"): + if len(node.args) == 2 and len(node.keywords) == 0: + arg0, arg1 = map(self.visit, node.args) + + result = self.append(ir.Alloc([arg0], node.type)) + def body_gen(index): + self.append(ir.SetElem(result, index, arg1)) + return self.append(ir.Arith(ast.Add(loc=None), index, + ir.Constant(1, arg0.type))) + self._make_loop(ir.Constant(0, self._size_type), + lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, arg0)), + body_gen) + return result + else: + assert False elif types.is_builtin(typ, "print"): self.polymorphic_print([self.visit(arg) for arg in node.args], separator=" ", suffix="\n") @@ -1725,7 +1740,7 @@ class ARTIQIRGenerator(algorithm.Visitor): else: diag = diagnostic.Diagnostic("error", "builtin function '{name}' cannot be used in this context", - {"name": typ.name}, + {"name": typ.find().name}, node.loc) self.engine.process(diag) diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index b93e1528b..ea488f64d 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -845,6 +845,23 @@ class Inferencer(algorithm.Visitor): pass else: diagnose(valid_forms()) + elif types.is_builtin(typ, "make_array"): + valid_forms = lambda: [ + valid_form("numpy.full(count:int32, value:'a) -> numpy.array(elt='a)") + ] + + self._unify(node.type, builtins.TArray(), + node.loc, None) + + if len(node.args) == 2 and len(node.keywords) == 0: + arg0, arg1 = node.args + + self._unify(arg0.type, builtins.TInt32(), + arg0.loc, None) + self._unify(arg1.type, node.type.find()["elt"], + arg1.loc, None) + else: + diagnose(valid_forms()) elif types.is_builtin(typ, "rtio_log"): valid_forms = lambda: [ valid_form("rtio_log(channel:str, args...) -> None"), diff --git a/artiq/test/coredevice/test_embedding.py b/artiq/test/coredevice/test_embedding.py index 7a1037ef2..1ad29201f 100644 --- a/artiq/test/coredevice/test_embedding.py +++ b/artiq/test/coredevice/test_embedding.py @@ -109,6 +109,10 @@ class _RPC(EnvExperiment): def numpy_things(self): return (numpy.int32(10), numpy.int64(20), numpy.array([42,])) + @kernel + def numpy_full(self): + return numpy.full(10, 20) + @kernel def builtin(self): sleep(1.0) @@ -126,6 +130,7 @@ class RPCTest(ExperimentCase): self.assertEqual(exp.args1kwargs2(), 2) self.assertEqual(exp.numpy_things(), (numpy.int32(10), numpy.int64(20), numpy.array([42,]))) + self.assertTrue((exp.numpy_full() == numpy.full(10, 20)).all()) exp.builtin()