forked from M-Labs/artiq
1
0
Fork 0

Add support for Compare.

This commit is contained in:
whitequark 2015-06-14 22:48:04 +03:00
parent fe69c5b465
commit 20b7a73b49
4 changed files with 87 additions and 46 deletions

View File

@ -163,13 +163,14 @@ class TValue(Type):
def is_var(typ): def is_var(typ):
return isinstance(typ.find(), TVar) return isinstance(typ.find(), TVar)
def is_mono(typ, name, **params): def is_mono(typ, name=None, **params):
typ = typ.find() typ = typ.find()
params_match = True params_match = True
for param in params: for param in params:
params_match = params_match and typ.params[param] == params[param] params_match = params_match and \
typ.params[param].find() == params[param].find()
return isinstance(typ, TMono) and \ return isinstance(typ, TMono) and \
typ.name == name and params_match (name is None or (typ.name == name and params_match))
def is_tuple(typ, elts=None): def is_tuple(typ, elts=None):
typ = typ.find() typ = typ.find()

View File

@ -283,7 +283,6 @@ class ASTTypedRewriter(algorithm.Transformer):
# expr # expr
visit_Call = visit_unsupported visit_Call = visit_unsupported
visit_Compare = visit_unsupported
visit_Dict = visit_unsupported visit_Dict = visit_unsupported
visit_DictComp = visit_unsupported visit_DictComp = visit_unsupported
visit_Ellipsis = visit_unsupported visit_Ellipsis = visit_unsupported
@ -393,11 +392,14 @@ class Inferencer(algorithm.Visitor):
node.attr_loc, [node.value.loc]) node.attr_loc, [node.value.loc])
self.engine.process(diag) self.engine.process(diag)
def _unify_collection(self, element, collection):
# TODO: support more than just lists
self._unify(builtins.TList(element.type), collection.type,
element.loc, collection.loc)
def visit_SubscriptT(self, node): def visit_SubscriptT(self, node):
self.generic_visit(node) self.generic_visit(node)
# TODO: support more than just lists self._unify_collection(element=node, collection=node.value)
self._unify(builtins.TList(node.type), node.value.type,
node.loc, node.value.loc)
def visit_IfExpT(self, node): def visit_IfExpT(self, node):
self.generic_visit(node) self.generic_visit(node)
@ -455,43 +457,39 @@ class Inferencer(algorithm.Visitor):
def _coerce_one(self, typ, coerced_node, other_node): def _coerce_one(self, typ, coerced_node, other_node):
if coerced_node.type.find() == typ.find(): if coerced_node.type.find() == typ.find():
return coerced_node return coerced_node
elif isinstance(coerced_node, asttyped.CoerceT):
node.type, node.other_expr = typ, other_node
else: else:
node = asttyped.CoerceT(type=typ, expr=coerced_node, other_expr=other_node, node = asttyped.CoerceT(type=typ, expr=coerced_node, other_expr=other_node,
loc=coerced_node.loc) loc=coerced_node.loc)
self.visit(node) self.visit(node)
return node return node
def _coerce_numeric(self, left, right): def _coerce_numeric(self, nodes, map_return=lambda typ: typ):
# Implements the coercion protocol.
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex. # See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
if builtins.is_float(left.type) or builtins.is_float(right.type): node_types = [node.type for node in nodes]
typ = builtins.TFloat() if any(map(types.is_var, node_types)): # not enough info yet
elif builtins.is_int(left.type) or builtins.is_int(right.type):
left_width, right_width = \
builtins.get_int_width(left.type), builtins.get_int_width(left.type)
if left_width and right_width:
typ = builtins.TInt(types.TValue(max(left_width, right_width)))
else:
typ = builtins.TInt()
elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet
return return
else: # conflicting types elif not all(map(builtins.is_numeric, node_types)):
printer = types.TypePrinter() err_node = next(filter(lambda node: not builtins.is_numeric(node.type), nodes))
note1 = diagnostic.Diagnostic("note",
"expression of type {typea}", {"typea": printer.name(left.type)},
left.loc)
note2 = diagnostic.Diagnostic("note",
"expression of type {typeb}", {"typeb": printer.name(right.type)},
right.loc)
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"cannot coerce {typea} and {typeb} to a common numeric type", "cannot coerce {type} to a numeric type",
{"typea": printer.name(left.type), "typeb": printer.name(right.type)}, {"type": types.TypePrinter().name(err_node.type)},
left.loc, [right.loc], err_node.loc, [])
[note1, note2])
self.engine.process(diag) self.engine.process(diag)
return return
elif any(map(builtins.is_float, node_types)):
typ = builtins.TFloat()
elif any(map(builtins.is_int, node_types)):
widths = map(builtins.get_int_width, node_types)
if all(widths):
typ = builtins.TInt(types.TValue(max(widths)))
else:
typ = builtins.TInt()
else:
assert False
return typ, typ, typ return map_return(typ)
def _order_by_pred(self, pred, left, right): def _order_by_pred(self, pred, left, right):
if pred(left.type): if pred(left.type):
@ -503,7 +501,7 @@ class Inferencer(algorithm.Visitor):
def _coerce_binop(self, op, left, right): def _coerce_binop(self, op, left, right):
if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor, if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
ast.LShift, ast.RShift)): ast.LShift, ast.RShift)):
# bitwise operators require integers # bitwise operators require integers
for operand in (left, right): for operand in (left, right):
if not types.is_var(operand.type) and not builtins.is_int(operand.type): if not types.is_var(operand.type) and not builtins.is_int(operand.type):
@ -515,7 +513,7 @@ class Inferencer(algorithm.Visitor):
self.engine.process(diag) self.engine.process(diag)
return return
return self._coerce_numeric(left, right) return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
elif isinstance(op, ast.Add): elif isinstance(op, ast.Add):
# add works on numbers and also collections # add works on numbers and also collections
if builtins.is_collection(left.type) or builtins.is_collection(right.type): if builtins.is_collection(left.type) or builtins.is_collection(right.type):
@ -554,7 +552,7 @@ class Inferencer(algorithm.Visitor):
left.loc, right.loc) left.loc, right.loc)
return left.type, left.type, right.type return left.type, left.type, right.type
else: else:
return self._coerce_numeric(left, right) return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
elif isinstance(op, ast.Mult): elif isinstance(op, ast.Mult):
# mult works on numbers and also number & collection # mult works on numbers and also number & collection
if types.is_tuple(left.type) or types.is_tuple(right.type): if types.is_tuple(left.type) or types.is_tuple(right.type):
@ -585,10 +583,10 @@ class Inferencer(algorithm.Visitor):
return list_.type, left.type, right.type return list_.type, left.type, right.type
else: else:
return self._coerce_numeric(left, right) return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
elif isinstance(op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)): elif isinstance(op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
# numeric operators work on any kind of number # numeric operators work on any kind of number
return self._coerce_numeric(left, right) return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
else: # MatMult else: # MatMult
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"operator '{op}' is not supported", {"op": op.loc.source()}, "operator '{op}' is not supported", {"op": op.loc.source()},
@ -605,6 +603,42 @@ class Inferencer(algorithm.Visitor):
node.right = self._coerce_one(right_type, node.right, other_node=node.left) node.right = self._coerce_one(right_type, node.right, other_node=node.left)
node.type.unify(return_type) # should never fail node.type.unify(return_type) # should never fail
def visit_CompareT(self, node):
self.generic_visit(node)
pairs = zip([node.left] + node.comparators, node.comparators)
if all(map(lambda op: isinstance(op, (ast.Is, ast.IsNot)), node.ops)):
for left, right in pairs:
self._unify(left.type, right.type,
left.loc, right.loc)
elif all(map(lambda op: isinstance(op, (ast.In, ast.NotIn)), node.ops)):
for left, right in pairs:
self._unify_collection(element=left, collection=right)
else: # Eq, NotEq, Lt, LtE, Gt, GtE
operands = [node.left] + node.comparators
operand_types = [operand.type for operand in operands]
if any(map(builtins.is_collection, operand_types)):
for left, right in pairs:
self._unify(left.type, right.type,
left.loc, right.loc)
else:
typ = self._coerce_numeric(operands)
if typ:
try:
other_node = next(filter(lambda operand: operand.type.find() == typ.find(),
operands))
except StopIteration:
# can't find an argument with an exact type, meaning
# the return value is more generic than any of the inputs, meaning
# the type is known (typ is not None), but its width is not
def wide_enough(opreand):
return types.is_mono(opreand.type) and \
opreand.type.find().name == typ.find().name
other_node = next(filter(wide_enough, operands))
print(typ, other_node)
node.left, *node.comparators = \
[self._coerce_one(typ, operand, other_node) for operand in operands]
node.type.unify(builtins.TBool())
def visit_Assign(self, node): def visit_Assign(self, node):
self.generic_visit(node) self.generic_visit(node)
if len(node.targets) > 1: if len(node.targets) > 1:

View File

@ -27,3 +27,15 @@
a = []; a += [1] a = []; a += [1]
# CHECK-L: a:list(elt=int(width='r)) = []:list(elt=int(width='r)); a:list(elt=int(width='r)) += [1:int(width='r)]:list(elt=int(width='r)) # CHECK-L: a:list(elt=int(width='r)) = []:list(elt=int(width='r)); a:list(elt=int(width='r)) += [1:int(width='r)]:list(elt=int(width='r))
[] is [1]
# CHECK-L: []:list(elt=int(width='s)) is [1:int(width='s)]:list(elt=int(width='s)):bool
1 in [1]
# CHECK-L: 1:int(width='t) in [1:int(width='t)]:list(elt=int(width='t)):bool
[] < [1]
# CHECK-L: []:list(elt=int(width='u)) < [1:int(width='u)]:list(elt=int(width='u)):bool
1.0 < 1
# CHECK-L: 1.0:float < 1:int(width='v):float:bool

View File

@ -25,13 +25,7 @@
# CHECK-L: ${LINE:+1}: note: operand of type list(elt='b), which is not a valid repetition amount # CHECK-L: ${LINE:+1}: note: operand of type list(elt='b), which is not a valid repetition amount
[1] * [] [1] * []
# CHECK-L: ${LINE:+3}: error: cannot coerce list(elt='a) and NoneType to a common numeric type # CHECK-L: ${LINE:+1}: error: cannot coerce list(elt='a) to a numeric type
# CHECK-L: ${LINE:+2}: note: expression of type list(elt='a)
# CHECK-L: ${LINE:+1}: note: expression of type NoneType
[] - None
# CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float
# CHECK-L: ${LINE:+1}: note: expression that required coercion to float
[] - 1.0 [] - 1.0
# CHECK-L: ${LINE:+2}: error: expression of type int(width='a) has to be coerced to float, which makes assignment invalid # CHECK-L: ${LINE:+2}: error: expression of type int(width='a) has to be coerced to float, which makes assignment invalid