forked from M-Labs/artiq
Add support for Compare.
This commit is contained in:
parent
fe69c5b465
commit
20b7a73b49
@ -163,13 +163,14 @@ class TValue(Type):
|
||||
def is_var(typ):
|
||||
return isinstance(typ.find(), TVar)
|
||||
|
||||
def is_mono(typ, name, **params):
|
||||
def is_mono(typ, name=None, **params):
|
||||
typ = typ.find()
|
||||
params_match = True
|
||||
for param in params:
|
||||
params_match = params_match and typ.params[param] == params[param]
|
||||
params_match = params_match and \
|
||||
typ.params[param].find() == params[param].find()
|
||||
return isinstance(typ, TMono) and \
|
||||
typ.name == name and params_match
|
||||
(name is None or (typ.name == name and params_match))
|
||||
|
||||
def is_tuple(typ, elts=None):
|
||||
typ = typ.find()
|
||||
|
@ -283,7 +283,6 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||
|
||||
# expr
|
||||
visit_Call = visit_unsupported
|
||||
visit_Compare = visit_unsupported
|
||||
visit_Dict = visit_unsupported
|
||||
visit_DictComp = visit_unsupported
|
||||
visit_Ellipsis = visit_unsupported
|
||||
@ -393,11 +392,14 @@ class Inferencer(algorithm.Visitor):
|
||||
node.attr_loc, [node.value.loc])
|
||||
self.engine.process(diag)
|
||||
|
||||
def _unify_collection(self, element, collection):
|
||||
# TODO: support more than just lists
|
||||
self._unify(builtins.TList(element.type), collection.type,
|
||||
element.loc, collection.loc)
|
||||
|
||||
def visit_SubscriptT(self, node):
|
||||
self.generic_visit(node)
|
||||
# TODO: support more than just lists
|
||||
self._unify(builtins.TList(node.type), node.value.type,
|
||||
node.loc, node.value.loc)
|
||||
self._unify_collection(element=node, collection=node.value)
|
||||
|
||||
def visit_IfExpT(self, node):
|
||||
self.generic_visit(node)
|
||||
@ -455,43 +457,39 @@ class Inferencer(algorithm.Visitor):
|
||||
def _coerce_one(self, typ, coerced_node, other_node):
|
||||
if coerced_node.type.find() == typ.find():
|
||||
return coerced_node
|
||||
elif isinstance(coerced_node, asttyped.CoerceT):
|
||||
node.type, node.other_expr = typ, other_node
|
||||
else:
|
||||
node = asttyped.CoerceT(type=typ, expr=coerced_node, other_expr=other_node,
|
||||
loc=coerced_node.loc)
|
||||
self.visit(node)
|
||||
return node
|
||||
self.visit(node)
|
||||
return node
|
||||
|
||||
def _coerce_numeric(self, left, right):
|
||||
# Implements the coercion protocol.
|
||||
def _coerce_numeric(self, nodes, map_return=lambda typ: typ):
|
||||
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
|
||||
if builtins.is_float(left.type) or builtins.is_float(right.type):
|
||||
typ = builtins.TFloat()
|
||||
elif builtins.is_int(left.type) or builtins.is_int(right.type):
|
||||
left_width, right_width = \
|
||||
builtins.get_int_width(left.type), builtins.get_int_width(left.type)
|
||||
if left_width and right_width:
|
||||
typ = builtins.TInt(types.TValue(max(left_width, right_width)))
|
||||
else:
|
||||
typ = builtins.TInt()
|
||||
elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet
|
||||
node_types = [node.type for node in nodes]
|
||||
if any(map(types.is_var, node_types)): # not enough info yet
|
||||
return
|
||||
else: # conflicting types
|
||||
printer = types.TypePrinter()
|
||||
note1 = diagnostic.Diagnostic("note",
|
||||
"expression of type {typea}", {"typea": printer.name(left.type)},
|
||||
left.loc)
|
||||
note2 = diagnostic.Diagnostic("note",
|
||||
"expression of type {typeb}", {"typeb": printer.name(right.type)},
|
||||
right.loc)
|
||||
elif not all(map(builtins.is_numeric, node_types)):
|
||||
err_node = next(filter(lambda node: not builtins.is_numeric(node.type), nodes))
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"cannot coerce {typea} and {typeb} to a common numeric type",
|
||||
{"typea": printer.name(left.type), "typeb": printer.name(right.type)},
|
||||
left.loc, [right.loc],
|
||||
[note1, note2])
|
||||
"cannot coerce {type} to a numeric type",
|
||||
{"type": types.TypePrinter().name(err_node.type)},
|
||||
err_node.loc, [])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
elif any(map(builtins.is_float, node_types)):
|
||||
typ = builtins.TFloat()
|
||||
elif any(map(builtins.is_int, node_types)):
|
||||
widths = map(builtins.get_int_width, node_types)
|
||||
if all(widths):
|
||||
typ = builtins.TInt(types.TValue(max(widths)))
|
||||
else:
|
||||
typ = builtins.TInt()
|
||||
else:
|
||||
assert False
|
||||
|
||||
return typ, typ, typ
|
||||
return map_return(typ)
|
||||
|
||||
def _order_by_pred(self, pred, left, right):
|
||||
if pred(left.type):
|
||||
@ -503,7 +501,7 @@ class Inferencer(algorithm.Visitor):
|
||||
|
||||
def _coerce_binop(self, op, left, right):
|
||||
if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
||||
ast.LShift, ast.RShift)):
|
||||
ast.LShift, ast.RShift)):
|
||||
# bitwise operators require integers
|
||||
for operand in (left, right):
|
||||
if not types.is_var(operand.type) and not builtins.is_int(operand.type):
|
||||
@ -515,7 +513,7 @@ class Inferencer(algorithm.Visitor):
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
return self._coerce_numeric(left, right)
|
||||
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
|
||||
elif isinstance(op, ast.Add):
|
||||
# add works on numbers and also collections
|
||||
if builtins.is_collection(left.type) or builtins.is_collection(right.type):
|
||||
@ -554,7 +552,7 @@ class Inferencer(algorithm.Visitor):
|
||||
left.loc, right.loc)
|
||||
return left.type, left.type, right.type
|
||||
else:
|
||||
return self._coerce_numeric(left, right)
|
||||
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
|
||||
elif isinstance(op, ast.Mult):
|
||||
# mult works on numbers and also number & collection
|
||||
if types.is_tuple(left.type) or types.is_tuple(right.type):
|
||||
@ -585,10 +583,10 @@ class Inferencer(algorithm.Visitor):
|
||||
|
||||
return list_.type, left.type, right.type
|
||||
else:
|
||||
return self._coerce_numeric(left, right)
|
||||
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
|
||||
elif isinstance(op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
|
||||
# numeric operators work on any kind of number
|
||||
return self._coerce_numeric(left, right)
|
||||
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
|
||||
else: # MatMult
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"operator '{op}' is not supported", {"op": op.loc.source()},
|
||||
@ -605,6 +603,42 @@ class Inferencer(algorithm.Visitor):
|
||||
node.right = self._coerce_one(right_type, node.right, other_node=node.left)
|
||||
node.type.unify(return_type) # should never fail
|
||||
|
||||
def visit_CompareT(self, node):
|
||||
self.generic_visit(node)
|
||||
pairs = zip([node.left] + node.comparators, node.comparators)
|
||||
if all(map(lambda op: isinstance(op, (ast.Is, ast.IsNot)), node.ops)):
|
||||
for left, right in pairs:
|
||||
self._unify(left.type, right.type,
|
||||
left.loc, right.loc)
|
||||
elif all(map(lambda op: isinstance(op, (ast.In, ast.NotIn)), node.ops)):
|
||||
for left, right in pairs:
|
||||
self._unify_collection(element=left, collection=right)
|
||||
else: # Eq, NotEq, Lt, LtE, Gt, GtE
|
||||
operands = [node.left] + node.comparators
|
||||
operand_types = [operand.type for operand in operands]
|
||||
if any(map(builtins.is_collection, operand_types)):
|
||||
for left, right in pairs:
|
||||
self._unify(left.type, right.type,
|
||||
left.loc, right.loc)
|
||||
else:
|
||||
typ = self._coerce_numeric(operands)
|
||||
if typ:
|
||||
try:
|
||||
other_node = next(filter(lambda operand: operand.type.find() == typ.find(),
|
||||
operands))
|
||||
except StopIteration:
|
||||
# can't find an argument with an exact type, meaning
|
||||
# the return value is more generic than any of the inputs, meaning
|
||||
# the type is known (typ is not None), but its width is not
|
||||
def wide_enough(opreand):
|
||||
return types.is_mono(opreand.type) and \
|
||||
opreand.type.find().name == typ.find().name
|
||||
other_node = next(filter(wide_enough, operands))
|
||||
print(typ, other_node)
|
||||
node.left, *node.comparators = \
|
||||
[self._coerce_one(typ, operand, other_node) for operand in operands]
|
||||
node.type.unify(builtins.TBool())
|
||||
|
||||
def visit_Assign(self, node):
|
||||
self.generic_visit(node)
|
||||
if len(node.targets) > 1:
|
||||
|
@ -27,3 +27,15 @@
|
||||
|
||||
a = []; a += [1]
|
||||
# CHECK-L: a:list(elt=int(width='r)) = []:list(elt=int(width='r)); a:list(elt=int(width='r)) += [1:int(width='r)]:list(elt=int(width='r))
|
||||
|
||||
[] is [1]
|
||||
# CHECK-L: []:list(elt=int(width='s)) is [1:int(width='s)]:list(elt=int(width='s)):bool
|
||||
|
||||
1 in [1]
|
||||
# CHECK-L: 1:int(width='t) in [1:int(width='t)]:list(elt=int(width='t)):bool
|
||||
|
||||
[] < [1]
|
||||
# CHECK-L: []:list(elt=int(width='u)) < [1:int(width='u)]:list(elt=int(width='u)):bool
|
||||
|
||||
1.0 < 1
|
||||
# CHECK-L: 1.0:float < 1:int(width='v):float:bool
|
||||
|
@ -25,13 +25,7 @@
|
||||
# CHECK-L: ${LINE:+1}: note: operand of type list(elt='b), which is not a valid repetition amount
|
||||
[1] * []
|
||||
|
||||
# CHECK-L: ${LINE:+3}: error: cannot coerce list(elt='a) and NoneType to a common numeric type
|
||||
# CHECK-L: ${LINE:+2}: note: expression of type list(elt='a)
|
||||
# CHECK-L: ${LINE:+1}: note: expression of type NoneType
|
||||
[] - None
|
||||
|
||||
# CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float
|
||||
# CHECK-L: ${LINE:+1}: note: expression that required coercion to float
|
||||
# CHECK-L: ${LINE:+1}: error: cannot coerce list(elt='a) to a numeric type
|
||||
[] - 1.0
|
||||
|
||||
# CHECK-L: ${LINE:+2}: error: expression of type int(width='a) has to be coerced to float, which makes assignment invalid
|
||||
|
Loading…
Reference in New Issue
Block a user