diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 2e0bd629b..2d96d1372 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -708,6 +708,7 @@ class Inferencer(algorithm.Visitor): elif types.is_builtin(typ, "round"): valid_forms = lambda: [ valid_form("round(x:float) -> int(width='a)"), + valid_form("round(x:float, width='b:) -> int(width='b)") ] self._unify(node.type, builtins.TInt(), @@ -718,6 +719,19 @@ class Inferencer(algorithm.Visitor): self._unify(arg.type, builtins.TFloat(), arg.loc, None) + elif len(node.args) == 1 and len(node.keywords) == 1 and \ + builtins.is_numeric(node.args[0].type) and \ + node.keywords[0].arg == 'width': + width = node.keywords[0].value + if not (isinstance(width, asttyped.NumT) and isinstance(width.n, int)): + diag = diagnostic.Diagnostic("error", + "the width argument of round() must be an integer literal", {}, + node.keywords[0].loc) + self.engine.process(diag) + return + + self._unify(node.type, builtins.TInt(types.TValue(width.n)), + node.loc, None) else: diagnose(valid_forms()) elif types.is_builtin(typ, "print"): diff --git a/lit-test/test/inferencer/builtin_calls.py b/lit-test/test/inferencer/builtin_calls.py index 61a97fb4a..1b3157e2e 100644 --- a/lit-test/test/inferencer/builtin_calls.py +++ b/lit-test/test/inferencer/builtin_calls.py @@ -30,3 +30,6 @@ len([]) # CHECK-L: round:(1.0:float):int(width='h) round(1.0) + +# CHECK-L: round:(1.0:float, width=64:int(width='i)):int(width=64) +round(1.0, width=64)