forked from M-Labs/artiq
Add support for Compare.
This commit is contained in:
parent
fe69c5b465
commit
20b7a73b49
|
@ -163,13 +163,14 @@ class TValue(Type):
|
||||||
def is_var(typ):
|
def is_var(typ):
|
||||||
return isinstance(typ.find(), TVar)
|
return isinstance(typ.find(), TVar)
|
||||||
|
|
||||||
def is_mono(typ, name, **params):
|
def is_mono(typ, name=None, **params):
|
||||||
typ = typ.find()
|
typ = typ.find()
|
||||||
params_match = True
|
params_match = True
|
||||||
for param in params:
|
for param in params:
|
||||||
params_match = params_match and typ.params[param] == params[param]
|
params_match = params_match and \
|
||||||
|
typ.params[param].find() == params[param].find()
|
||||||
return isinstance(typ, TMono) and \
|
return isinstance(typ, TMono) and \
|
||||||
typ.name == name and params_match
|
(name is None or (typ.name == name and params_match))
|
||||||
|
|
||||||
def is_tuple(typ, elts=None):
|
def is_tuple(typ, elts=None):
|
||||||
typ = typ.find()
|
typ = typ.find()
|
||||||
|
|
|
@ -283,7 +283,6 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||||
|
|
||||||
# expr
|
# expr
|
||||||
visit_Call = visit_unsupported
|
visit_Call = visit_unsupported
|
||||||
visit_Compare = visit_unsupported
|
|
||||||
visit_Dict = visit_unsupported
|
visit_Dict = visit_unsupported
|
||||||
visit_DictComp = visit_unsupported
|
visit_DictComp = visit_unsupported
|
||||||
visit_Ellipsis = visit_unsupported
|
visit_Ellipsis = visit_unsupported
|
||||||
|
@ -393,11 +392,14 @@ class Inferencer(algorithm.Visitor):
|
||||||
node.attr_loc, [node.value.loc])
|
node.attr_loc, [node.value.loc])
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
def _unify_collection(self, element, collection):
|
||||||
|
# TODO: support more than just lists
|
||||||
|
self._unify(builtins.TList(element.type), collection.type,
|
||||||
|
element.loc, collection.loc)
|
||||||
|
|
||||||
def visit_SubscriptT(self, node):
|
def visit_SubscriptT(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
# TODO: support more than just lists
|
self._unify_collection(element=node, collection=node.value)
|
||||||
self._unify(builtins.TList(node.type), node.value.type,
|
|
||||||
node.loc, node.value.loc)
|
|
||||||
|
|
||||||
def visit_IfExpT(self, node):
|
def visit_IfExpT(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
|
@ -455,43 +457,39 @@ class Inferencer(algorithm.Visitor):
|
||||||
def _coerce_one(self, typ, coerced_node, other_node):
|
def _coerce_one(self, typ, coerced_node, other_node):
|
||||||
if coerced_node.type.find() == typ.find():
|
if coerced_node.type.find() == typ.find():
|
||||||
return coerced_node
|
return coerced_node
|
||||||
|
elif isinstance(coerced_node, asttyped.CoerceT):
|
||||||
|
node.type, node.other_expr = typ, other_node
|
||||||
else:
|
else:
|
||||||
node = asttyped.CoerceT(type=typ, expr=coerced_node, other_expr=other_node,
|
node = asttyped.CoerceT(type=typ, expr=coerced_node, other_expr=other_node,
|
||||||
loc=coerced_node.loc)
|
loc=coerced_node.loc)
|
||||||
self.visit(node)
|
self.visit(node)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def _coerce_numeric(self, left, right):
|
def _coerce_numeric(self, nodes, map_return=lambda typ: typ):
|
||||||
# Implements the coercion protocol.
|
|
||||||
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
|
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
|
||||||
if builtins.is_float(left.type) or builtins.is_float(right.type):
|
node_types = [node.type for node in nodes]
|
||||||
typ = builtins.TFloat()
|
if any(map(types.is_var, node_types)): # not enough info yet
|
||||||
elif builtins.is_int(left.type) or builtins.is_int(right.type):
|
|
||||||
left_width, right_width = \
|
|
||||||
builtins.get_int_width(left.type), builtins.get_int_width(left.type)
|
|
||||||
if left_width and right_width:
|
|
||||||
typ = builtins.TInt(types.TValue(max(left_width, right_width)))
|
|
||||||
else:
|
|
||||||
typ = builtins.TInt()
|
|
||||||
elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet
|
|
||||||
return
|
return
|
||||||
else: # conflicting types
|
elif not all(map(builtins.is_numeric, node_types)):
|
||||||
printer = types.TypePrinter()
|
err_node = next(filter(lambda node: not builtins.is_numeric(node.type), nodes))
|
||||||
note1 = diagnostic.Diagnostic("note",
|
|
||||||
"expression of type {typea}", {"typea": printer.name(left.type)},
|
|
||||||
left.loc)
|
|
||||||
note2 = diagnostic.Diagnostic("note",
|
|
||||||
"expression of type {typeb}", {"typeb": printer.name(right.type)},
|
|
||||||
right.loc)
|
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"cannot coerce {typea} and {typeb} to a common numeric type",
|
"cannot coerce {type} to a numeric type",
|
||||||
{"typea": printer.name(left.type), "typeb": printer.name(right.type)},
|
{"type": types.TypePrinter().name(err_node.type)},
|
||||||
left.loc, [right.loc],
|
err_node.loc, [])
|
||||||
[note1, note2])
|
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
return
|
return
|
||||||
|
elif any(map(builtins.is_float, node_types)):
|
||||||
|
typ = builtins.TFloat()
|
||||||
|
elif any(map(builtins.is_int, node_types)):
|
||||||
|
widths = map(builtins.get_int_width, node_types)
|
||||||
|
if all(widths):
|
||||||
|
typ = builtins.TInt(types.TValue(max(widths)))
|
||||||
|
else:
|
||||||
|
typ = builtins.TInt()
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
return typ, typ, typ
|
return map_return(typ)
|
||||||
|
|
||||||
def _order_by_pred(self, pred, left, right):
|
def _order_by_pred(self, pred, left, right):
|
||||||
if pred(left.type):
|
if pred(left.type):
|
||||||
|
@ -503,7 +501,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
|
|
||||||
def _coerce_binop(self, op, left, right):
|
def _coerce_binop(self, op, left, right):
|
||||||
if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
||||||
ast.LShift, ast.RShift)):
|
ast.LShift, ast.RShift)):
|
||||||
# bitwise operators require integers
|
# bitwise operators require integers
|
||||||
for operand in (left, right):
|
for operand in (left, right):
|
||||||
if not types.is_var(operand.type) and not builtins.is_int(operand.type):
|
if not types.is_var(operand.type) and not builtins.is_int(operand.type):
|
||||||
|
@ -515,7 +513,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
return
|
return
|
||||||
|
|
||||||
return self._coerce_numeric(left, right)
|
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
|
||||||
elif isinstance(op, ast.Add):
|
elif isinstance(op, ast.Add):
|
||||||
# add works on numbers and also collections
|
# add works on numbers and also collections
|
||||||
if builtins.is_collection(left.type) or builtins.is_collection(right.type):
|
if builtins.is_collection(left.type) or builtins.is_collection(right.type):
|
||||||
|
@ -554,7 +552,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
left.loc, right.loc)
|
left.loc, right.loc)
|
||||||
return left.type, left.type, right.type
|
return left.type, left.type, right.type
|
||||||
else:
|
else:
|
||||||
return self._coerce_numeric(left, right)
|
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
|
||||||
elif isinstance(op, ast.Mult):
|
elif isinstance(op, ast.Mult):
|
||||||
# mult works on numbers and also number & collection
|
# mult works on numbers and also number & collection
|
||||||
if types.is_tuple(left.type) or types.is_tuple(right.type):
|
if types.is_tuple(left.type) or types.is_tuple(right.type):
|
||||||
|
@ -585,10 +583,10 @@ class Inferencer(algorithm.Visitor):
|
||||||
|
|
||||||
return list_.type, left.type, right.type
|
return list_.type, left.type, right.type
|
||||||
else:
|
else:
|
||||||
return self._coerce_numeric(left, right)
|
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
|
||||||
elif isinstance(op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
|
elif isinstance(op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
|
||||||
# numeric operators work on any kind of number
|
# numeric operators work on any kind of number
|
||||||
return self._coerce_numeric(left, right)
|
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
|
||||||
else: # MatMult
|
else: # MatMult
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"operator '{op}' is not supported", {"op": op.loc.source()},
|
"operator '{op}' is not supported", {"op": op.loc.source()},
|
||||||
|
@ -605,6 +603,42 @@ class Inferencer(algorithm.Visitor):
|
||||||
node.right = self._coerce_one(right_type, node.right, other_node=node.left)
|
node.right = self._coerce_one(right_type, node.right, other_node=node.left)
|
||||||
node.type.unify(return_type) # should never fail
|
node.type.unify(return_type) # should never fail
|
||||||
|
|
||||||
|
def visit_CompareT(self, node):
|
||||||
|
self.generic_visit(node)
|
||||||
|
pairs = zip([node.left] + node.comparators, node.comparators)
|
||||||
|
if all(map(lambda op: isinstance(op, (ast.Is, ast.IsNot)), node.ops)):
|
||||||
|
for left, right in pairs:
|
||||||
|
self._unify(left.type, right.type,
|
||||||
|
left.loc, right.loc)
|
||||||
|
elif all(map(lambda op: isinstance(op, (ast.In, ast.NotIn)), node.ops)):
|
||||||
|
for left, right in pairs:
|
||||||
|
self._unify_collection(element=left, collection=right)
|
||||||
|
else: # Eq, NotEq, Lt, LtE, Gt, GtE
|
||||||
|
operands = [node.left] + node.comparators
|
||||||
|
operand_types = [operand.type for operand in operands]
|
||||||
|
if any(map(builtins.is_collection, operand_types)):
|
||||||
|
for left, right in pairs:
|
||||||
|
self._unify(left.type, right.type,
|
||||||
|
left.loc, right.loc)
|
||||||
|
else:
|
||||||
|
typ = self._coerce_numeric(operands)
|
||||||
|
if typ:
|
||||||
|
try:
|
||||||
|
other_node = next(filter(lambda operand: operand.type.find() == typ.find(),
|
||||||
|
operands))
|
||||||
|
except StopIteration:
|
||||||
|
# can't find an argument with an exact type, meaning
|
||||||
|
# the return value is more generic than any of the inputs, meaning
|
||||||
|
# the type is known (typ is not None), but its width is not
|
||||||
|
def wide_enough(opreand):
|
||||||
|
return types.is_mono(opreand.type) and \
|
||||||
|
opreand.type.find().name == typ.find().name
|
||||||
|
other_node = next(filter(wide_enough, operands))
|
||||||
|
print(typ, other_node)
|
||||||
|
node.left, *node.comparators = \
|
||||||
|
[self._coerce_one(typ, operand, other_node) for operand in operands]
|
||||||
|
node.type.unify(builtins.TBool())
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
if len(node.targets) > 1:
|
if len(node.targets) > 1:
|
||||||
|
|
|
@ -27,3 +27,15 @@
|
||||||
|
|
||||||
a = []; a += [1]
|
a = []; a += [1]
|
||||||
# CHECK-L: a:list(elt=int(width='r)) = []:list(elt=int(width='r)); a:list(elt=int(width='r)) += [1:int(width='r)]:list(elt=int(width='r))
|
# CHECK-L: a:list(elt=int(width='r)) = []:list(elt=int(width='r)); a:list(elt=int(width='r)) += [1:int(width='r)]:list(elt=int(width='r))
|
||||||
|
|
||||||
|
[] is [1]
|
||||||
|
# CHECK-L: []:list(elt=int(width='s)) is [1:int(width='s)]:list(elt=int(width='s)):bool
|
||||||
|
|
||||||
|
1 in [1]
|
||||||
|
# CHECK-L: 1:int(width='t) in [1:int(width='t)]:list(elt=int(width='t)):bool
|
||||||
|
|
||||||
|
[] < [1]
|
||||||
|
# CHECK-L: []:list(elt=int(width='u)) < [1:int(width='u)]:list(elt=int(width='u)):bool
|
||||||
|
|
||||||
|
1.0 < 1
|
||||||
|
# CHECK-L: 1.0:float < 1:int(width='v):float:bool
|
||||||
|
|
|
@ -25,13 +25,7 @@
|
||||||
# CHECK-L: ${LINE:+1}: note: operand of type list(elt='b), which is not a valid repetition amount
|
# CHECK-L: ${LINE:+1}: note: operand of type list(elt='b), which is not a valid repetition amount
|
||||||
[1] * []
|
[1] * []
|
||||||
|
|
||||||
# CHECK-L: ${LINE:+3}: error: cannot coerce list(elt='a) and NoneType to a common numeric type
|
# CHECK-L: ${LINE:+1}: error: cannot coerce list(elt='a) to a numeric type
|
||||||
# CHECK-L: ${LINE:+2}: note: expression of type list(elt='a)
|
|
||||||
# CHECK-L: ${LINE:+1}: note: expression of type NoneType
|
|
||||||
[] - None
|
|
||||||
|
|
||||||
# CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float
|
|
||||||
# CHECK-L: ${LINE:+1}: note: expression that required coercion to float
|
|
||||||
[] - 1.0
|
[] - 1.0
|
||||||
|
|
||||||
# CHECK-L: ${LINE:+2}: error: expression of type int(width='a) has to be coerced to float, which makes assignment invalid
|
# CHECK-L: ${LINE:+2}: error: expression of type int(width='a) has to be coerced to float, which makes assignment invalid
|
||||||
|
|
Loading…
Reference in New Issue