diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index a8b87491a..3fe2972b2 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -145,6 +145,8 @@ class Inferencer(algorithm.Visitor): def visit_IfExpT(self, node): self.generic_visit(node) + self._unify(node.test.type, builtins.TBool(), + node.test.loc, None) self._unify(node.body.type, node.orelse.type, node.body.loc, node.orelse.loc) self._unify(node.type, node.body.type, @@ -788,6 +790,11 @@ class Inferencer(algorithm.Visitor): node.value = self._coerce_one(value_type, node.value, other_node=node.target) + def visit_If(self, node): + self.generic_visit(node) + self._unify(node.test.type, builtins.TBool(), + node.test.loc, None) + def visit_For(self, node): old_in_loop, self.in_loop = self.in_loop, True self.generic_visit(node) @@ -798,6 +805,8 @@ class Inferencer(algorithm.Visitor): old_in_loop, self.in_loop = self.in_loop, True self.generic_visit(node) self.in_loop = old_in_loop + self._unify(node.test.type, builtins.TBool(), + node.test.loc, None) def visit_Break(self, node): if not self.in_loop: diff --git a/lit-test/compiler/inferencer/error_unify.py b/lit-test/compiler/inferencer/error_unify.py index e81537d29..dd5c617a8 100644 --- a/lit-test/compiler/inferencer/error_unify.py +++ b/lit-test/compiler/inferencer/error_unify.py @@ -25,3 +25,12 @@ a = b # CHECK-L: ${LINE:+1}: error: type int(width='a) does not have an attribute 'x' (1).x + +# CHECK-L: ${LINE:+1}: error: cannot unify int(width='a) with bool +1 if 1 else 1 + +# CHECK-L: ${LINE:+1}: error: cannot unify int(width='a) with bool +if 1: pass + +# CHECK-L: ${LINE:+1}: error: cannot unify int(width='a) with bool +while 1: pass diff --git a/lit-test/compiler/inferencer/gcd.py b/lit-test/compiler/inferencer/gcd.py index e2d4b4779..87bd42716 100644 --- a/lit-test/compiler/inferencer/gcd.py +++ b/lit-test/compiler/inferencer/gcd.py @@ -3,7 +3,7 @@ def _gcd(a, b): if a < 0: a = -a - while a: + while a > 0: c = a a = b % a b = c diff --git a/lit-test/compiler/inferencer/unify.py b/lit-test/compiler/inferencer/unify.py index 48681fdcb..59abd4262 100644 --- a/lit-test/compiler/inferencer/unify.py +++ b/lit-test/compiler/inferencer/unify.py @@ -33,8 +33,8 @@ j = [] j += [1.0] # CHECK-L: j:list(elt=float) -1 if a else 2 -# CHECK-L: 1:int(width='f) if a:int(width='a) else 2:int(width='f):int(width='f) +1 if c else 2 +# CHECK-L: 1:int(width='f) if c:bool else 2:int(width='f):int(width='f) True and False # CHECK-L: True:bool and False:bool:bool