diff --git a/artiq/compiler/module.py b/artiq/compiler/module.py index a11c8954e..dcbb1aab3 100644 --- a/artiq/compiler/module.py +++ b/artiq/compiler/module.py @@ -48,6 +48,7 @@ class Module: self.globals = src.globals int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine) + cast_monomorphizer = transforms.CastMonomorphizer(engine=self.engine) inferencer = transforms.Inferencer(engine=self.engine) monomorphism_validator = validators.MonomorphismValidator(engine=self.engine) escape_validator = validators.EscapeValidator(engine=self.engine) @@ -63,6 +64,7 @@ class Module: interleaver = transforms.Interleaver(engine=self.engine) invariant_detection = analyses.InvariantDetection(engine=self.engine) + cast_monomorphizer.visit(src.typedtree) int_monomorphizer.visit(src.typedtree) inferencer.visit(src.typedtree) monomorphism_validator.visit(src.typedtree) diff --git a/artiq/compiler/testbench/inferencer.py b/artiq/compiler/testbench/inferencer.py index 4179cc777..3e36d0a38 100644 --- a/artiq/compiler/testbench/inferencer.py +++ b/artiq/compiler/testbench/inferencer.py @@ -1,7 +1,7 @@ import sys, fileinput, os from pythonparser import source, diagnostic, algorithm, parse_buffer from .. import prelude, types -from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer +from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, CastMonomorphizer from ..transforms import IODelayEstimator class Printer(algorithm.Visitor): @@ -84,6 +84,7 @@ def main(): typed = ASTTypedRewriter(engine=engine, prelude=prelude.globals()).visit(parsed) Inferencer(engine=engine).visit(typed) if monomorphize: + CastMonomorphizer(engine=engine).visit(typed) IntMonomorphizer(engine=engine).visit(typed) Inferencer(engine=engine).visit(typed) if iodelay: diff --git a/artiq/compiler/transforms/__init__.py b/artiq/compiler/transforms/__init__.py index 665fd3ea1..a4a56d82d 100644 --- a/artiq/compiler/transforms/__init__.py +++ b/artiq/compiler/transforms/__init__.py @@ -1,6 +1,7 @@ from .asttyped_rewriter import ASTTypedRewriter from .inferencer import Inferencer from .int_monomorphizer import IntMonomorphizer +from .cast_monomorphizer import CastMonomorphizer from .iodelay_estimator import IODelayEstimator from .artiq_ir_generator import ARTIQIRGenerator from .dead_code_eliminator import DeadCodeEliminator diff --git a/artiq/compiler/transforms/cast_monomorphizer.py b/artiq/compiler/transforms/cast_monomorphizer.py new file mode 100644 index 000000000..c12eb9663 --- /dev/null +++ b/artiq/compiler/transforms/cast_monomorphizer.py @@ -0,0 +1,24 @@ +""" +:class:`CastMonomorphizer` uses explicit casts to monomorphize +expressions of undetermined integer type to either 32 or 64 bits. +""" + +from pythonparser import algorithm, diagnostic +from .. import types, builtins + +class CastMonomorphizer(algorithm.Visitor): + def __init__(self, engine): + self.engine = engine + + def visit_CallT(self, node): + self.generic_visit(node) + + if (types.is_builtin(node.func.type, "int") or + types.is_builtin(node.func.type, "int32") or + types.is_builtin(node.func.type, "int64")): + typ = node.type.find() + if (not types.is_var(typ["width"]) and + builtins.is_int(node.args[0].type) and + types.is_var(node.args[0].type.find()["width"])): + node.args[0].type.unify(typ) + diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index b55dfa98d..0c72b5ac6 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -780,7 +780,6 @@ class Inferencer(algorithm.Visitor): elif types.is_builtin(typ, "round"): valid_forms = lambda: [ valid_form("round(x:float) -> numpy.int?"), - valid_form("round(x:float, width=?) -> numpy.int?") ] self._unify(node.type, builtins.TInt(), @@ -791,19 +790,6 @@ class Inferencer(algorithm.Visitor): self._unify(arg.type, builtins.TFloat(), arg.loc, None) - elif len(node.args) == 1 and len(node.keywords) == 1 and \ - builtins.is_numeric(node.args[0].type) and \ - node.keywords[0].arg == 'width': - width = node.keywords[0].value - if not (isinstance(width, asttyped.NumT) and isinstance(width.n, int)): - diag = diagnostic.Diagnostic("error", - "the width argument of round() must be an integer literal", {}, - node.keywords[0].loc) - self.engine.process(diag) - return - - self._unify(node.type, builtins.TInt(types.TValue(width.n)), - node.loc, None) else: diagnose(valid_forms()) elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"): diff --git a/artiq/test/lit/inferencer/builtin_calls.py b/artiq/test/lit/inferencer/builtin_calls.py index bd33dd140..be1b797d9 100644 --- a/artiq/test/lit/inferencer/builtin_calls.py +++ b/artiq/test/lit/inferencer/builtin_calls.py @@ -30,6 +30,3 @@ len([]) # CHECK-L: round:(1.0:float):numpy.int? round(1.0) - -# CHECK-L: round:(1.0:float, width=64:numpy.int?):numpy.int64 -round(1.0, width=64) diff --git a/artiq/test/lit/monomorphism/integers.py b/artiq/test/lit/monomorphism/integers.py index 5c0e8be7f..e6760fffb 100644 --- a/artiq/test/lit/monomorphism/integers.py +++ b/artiq/test/lit/monomorphism/integers.py @@ -6,6 +6,3 @@ x = 1 y = int(1) # CHECK-L: y: numpy.int32 - -z = round(1.0) -# CHECK-L: z: numpy.int32 diff --git a/artiq/test/lit/monomorphism/round.py b/artiq/test/lit/monomorphism/round.py new file mode 100644 index 000000000..74df18401 --- /dev/null +++ b/artiq/test/lit/monomorphism/round.py @@ -0,0 +1,11 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +mono %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: round:(1.0:float):numpy.int32 +round(1.0) + +# CHECK-L: round:(2.0:float):numpy.int32 +int32(round(2.0)) + +# CHECK-L: round:(3.0:float):numpy.int64 +int64(round(3.0))