diff --git a/artiq/py2llvm/types.py b/artiq/py2llvm/types.py index 97aff58ca..19ad56d48 100644 --- a/artiq/py2llvm/types.py +++ b/artiq/py2llvm/types.py @@ -163,13 +163,14 @@ class TValue(Type): def is_var(typ): return isinstance(typ.find(), TVar) -def is_mono(typ, name, **params): +def is_mono(typ, name=None, **params): typ = typ.find() params_match = True for param in params: - params_match = params_match and typ.params[param] == params[param] + params_match = params_match and \ + typ.params[param].find() == params[param].find() return isinstance(typ, TMono) and \ - typ.name == name and params_match + (name is None or (typ.name == name and params_match)) def is_tuple(typ, elts=None): typ = typ.find() diff --git a/artiq/py2llvm/typing.py b/artiq/py2llvm/typing.py index ec4991eb7..875920df3 100644 --- a/artiq/py2llvm/typing.py +++ b/artiq/py2llvm/typing.py @@ -283,7 +283,6 @@ class ASTTypedRewriter(algorithm.Transformer): # expr visit_Call = visit_unsupported - visit_Compare = visit_unsupported visit_Dict = visit_unsupported visit_DictComp = visit_unsupported visit_Ellipsis = visit_unsupported @@ -393,11 +392,14 @@ class Inferencer(algorithm.Visitor): node.attr_loc, [node.value.loc]) self.engine.process(diag) + def _unify_collection(self, element, collection): + # TODO: support more than just lists + self._unify(builtins.TList(element.type), collection.type, + element.loc, collection.loc) + def visit_SubscriptT(self, node): self.generic_visit(node) - # TODO: support more than just lists - self._unify(builtins.TList(node.type), node.value.type, - node.loc, node.value.loc) + self._unify_collection(element=node, collection=node.value) def visit_IfExpT(self, node): self.generic_visit(node) @@ -455,43 +457,39 @@ class Inferencer(algorithm.Visitor): def _coerce_one(self, typ, coerced_node, other_node): if coerced_node.type.find() == typ.find(): return coerced_node + elif isinstance(coerced_node, asttyped.CoerceT): + node.type, node.other_expr = typ, other_node else: node = asttyped.CoerceT(type=typ, expr=coerced_node, other_expr=other_node, loc=coerced_node.loc) - self.visit(node) - return node + self.visit(node) + return node - def _coerce_numeric(self, left, right): - # Implements the coercion protocol. + def _coerce_numeric(self, nodes, map_return=lambda typ: typ): # See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex. - if builtins.is_float(left.type) or builtins.is_float(right.type): - typ = builtins.TFloat() - elif builtins.is_int(left.type) or builtins.is_int(right.type): - left_width, right_width = \ - builtins.get_int_width(left.type), builtins.get_int_width(left.type) - if left_width and right_width: - typ = builtins.TInt(types.TValue(max(left_width, right_width))) - else: - typ = builtins.TInt() - elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet + node_types = [node.type for node in nodes] + if any(map(types.is_var, node_types)): # not enough info yet return - else: # conflicting types - printer = types.TypePrinter() - note1 = diagnostic.Diagnostic("note", - "expression of type {typea}", {"typea": printer.name(left.type)}, - left.loc) - note2 = diagnostic.Diagnostic("note", - "expression of type {typeb}", {"typeb": printer.name(right.type)}, - right.loc) + elif not all(map(builtins.is_numeric, node_types)): + err_node = next(filter(lambda node: not builtins.is_numeric(node.type), nodes)) diag = diagnostic.Diagnostic("error", - "cannot coerce {typea} and {typeb} to a common numeric type", - {"typea": printer.name(left.type), "typeb": printer.name(right.type)}, - left.loc, [right.loc], - [note1, note2]) + "cannot coerce {type} to a numeric type", + {"type": types.TypePrinter().name(err_node.type)}, + err_node.loc, []) self.engine.process(diag) return + elif any(map(builtins.is_float, node_types)): + typ = builtins.TFloat() + elif any(map(builtins.is_int, node_types)): + widths = map(builtins.get_int_width, node_types) + if all(widths): + typ = builtins.TInt(types.TValue(max(widths))) + else: + typ = builtins.TInt() + else: + assert False - return typ, typ, typ + return map_return(typ) def _order_by_pred(self, pred, left, right): if pred(left.type): @@ -503,7 +501,7 @@ class Inferencer(algorithm.Visitor): def _coerce_binop(self, op, left, right): if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor, - ast.LShift, ast.RShift)): + ast.LShift, ast.RShift)): # bitwise operators require integers for operand in (left, right): if not types.is_var(operand.type) and not builtins.is_int(operand.type): @@ -515,7 +513,7 @@ class Inferencer(algorithm.Visitor): self.engine.process(diag) return - return self._coerce_numeric(left, right) + return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) elif isinstance(op, ast.Add): # add works on numbers and also collections if builtins.is_collection(left.type) or builtins.is_collection(right.type): @@ -554,7 +552,7 @@ class Inferencer(algorithm.Visitor): left.loc, right.loc) return left.type, left.type, right.type else: - return self._coerce_numeric(left, right) + return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) elif isinstance(op, ast.Mult): # mult works on numbers and also number & collection if types.is_tuple(left.type) or types.is_tuple(right.type): @@ -585,10 +583,10 @@ class Inferencer(algorithm.Visitor): return list_.type, left.type, right.type else: - return self._coerce_numeric(left, right) + return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) elif isinstance(op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)): # numeric operators work on any kind of number - return self._coerce_numeric(left, right) + return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) else: # MatMult diag = diagnostic.Diagnostic("error", "operator '{op}' is not supported", {"op": op.loc.source()}, @@ -605,6 +603,42 @@ class Inferencer(algorithm.Visitor): node.right = self._coerce_one(right_type, node.right, other_node=node.left) node.type.unify(return_type) # should never fail + def visit_CompareT(self, node): + self.generic_visit(node) + pairs = zip([node.left] + node.comparators, node.comparators) + if all(map(lambda op: isinstance(op, (ast.Is, ast.IsNot)), node.ops)): + for left, right in pairs: + self._unify(left.type, right.type, + left.loc, right.loc) + elif all(map(lambda op: isinstance(op, (ast.In, ast.NotIn)), node.ops)): + for left, right in pairs: + self._unify_collection(element=left, collection=right) + else: # Eq, NotEq, Lt, LtE, Gt, GtE + operands = [node.left] + node.comparators + operand_types = [operand.type for operand in operands] + if any(map(builtins.is_collection, operand_types)): + for left, right in pairs: + self._unify(left.type, right.type, + left.loc, right.loc) + else: + typ = self._coerce_numeric(operands) + if typ: + try: + other_node = next(filter(lambda operand: operand.type.find() == typ.find(), + operands)) + except StopIteration: + # can't find an argument with an exact type, meaning + # the return value is more generic than any of the inputs, meaning + # the type is known (typ is not None), but its width is not + def wide_enough(opreand): + return types.is_mono(opreand.type) and \ + opreand.type.find().name == typ.find().name + other_node = next(filter(wide_enough, operands)) + print(typ, other_node) + node.left, *node.comparators = \ + [self._coerce_one(typ, operand, other_node) for operand in operands] + node.type.unify(builtins.TBool()) + def visit_Assign(self, node): self.generic_visit(node) if len(node.targets) > 1: diff --git a/lit-test/py2llvm/typing/coerce.py b/lit-test/py2llvm/typing/coerce.py index 786f3e81a..34df20e7e 100644 --- a/lit-test/py2llvm/typing/coerce.py +++ b/lit-test/py2llvm/typing/coerce.py @@ -27,3 +27,15 @@ a = []; a += [1] # CHECK-L: a:list(elt=int(width='r)) = []:list(elt=int(width='r)); a:list(elt=int(width='r)) += [1:int(width='r)]:list(elt=int(width='r)) + +[] is [1] +# CHECK-L: []:list(elt=int(width='s)) is [1:int(width='s)]:list(elt=int(width='s)):bool + +1 in [1] +# CHECK-L: 1:int(width='t) in [1:int(width='t)]:list(elt=int(width='t)):bool + +[] < [1] +# CHECK-L: []:list(elt=int(width='u)) < [1:int(width='u)]:list(elt=int(width='u)):bool + +1.0 < 1 +# CHECK-L: 1.0:float < 1:int(width='v):float:bool diff --git a/lit-test/py2llvm/typing/error_coerce.py b/lit-test/py2llvm/typing/error_coerce.py index ad8416577..0ac65724f 100644 --- a/lit-test/py2llvm/typing/error_coerce.py +++ b/lit-test/py2llvm/typing/error_coerce.py @@ -25,13 +25,7 @@ # CHECK-L: ${LINE:+1}: note: operand of type list(elt='b), which is not a valid repetition amount [1] * [] -# CHECK-L: ${LINE:+3}: error: cannot coerce list(elt='a) and NoneType to a common numeric type -# CHECK-L: ${LINE:+2}: note: expression of type list(elt='a) -# CHECK-L: ${LINE:+1}: note: expression of type NoneType -[] - None - -# CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float -# CHECK-L: ${LINE:+1}: note: expression that required coercion to float +# CHECK-L: ${LINE:+1}: error: cannot coerce list(elt='a) to a numeric type [] - 1.0 # CHECK-L: ${LINE:+2}: error: expression of type int(width='a) has to be coerced to float, which makes assignment invalid