forked from M-Labs/artiq
Add range types.
This commit is contained in:
parent
71256a7109
commit
e07057c224
|
@ -31,6 +31,12 @@ class TList(types.TMono):
|
||||||
elt = types.TVar()
|
elt = types.TVar()
|
||||||
super().__init__("list", {"elt": elt})
|
super().__init__("list", {"elt": elt})
|
||||||
|
|
||||||
|
class TRange(types.TMono):
|
||||||
|
def __init__(self, elt=None):
|
||||||
|
if elt is None:
|
||||||
|
elt = types.TVar()
|
||||||
|
super().__init__("range", {"elt": elt})
|
||||||
|
|
||||||
def fn_bool():
|
def fn_bool():
|
||||||
return types.TBuiltin("class bool")
|
return types.TBuiltin("class bool")
|
||||||
|
|
||||||
|
@ -43,15 +49,15 @@ def fn_float():
|
||||||
def fn_list():
|
def fn_list():
|
||||||
return types.TBuiltin("class list")
|
return types.TBuiltin("class list")
|
||||||
|
|
||||||
|
def fn_range():
|
||||||
|
return types.TBuiltin("function range")
|
||||||
|
|
||||||
def fn_len():
|
def fn_len():
|
||||||
return types.TBuiltin("function len")
|
return types.TBuiltin("function len")
|
||||||
|
|
||||||
def fn_round():
|
def fn_round():
|
||||||
return types.TBuiltin("function round")
|
return types.TBuiltin("function round")
|
||||||
|
|
||||||
def fn_range():
|
|
||||||
return types.TBuiltin("function range")
|
|
||||||
|
|
||||||
def fn_syscall():
|
def fn_syscall():
|
||||||
return types.TBuiltin("function syscall")
|
return types.TBuiltin("function syscall")
|
||||||
|
|
||||||
|
@ -87,6 +93,21 @@ def is_list(typ, elt=None):
|
||||||
else:
|
else:
|
||||||
return types.is_mono(typ, "list")
|
return types.is_mono(typ, "list")
|
||||||
|
|
||||||
|
def is_range(typ, elt=None):
|
||||||
|
if elt:
|
||||||
|
return types.is_mono(typ, "range", {"elt": elt})
|
||||||
|
else:
|
||||||
|
return types.is_mono(typ, "range")
|
||||||
|
|
||||||
|
def is_iterable(typ):
|
||||||
|
typ = typ.find()
|
||||||
|
return isinstance(typ, types.TMono) and \
|
||||||
|
typ.name in ('list', 'range')
|
||||||
|
|
||||||
|
def get_iterable_elt(typ):
|
||||||
|
if is_iterable(typ):
|
||||||
|
return typ.find()["elt"]
|
||||||
|
|
||||||
def is_collection(typ):
|
def is_collection(typ):
|
||||||
typ = typ.find()
|
typ = typ.find()
|
||||||
return isinstance(typ, types.TTuple) or \
|
return isinstance(typ, types.TTuple) or \
|
||||||
|
|
|
@ -11,8 +11,8 @@ def globals():
|
||||||
"int": builtins.fn_int(),
|
"int": builtins.fn_int(),
|
||||||
"float": builtins.fn_float(),
|
"float": builtins.fn_float(),
|
||||||
"list": builtins.fn_list(),
|
"list": builtins.fn_list(),
|
||||||
|
"range": builtins.fn_range(),
|
||||||
"len": builtins.fn_len(),
|
"len": builtins.fn_len(),
|
||||||
"round": builtins.fn_round(),
|
"round": builtins.fn_round(),
|
||||||
"range": builtins.fn_range(),
|
|
||||||
"syscall": builtins.fn_syscall(),
|
"syscall": builtins.fn_syscall(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -454,14 +454,22 @@ class Inferencer(algorithm.Visitor):
|
||||||
node.attr_loc, [node.value.loc])
|
node.attr_loc, [node.value.loc])
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
def _unify_collection(self, element, collection):
|
def _unify_iterable(self, element, collection):
|
||||||
# TODO: support more than just lists
|
if builtins.is_iterable(collection.type):
|
||||||
self._unify(builtins.TList(element.type), collection.type,
|
rhs_type = collection.type.find()
|
||||||
|
rhs_wrapped_lhs_type = types.TMono(rhs_type.name, {"elt": element.type})
|
||||||
|
self._unify(rhs_wrapped_lhs_type, rhs_type,
|
||||||
element.loc, collection.loc)
|
element.loc, collection.loc)
|
||||||
|
elif not types.is_var(collection.type):
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"type {type} is not iterable",
|
||||||
|
{"type": types.TypePrinter().name(collection.type)},
|
||||||
|
collection.loc, [])
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
def visit_SubscriptT(self, node):
|
def visit_SubscriptT(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
self._unify_collection(element=node, collection=node.value)
|
self._unify_iterable(element=node, collection=node.value)
|
||||||
|
|
||||||
def visit_IfExpT(self, node):
|
def visit_IfExpT(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
|
@ -678,7 +686,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
left.loc, right.loc)
|
left.loc, right.loc)
|
||||||
elif all(map(lambda op: isinstance(op, (ast.In, ast.NotIn)), node.ops)):
|
elif all(map(lambda op: isinstance(op, (ast.In, ast.NotIn)), node.ops)):
|
||||||
for left, right in pairs:
|
for left, right in pairs:
|
||||||
self._unify_collection(element=left, collection=right)
|
self._unify_iterable(element=left, collection=right)
|
||||||
else: # Eq, NotEq, Lt, LtE, Gt, GtE
|
else: # Eq, NotEq, Lt, LtE, Gt, GtE
|
||||||
operands = [node.left] + node.comparators
|
operands = [node.left] + node.comparators
|
||||||
operand_types = [operand.type for operand in operands]
|
operand_types = [operand.type for operand in operands]
|
||||||
|
@ -713,7 +721,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
|
|
||||||
def visit_comprehension(self, node):
|
def visit_comprehension(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
self._unify_collection(element=node.target, collection=node.iter)
|
self._unify_iterable(element=node.target, collection=node.iter)
|
||||||
|
|
||||||
def visit_builtin_call(self, node):
|
def visit_builtin_call(self, node):
|
||||||
typ = node.func.type.find()
|
typ = node.func.type.find()
|
||||||
|
@ -770,6 +778,8 @@ class Inferencer(algorithm.Visitor):
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"the width argument of int() must be an integer literal", {},
|
"the width argument of int() must be an integer literal", {},
|
||||||
node.keywords[0].loc)
|
node.keywords[0].loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
||||||
self._unify(node.type, builtins.TInt(types.TValue(width.n)),
|
self._unify(node.type, builtins.TInt(types.TValue(width.n)),
|
||||||
node.loc, None)
|
node.loc, None)
|
||||||
|
@ -805,9 +815,36 @@ class Inferencer(algorithm.Visitor):
|
||||||
pass # []
|
pass # []
|
||||||
else:
|
else:
|
||||||
diagnose(valid_forms())
|
diagnose(valid_forms())
|
||||||
|
elif builtins.is_builtin(typ, "function range"):
|
||||||
|
valid_forms = lambda: [
|
||||||
|
valid_form("range(max:'a) -> range(elt='a)"),
|
||||||
|
valid_form("range(min:'a, max:'a) -> range(elt='a)"),
|
||||||
|
valid_form("range(min:'a, max:'a, step:'a) -> range(elt='a)"),
|
||||||
|
]
|
||||||
|
|
||||||
|
range_tvar = types.TVar()
|
||||||
|
self._unify(node.type, builtins.TRange(range_tvar),
|
||||||
|
node.loc, None)
|
||||||
|
|
||||||
|
if len(node.args) in (1, 2, 3) and len(node.keywords) == 0:
|
||||||
|
for arg in node.args:
|
||||||
|
self._unify(arg.type, range_tvar,
|
||||||
|
arg.loc, None)
|
||||||
|
|
||||||
|
if not builtins.is_numeric(arg.type):
|
||||||
|
note = diagnostic.Diagnostic("note",
|
||||||
|
"this expression has type {type}",
|
||||||
|
{"type": types.TypePrinter().name(arg.type)},
|
||||||
|
arg.loc)
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"an argument of range() must be of a numeric type", {},
|
||||||
|
node.func.loc, notes=[note])
|
||||||
|
self.engine.process(diag)
|
||||||
|
else:
|
||||||
|
diagnose(valid_forms())
|
||||||
elif builtins.is_builtin(typ, "function len"):
|
elif builtins.is_builtin(typ, "function len"):
|
||||||
valid_forms = lambda: [
|
valid_forms = lambda: [
|
||||||
valid_form("len(x:list(elt='a)) -> int(width='b)"),
|
valid_form("len(x:'a) -> int(width='b) where 'a is iterable"),
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: should be ssize_t-sized
|
# TODO: should be ssize_t-sized
|
||||||
|
@ -817,8 +854,17 @@ class Inferencer(algorithm.Visitor):
|
||||||
if len(node.args) == 1 and len(node.keywords) == 0:
|
if len(node.args) == 1 and len(node.keywords) == 0:
|
||||||
arg, = node.args
|
arg, = node.args
|
||||||
|
|
||||||
self._unify(arg.type, builtins.TList(),
|
if builtins.is_list(arg.type) or builtins.is_range(arg.type):
|
||||||
arg.loc, None)
|
pass
|
||||||
|
else:
|
||||||
|
note = diagnostic.Diagnostic("note",
|
||||||
|
"this expression has type {type}",
|
||||||
|
{"type": types.TypePrinter().name(arg.type)},
|
||||||
|
arg.loc)
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"the argument of len() must be of an iterable type", {},
|
||||||
|
node.func.loc, notes=[note])
|
||||||
|
self.engine.process(diag)
|
||||||
else:
|
else:
|
||||||
diagnose(valid_forms())
|
diagnose(valid_forms())
|
||||||
elif builtins.is_builtin(typ, "function round"):
|
elif builtins.is_builtin(typ, "function round"):
|
||||||
|
@ -836,13 +882,6 @@ class Inferencer(algorithm.Visitor):
|
||||||
arg.loc, None)
|
arg.loc, None)
|
||||||
else:
|
else:
|
||||||
diagnose(valid_forms())
|
diagnose(valid_forms())
|
||||||
# TODO: add when there are range types
|
|
||||||
# elif builtins.is_builtin(typ, "function range"):
|
|
||||||
# valid_forms = lambda: [
|
|
||||||
# valid_form("range(max:'a) -> range(elt='a)"),
|
|
||||||
# valid_form("range(min:'a, max:'a) -> range(elt='a)"),
|
|
||||||
# valid_form("range(min:'a, max:'a, step:'a) -> range(elt='a)"),
|
|
||||||
# ]
|
|
||||||
# TODO: add when it is clear what interface syscall() has
|
# TODO: add when it is clear what interface syscall() has
|
||||||
# elif builtins.is_builtin(typ, "function syscall"):
|
# elif builtins.is_builtin(typ, "function syscall"):
|
||||||
# valid_Forms = lambda: [
|
# valid_Forms = lambda: [
|
||||||
|
@ -862,7 +901,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
|
|
||||||
if types.is_var(node.func.type):
|
if types.is_var(node.func.type):
|
||||||
return # not enough info yet
|
return # not enough info yet
|
||||||
elif types.is_mono(node.func.type) or types.is_builtin(node.func.type):
|
elif types.is_builtin(node.func.type):
|
||||||
return self.visit_builtin_call(node)
|
return self.visit_builtin_call(node)
|
||||||
elif not types.is_function(node.func.type):
|
elif not types.is_function(node.func.type):
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
@ -988,9 +1027,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
old_in_loop, self.in_loop = self.in_loop, True
|
old_in_loop, self.in_loop = self.in_loop, True
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
self.in_loop = old_in_loop
|
self.in_loop = old_in_loop
|
||||||
# TODO: support more than just lists
|
self._unify_iterable(node.target, node.iter)
|
||||||
self._unify(builtins.TList(node.target.type), node.iter.type,
|
|
||||||
node.target.loc, node.iter.loc)
|
|
||||||
|
|
||||||
def visit_While(self, node):
|
def visit_While(self, node):
|
||||||
old_in_loop, self.in_loop = self.in_loop, True
|
old_in_loop, self.in_loop = self.in_loop, True
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
# RUN: %python -m artiq.py2llvm.typing +diag %s >%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
a = 1
|
||||||
|
# CHECK-L: ${LINE:+1}: error: the width argument of int() must be an integer literal
|
||||||
|
int(1.0, width=a)
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+1}: error: the argument of len() must be of an iterable type
|
||||||
|
len(1)
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+1}: error: an argument of range() must be of a numeric type
|
||||||
|
range([])
|
|
@ -0,0 +1,5 @@
|
||||||
|
# RUN: %python -m artiq.py2llvm.typing +diag %s >%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+1}: error: type int(width='a) is not iterable
|
||||||
|
for x in 1: pass
|
Loading…
Reference in New Issue