diff --git a/artiq/py2llvm/asttyped.py b/artiq/py2llvm/asttyped.py index 415f7dcc0..e96edd641 100644 --- a/artiq/py2llvm/asttyped.py +++ b/artiq/py2llvm/asttyped.py @@ -22,6 +22,7 @@ class scoped(object): list of variables resolved as globals """ +# Typed versions of untyped nodes class argT(ast.arg, commontyped): pass @@ -82,3 +83,7 @@ class YieldT(ast.Yield, commontyped): pass class YieldFromT(ast.YieldFrom, commontyped): pass + +# Novel typed nodes +class CoerceT(ast.expr, commontyped): + _fields = ('expr',) # other_expr deliberately not in _fields diff --git a/artiq/py2llvm/builtins.py b/artiq/py2llvm/builtins.py index 84229cc6c..569b85cb2 100644 --- a/artiq/py2llvm/builtins.py +++ b/artiq/py2llvm/builtins.py @@ -23,36 +23,6 @@ class TFloat(types.TMono): def __init__(self): super().__init__("float") -class TTuple(types.Type): - """A tuple type.""" - - attributes = {} - - def __init__(self, elts=[]): - self.elts = elts - - def find(self): - return self - - def unify(self, other): - if isinstance(other, TTuple) and len(self.elts) == len(other.elts): - for selfelt, otherelt in zip(self.elts, other.elts): - selfelt.unify(otherelt) - elif isinstance(other, TVar): - other.unify(self) - else: - raise UnificationError(self, other) - - def __repr__(self): - return "TTuple(%s)" % (", ".join(map(repr, self.elts))) - - def __eq__(self, other): - return isinstance(other, TTuple) and \ - self.elts == other.elts - - def __ne__(self, other): - return not (self == other) - class TList(types.TMono): def __init__(self, elt=None): if elt is None: @@ -60,12 +30,37 @@ class TList(types.TMono): super().__init__("list", {"elt": elt}) +def is_none(typ): + return types.is_mono(typ, "NoneType") + +def is_bool(typ): + return types.is_mono(typ, "bool") + def is_int(typ, width=None): if width: return types.is_mono(typ, "int", {"width": width}) else: return types.is_mono(typ, "int") +def get_int_width(typ): + if is_int(typ): + return types.get_value(typ["width"]) + +def is_float(typ): + return types.is_mono(typ, "float") + def is_numeric(typ): + typ = typ.find() return isinstance(typ, types.TMono) and \ typ.name in ('int', 'float') + +def is_list(typ, elt=None): + if elt: + return types.is_mono(typ, "list", {"elt": elt}) + else: + return types.is_mono(typ, "list") + +def is_collection(typ): + typ = typ.find() + return isinstance(typ, types.TTuple) or \ + types.is_mono(typ, "list") diff --git a/artiq/py2llvm/types.py b/artiq/py2llvm/types.py index cce235ed8..97aff58ca 100644 --- a/artiq/py2llvm/types.py +++ b/artiq/py2llvm/types.py @@ -101,6 +101,36 @@ class TMono(Type): def __ne__(self, other): return not (self == other) +class TTuple(Type): + """A tuple type.""" + + attributes = {} + + def __init__(self, elts=[]): + self.elts = elts + + def find(self): + return self + + def unify(self, other): + if isinstance(other, TTuple) and len(self.elts) == len(other.elts): + for selfelt, otherelt in zip(self.elts, other.elts): + selfelt.unify(otherelt) + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(self, other) + + def __repr__(self): + return "TTuple(%s)" % (", ".join(map(repr, self.elts))) + + def __eq__(self, other): + return isinstance(other, TTuple) and \ + self.elts == other.elts + + def __ne__(self, other): + return not (self == other) + class TValue(Type): """ A type-level value (such as the integer denoting width of @@ -131,15 +161,32 @@ class TValue(Type): def is_var(typ): - return isinstance(typ, TVar) + return isinstance(typ.find(), TVar) def is_mono(typ, name, **params): + typ = typ.find() params_match = True for param in params: params_match = params_match and typ.params[param] == params[param] return isinstance(typ, TMono) and \ typ.name == name and params_match +def is_tuple(typ, elts=None): + typ = typ.find() + if elts: + return isinstance(typ, TTuple) and \ + elts == typ.elts + else: + return isinstance(typ, TTuple) + +def get_value(typ): + typ = typ.find() + if isinstance(typ, TVar): + return None + elif isinstance(typ, TValue): + return typ.value + else: + assert False class TypePrinter(object): """ diff --git a/artiq/py2llvm/typing.py b/artiq/py2llvm/typing.py index a35b763f6..0f87dcd17 100644 --- a/artiq/py2llvm/typing.py +++ b/artiq/py2llvm/typing.py @@ -215,7 +215,7 @@ class ASTTypedRewriter(algorithm.Transformer): def visit_Tuple(self, node): node = self.generic_visit(node) - return asttyped.TupleT(type=builtins.TTuple([x.type for x in node.elts]), + return asttyped.TupleT(type=types.TTuple([x.type for x in node.elts]), elts=node.elts, ctx=node.ctx, loc=node.loc) def visit_List(self, node): @@ -282,7 +282,6 @@ class ASTTypedRewriter(algorithm.Transformer): self.engine.process(diag) # expr - visit_BinOp = visit_unsupported visit_Call = visit_unsupported visit_Compare = visit_unsupported visit_Dict = visit_unsupported @@ -375,16 +374,18 @@ class Inferencer(algorithm.Visitor): return makenotes def visit_ListT(self, node): + self.generic_visit(node) for elt in node.elts: self._unify(node.type["elt"], elt.type, node.loc, elt.loc, self._makenotes_elts(node.elts, "a list element")) def visit_AttributeT(self, node): + self.generic_visit(node) object_type = node.value.type.find() if not types.is_var(object_type): if node.attr in object_type.attributes: # assumes no free type variables in .attributes - node.type = object_type.attributes[node.attr] + node.type.unify(object_type.attributes[node.attr]) # should never fail else: diag = diagnostic.Diagnostic("error", "type {type} does not have an attribute '{attr}'", @@ -393,48 +394,221 @@ class Inferencer(algorithm.Visitor): self.engine.process(diag) def visit_SubscriptT(self, node): + self.generic_visit(node) # TODO: support more than just lists self._unify(builtins.TList(node.type), node.value.type, node.loc, node.value.loc) def visit_IfExpT(self, node): + self.generic_visit(node) self._unify(node.body.type, node.orelse.type, node.body.loc, node.orelse.loc) - node.type = node.body.type + node.type.unify(node.body.type) # should never fail def visit_BoolOpT(self, node): + self.generic_visit(node) for value in node.values: self._unify(node.type, value.type, node.loc, value.loc, self._makenotes_elts(node.values, "an operand")) def visit_UnaryOpT(self, node): + self.generic_visit(node) operand_type = node.operand.type.find() if isinstance(node.op, ast.Not): - node.type = builtins.TBool() + node.type.unify(builtins.TBool()) # should never fail elif isinstance(node.op, ast.Invert): if builtins.is_int(operand_type): - node.type = operand_type + node.type.unify(operand_type) # should never fail elif not types.is_var(operand_type): diag = diagnostic.Diagnostic("error", - "expected ~ operand to be of integer type, not {type}", + "expected '~' operand to be of integer type, not {type}", {"type": types.TypePrinter().name(operand_type)}, node.operand.loc) self.engine.process(diag) else: # UAdd, USub if builtins.is_numeric(operand_type): - node.type = operand_type + node.type.unify(operand_type) # should never fail elif not types.is_var(operand_type): diag = diagnostic.Diagnostic("error", - "expected unary {op} operand to be of numeric type, not {type}", + "expected unary '{op}' operand to be of numeric type, not {type}", {"op": node.op.loc.source(), "type": types.TypePrinter().name(operand_type)}, node.operand.loc) self.engine.process(diag) + def visit_CoerceT(self, node): + self.generic_visit(node) + if builtins.is_numeric(node.type) and builtins.is_numeric(node.expr.type): + pass + else: + printer = types.TypePrinter() + note = diagnostic.Diagnostic("note", + "expression that required coercion to {typeb}", + {"typeb": printer.name(node.type)}, + node.other_expr.loc) + diag = diagnostic.Diagnostic("error", + "cannot coerce {typea} to {typeb}", + {"typea": printer.name(node.expr.type), "typeb": printer.name(node.type)}, + node.loc, notes=[note]) + self.engine.process(diag) + + def _coerce_one(self, typ, coerced_node, other_node): + if coerced_node.type.find() == typ.find(): + return coerced_node + else: + node = asttyped.CoerceT(type=typ, expr=coerced_node, other_expr=other_node, + loc=coerced_node.loc) + self.visit(node) + return node + + def _coerce_numeric(self, return_type, left, right): + # Implements the coercion protocol. + # See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex. + if builtins.is_float(left.type) or builtins.is_float(right.type): + typ = builtins.TFloat() + elif builtins.is_int(left.type) or builtins.is_int(right.type): + left_width, right_width = \ + builtins.get_int_width(left.type), builtins.get_int_width(left.type) + if left_width and right_width: + typ = builtins.TInt(types.TValue(max(left_width, right_width))) + else: + typ = builtins.TInt() + elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet + return left, right + else: # conflicting types + printer = types.TypePrinter() + note1 = diagnostic.Diagnostic("note", + "expression of type {typea}", {"typea": printer.name(left.type)}, + left.loc) + note2 = diagnostic.Diagnostic("note", + "expression of type {typeb}", {"typeb": printer.name(right.type)}, + right.loc) + diag = diagnostic.Diagnostic("error", + "cannot coerce {typea} and {typeb} to a common numeric type", + {"typea": printer.name(left.type), "typeb": printer.name(right.type)}, + left.loc, [right.loc], + [note1, note2]) + self.engine.process(diag) + return left, right + + # On 1st invocation, return_type is always a type variable. + # On further invocations, coerce will only ever refine the type, + # so this should never fail. + return_type.unify(typ) + + return self._coerce_one(typ, left, other_node=right), \ + self._coerce_one(typ, right, other_node=left) + + def _order_by_pred(self, pred, left, right): + if pred(left.type): + return left, right + elif pred(right.type): + return right, left + else: + assert False + + def visit_BinOpT(self, node): + self.generic_visit(node) + if isinstance(node.op, (ast.BitAnd, ast.BitOr, ast.BitXor, + ast.LShift, ast.RShift)): + # bitwise operators require integers + for operand in (node.left, node.right): + if not types.is_var(operand.type) and not builtins.is_int(operand.type): + diag = diagnostic.Diagnostic("error", + "expected '{op}' operand to be of integer type, not {type}", + {"op": node.op.loc.source(), + "type": types.TypePrinter().name(operand.type)}, + node.op.loc, [operand.loc]) + self.engine.process(diag) + return + + node.left, node.right = \ + self._coerce_numeric(node.type, node.left, node.right) + elif isinstance(node.op, ast.Add): + # add works on numbers and also collections + if builtins.is_collection(node.left.type) or builtins.is_collection(node.right.type): + collection, other = \ + self._order_by_pred(builtins.is_collection, node.left, node.right) + if types.is_tuple(collection.type): + pred, kind = types.is_tuple, "tuple" + elif builtins.is_list(collection.type): + pred, kind = builtins.is_list, "list" + else: + assert False + if not pred(other.type): + printer = types.TypePrinter() + note1 = diagnostic.Diagnostic("note", + "{kind} of type {typea}", + {"typea": printer.name(collection.type), "kind": kind}, + collection.loc) + note2 = diagnostic.Diagnostic("note", + "{typeb}, which cannot be added to a {kind}", + {"typeb": printer.name(other.type), "kind": kind}, + other.loc) + diag = diagnostic.Diagnostic("error", + "expected every '+' operand to be a {kind} in this context", + {"kind": kind}, + node.op.loc, [other.loc, collection.loc], + [note1, note2]) + self.engine.process(diag) + return + + if types.is_tuple(collection.type): + # should never fail + node.type.unify(types.TTuple(node.left.type.find().elts + + node.right.type.find().elts)) + elif builtins.is_list(collection.type): + self._unify(node.left.type, node.right.type, + node.left.loc, node.right.loc) + node.type.unify(node.left.type) # should never fail + else: + node.left, node.right = \ + self._coerce_numeric(node.type, node.left, node.right) + elif isinstance(node.op, ast.Mult): + # mult works on numbers and also number & collection + if types.is_tuple(node.left.type) or types.is_tuple(node.right.type): + tuple_, other = self._order_by_pred(types.is_tuple, node.left, node.right) + diag = diagnostic.Diagnostic("error", + "py2llvm does not support passing tuples to '*'", {}, + node.op.loc, [tuple_.loc]) + self.engine.process(diag) + elif builtins.is_list(node.left.type) or builtins.is_list(node.right.type): + list_, other = self._order_by_pred(builtins.is_list, node.left, node.right) + if not builtins.is_int(other.type): + printer = types.TypePrinter() + note1 = diagnostic.Diagnostic("note", + "list operand of type {typea}", + {"typea": printer.name(list_.type)}, + list_.loc) + note2 = diagnostic.Diagnostic("note", + "operand of type {typeb}, which is not a valid repetition amount", + {"typeb": printer.name(other.type)}, + other.loc) + diag = diagnostic.Diagnostic("error", + "expected '*' operands to be a list and an integer in this context", {}, + node.op.loc, [list_.loc, other.loc], + [note1, note2]) + self.engine.process(diag) + return + node.type.unify(list_.type) + else: + node.left, node.right = \ + self._coerce_numeric(node.type, node.left, node.right) + elif isinstance(node.op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)): + # numeric operators work on any kind of number + node.left, node.right = \ + self._coerce_numeric(node.type, node.left, node.right) + else: # MatMult + diag = diagnostic.Diagnostic("error", + "operator '{op}' is not supported", {"op": node.op.loc.source()}, + node.op.loc) + self.engine.process(diag) + return + def visit_Assign(self, node): self.generic_visit(node) if len(node.targets) > 1: - self._unify(builtins.TTuple([x.type for x in node.targets]), node.value.type, + self._unify(types.TTuple([x.type for x in node.targets]), node.value.type, node.targets[0].loc.join(node.targets[-1].loc), node.value.loc) else: self._unify(node.targets[0].type, node.value.type, diff --git a/lit-test/py2llvm/typing/coerce.py b/lit-test/py2llvm/typing/coerce.py new file mode 100644 index 000000000..ac33002d9 --- /dev/null +++ b/lit-test/py2llvm/typing/coerce.py @@ -0,0 +1,26 @@ +# RUN: %python -m artiq.py2llvm.typing %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +1 | 2 +# CHECK-L: 1:int(width='a):int(width='b) | 2:int(width='c):int(width='b):int(width='b) + +1 + 2 +# CHECK-L: 1:int(width='d):int(width='e) + 2:int(width='f):int(width='e):int(width='e) + +(1,) + (2.0,) +# CHECK-L: (1:int(width='g),):(int(width='g),) + (2.0:float,):(float,):(int(width='g), float) + +[1] + [2] +# CHECK-L: [1:int(width='h)]:list(elt=int(width='h)) + [2:int(width='h)]:list(elt=int(width='h)):list(elt=int(width='h)) + +1 * 2 +# CHECK-L: 1:int(width='i):int(width='j) * 2:int(width='k):int(width='j):int(width='j) + +[1] * 2 +# CHECK-L: [1:int(width='l)]:list(elt=int(width='l)) * 2:int(width='m):list(elt=int(width='l)) + +1 / 2 +# CHECK-L: 1:int(width='n):int(width='o) / 2:int(width='p):int(width='o):int(width='o) + +1 + 1.0 +# CHECK-L: 1:int(width='q):float + 1.0:float:float diff --git a/lit-test/py2llvm/typing/error_coerce.py b/lit-test/py2llvm/typing/error_coerce.py new file mode 100644 index 000000000..acd90a48c --- /dev/null +++ b/lit-test/py2llvm/typing/error_coerce.py @@ -0,0 +1,35 @@ +# RUN: %python -m artiq.py2llvm.typing +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+1}: error: expected '<<' operand to be of integer type, not float +1 << 2.0 + +# CHECK-L: ${LINE:+3}: error: expected every '+' operand to be a list in this context +# CHECK-L: ${LINE:+2}: note: list of type list(elt=int(width='a)) +# CHECK-L: ${LINE:+1}: note: int(width='b), which cannot be added to a list +[1] + 2 + +# CHECK-L: ${LINE:+1}: error: cannot unify list(elt=int(width='a)) with list(elt=float): int(width='a) is incompatible with float +[1] + [2.0] + +# CHECK-L: ${LINE:+3}: error: expected every '+' operand to be a tuple in this context +# CHECK-L: ${LINE:+2}: note: tuple of type (int(width='a),) +# CHECK-L: ${LINE:+1}: note: int(width='b), which cannot be added to a tuple +(1,) + 2 + +# CHECK-L: ${LINE:+1}: error: py2llvm does not support passing tuples to '*' +(1,) * 2 + +# CHECK-L: ${LINE:+3}: error: expected '*' operands to be a list and an integer in this context +# CHECK-L: ${LINE:+2}: note: list operand of type list(elt=int(width='a)) +# CHECK-L: ${LINE:+1}: note: operand of type list(elt='b), which is not a valid repetition amount +[1] * [] + +# CHECK-L: ${LINE:+3}: error: cannot coerce list(elt='a) and NoneType to a common numeric type +# CHECK-L: ${LINE:+2}: note: expression of type list(elt='a) +# CHECK-L: ${LINE:+1}: note: expression of type NoneType +[] - None + +# CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float +# CHECK-L: ${LINE:+1}: note: expression that required coercion to float +[] - 1.0 diff --git a/lit-test/py2llvm/typing/error_unify.py b/lit-test/py2llvm/typing/error_unify.py index d268837e1..3abef5b13 100644 --- a/lit-test/py2llvm/typing/error_unify.py +++ b/lit-test/py2llvm/typing/error_unify.py @@ -17,10 +17,10 @@ a = b # CHECK-L: note: an operand of type int(width='a) # CHECK-L: note: an operand of type bool -# CHECK-L: ${LINE:+1}: error: expected unary + operand to be of numeric type, not list(elt='a) +# CHECK-L: ${LINE:+1}: error: expected unary '+' operand to be of numeric type, not list(elt='a) +[] -# CHECK-L: ${LINE:+1}: error: expected ~ operand to be of integer type, not float +# CHECK-L: ${LINE:+1}: error: expected '~' operand to be of integer type, not float ~1.0 # CHECK-L: ${LINE:+1}: error: type int(width='a) does not have an attribute 'x'