diff --git a/artiq/py2llvm/types.py b/artiq/py2llvm/types.py index 078051883..60768a349 100644 --- a/artiq/py2llvm/types.py +++ b/artiq/py2llvm/types.py @@ -180,6 +180,18 @@ def TList(elt=None): return TMono("list", {"elt": elt}) +def is_var(typ): + return isinstance(typ, TVar) + +def is_mono(typ, name, **params): + return isinstance(typ, TMono) and \ + typ.name == name and typ.params == params + +def is_numeric(typ): + return isinstance(typ, TMono) and \ + typ.name in ('int', 'float') + + class TypePrinter(object): """ A class that prints types using Python-like syntax and gives diff --git a/artiq/py2llvm/typing.py b/artiq/py2llvm/typing.py index 2fc33fc9d..943b5b6fe 100644 --- a/artiq/py2llvm/typing.py +++ b/artiq/py2llvm/typing.py @@ -284,6 +284,27 @@ class Inferencer(algorithm.Transformer): op_locs=node.op_locs, loc=node.loc) return self.visit(node) + def visit_UnaryOp(self, node): + node = self.generic_visit(node) + node = asttyped.UnaryOpT(type=types.TVar(), + op=node.op, operand=node.operand, + loc=node.loc) + return self.visit(node) + + def visit_BinOp(self, node): + node = self.generic_visit(node) + node = asttyped.BinOpT(type=types.TVar(), + left=node.left, op=node.op, right=node.right, + loc=node.loc) + return self.visit(node) + + def visit_Compare(self, node): + node = self.generic_visit(node) + node = asttyped.CompareT(type=types.TVar(), + left=node.left, ops=node.ops, comparators=node.comparators, + loc=node.loc) + return self.visit(node) + # Visitors that just unify types # def visit_ListT(self, node): @@ -310,6 +331,21 @@ class Inferencer(algorithm.Transformer): node.loc, value.loc, self._makenotes_elts(node.values, "an operand")) return node + def visit_UnaryOpT(self, node): + if isinstance(node.op, ast.Not): + node.type = types.TBool() + else: + operand_type = node.operand.type.find() + if types.is_numeric(operand_type): + node.type = operand_type + elif not types.is_var(operand_type): + diag = diagnostic.Diagnostic("error", + "expected operand to be of numeric type, not {type}", + {"type": types.TypePrinter().name(operand_type)}, + node.operand.loc) + self.engine.process(diag) + return node + def visit_Assign(self, node): node = self.generic_visit(node) if len(node.targets) > 1: @@ -375,7 +411,6 @@ class Inferencer(algorithm.Transformer): visit_SetComp = visit_unsupported visit_Str = visit_unsupported visit_Starred = visit_unsupported - visit_UnaryOp = visit_unsupported visit_Yield = visit_unsupported visit_YieldFrom = visit_unsupported diff --git a/lit-test/py2llvm/typing/error_unify.py b/lit-test/py2llvm/typing/error_unify.py index 1140ba61f..8f0faf23a 100644 --- a/lit-test/py2llvm/typing/error_unify.py +++ b/lit-test/py2llvm/typing/error_unify.py @@ -16,3 +16,6 @@ a = b 1 and False # CHECK-L: note: an operand of type int(width='a) # CHECK-L: note: an operand of type bool + +# CHECK-L: ${LINE:+1}: error: expected operand to be of numeric type, not list(elt='a) +~[] diff --git a/lit-test/py2llvm/typing/unify.py b/lit-test/py2llvm/typing/unify.py index 8ea23dd0e..64dfb20b7 100644 --- a/lit-test/py2llvm/typing/unify.py +++ b/lit-test/py2llvm/typing/unify.py @@ -41,3 +41,9 @@ True and False 1 and 0 # CHECK-L: 1:int(width='g) and 0:int(width='g):int(width='g) + +~1 +# CHECK-L: 1:int(width='h):int(width='h) + +not 1 +# CHECK-L: 1:int(width='i):bool