forked from M-Labs/artiq
Add support for function types and LambdaT.
Also fix scoping of Nonlocal.
This commit is contained in:
parent
dbfdbc3c22
commit
20e0e69358
|
@ -29,7 +29,7 @@ class argT(ast.arg, commontyped):
|
|||
class ClassDefT(ast.ClassDef, scoped):
|
||||
pass
|
||||
class FunctionDefT(ast.FunctionDef, scoped):
|
||||
pass
|
||||
_types = ("signature_type",)
|
||||
class ModuleT(ast.Module, scoped):
|
||||
pass
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ class TVar(Type):
|
|||
|
||||
def __repr__(self):
|
||||
if self.parent is self:
|
||||
return "TVar(%d)" % id(self)
|
||||
return "<py2llvm.types.TVar %d>" % id(self)
|
||||
else:
|
||||
return repr(self.find())
|
||||
|
||||
|
@ -88,7 +88,7 @@ class TMono(Type):
|
|||
raise UnificationError(self, other)
|
||||
|
||||
def __repr__(self):
|
||||
return "TMono(%s, %s)" % (repr(self.name), repr(self.params))
|
||||
return "py2llvm.types.TMono(%s, %s)" % (repr(self.name), repr(self.params))
|
||||
|
||||
def __getitem__(self, param):
|
||||
return self.params[param]
|
||||
|
@ -102,7 +102,11 @@ class TMono(Type):
|
|||
return not (self == other)
|
||||
|
||||
class TTuple(Type):
|
||||
"""A tuple type."""
|
||||
"""
|
||||
A tuple type.
|
||||
|
||||
:ivar elts: (list of :class:`Type`) elements
|
||||
"""
|
||||
|
||||
attributes = {}
|
||||
|
||||
|
@ -122,7 +126,7 @@ class TTuple(Type):
|
|||
raise UnificationError(self, other)
|
||||
|
||||
def __repr__(self):
|
||||
return "TTuple(%s)" % (", ".join(map(repr, self.elts)))
|
||||
return "py2llvm.types.TTuple(%s)" % repr(self.elts)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TTuple) and \
|
||||
|
@ -131,6 +135,51 @@ class TTuple(Type):
|
|||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
class TFunction(Type):
|
||||
"""
|
||||
A function type.
|
||||
|
||||
:ivar args: (:class:`collections.OrderedDict` of string to :class:`Type`)
|
||||
mandatory arguments
|
||||
:ivar optargs: (:class:`collections.OrderedDict` of string to :class:`Type`)
|
||||
optional arguments
|
||||
:ivar ret: (:class:`Type`)
|
||||
return type
|
||||
"""
|
||||
|
||||
attributes = {}
|
||||
|
||||
def __init__(self, args, optargs, ret):
|
||||
self.args, self.optargs, self.ret = args, optargs, ret
|
||||
|
||||
def find(self):
|
||||
return self
|
||||
|
||||
def unify(self, other):
|
||||
if isinstance(other, TFunction) and \
|
||||
self.args.keys() == other.args.keys() and \
|
||||
self.optargs.keys() == other.optargs.keys():
|
||||
for selfarg, otherarg in zip(self.args.values() + self.optargs.values(),
|
||||
other.args.values() + other.optargs.values()):
|
||||
selfarg.unify(otherarg)
|
||||
self.ret.unify(other.ret)
|
||||
elif isinstance(other, TVar):
|
||||
other.unify(self)
|
||||
else:
|
||||
raise UnificationError(self, other)
|
||||
|
||||
def __repr__(self):
|
||||
return "py2llvm.types.TFunction(%s, %s, %s)" % \
|
||||
(repr(self.args), repr(self.optargs), repr(self.ret))
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TFunction) and \
|
||||
self.args == other.args and \
|
||||
self.optargs == other.optargs
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
class TValue(Type):
|
||||
"""
|
||||
A type-level value (such as the integer denoting width of
|
||||
|
@ -150,7 +199,7 @@ class TValue(Type):
|
|||
raise UnificationError(self, other)
|
||||
|
||||
def __repr__(self):
|
||||
return "TValue(%s)" % repr(self.value)
|
||||
return "py2llvm.types.TValue(%s)" % repr(self.value)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TValue) and \
|
||||
|
@ -216,6 +265,11 @@ class TypePrinter(object):
|
|||
return "(%s,)" % self.name(typ.elts[0])
|
||||
else:
|
||||
return "(%s)" % ", ".join(list(map(self.name, typ.elts)))
|
||||
elif isinstance(typ, TFunction):
|
||||
args = []
|
||||
args += [ "%s:%s" % (arg, self.name(typ.args[arg])) for arg in typ.args]
|
||||
args += ["?%s:%s" % (arg, self.name(typ.optargs[arg])) for arg in typ.optargs]
|
||||
return "(%s)->%s" % (", ".join(args), self.name(typ.ret))
|
||||
elif isinstance(typ, TValue):
|
||||
return repr(typ.value)
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from pythonparser import source, ast, algorithm, diagnostic, parse_buffer
|
||||
from collections import OrderedDict
|
||||
from . import asttyped, types, builtins
|
||||
|
||||
# This visitor will be called for every node with a scope,
|
||||
|
@ -23,6 +24,9 @@ class LocalExtractor(algorithm.Visitor):
|
|||
# parameters can't be declared as global or nonlocal
|
||||
self.params = set()
|
||||
|
||||
if len(self.env_stack) == 0:
|
||||
self.env_stack.append(self.typing_env)
|
||||
|
||||
def visit_in_assign(self, node):
|
||||
try:
|
||||
self.in_assign = True
|
||||
|
@ -58,14 +62,19 @@ class LocalExtractor(algorithm.Visitor):
|
|||
self.in_root = True
|
||||
self.generic_visit(node)
|
||||
|
||||
visit_ClassDef = visit_root # don't look at inner scopes
|
||||
visit_FunctionDef = visit_root
|
||||
visit_Module = visit_root # don't look at inner scopes
|
||||
visit_ClassDef = visit_root
|
||||
visit_Lambda = visit_root
|
||||
visit_DictComp = visit_root
|
||||
visit_ListComp = visit_root
|
||||
visit_SetComp = visit_root
|
||||
visit_GeneratorExp = visit_root
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
if self.in_root:
|
||||
self._assignable(node.name)
|
||||
self.visit_root(node)
|
||||
|
||||
def _assignable(self, name):
|
||||
if name not in self.typing_env and name not in self.nonlocal_:
|
||||
self.typing_env[name] = types.TVar()
|
||||
|
@ -103,7 +112,10 @@ class LocalExtractor(algorithm.Visitor):
|
|||
if self._check_not_in(name, self.nonlocal_, "nonlocal", "global", loc) or \
|
||||
self._check_not_in(name, self.params, "a parameter", "global", loc):
|
||||
continue
|
||||
|
||||
self.global_.add(name)
|
||||
self._assignable(name)
|
||||
self.env_stack[0][name] = self.typing_env[name]
|
||||
|
||||
def visit_Nonlocal(self, node):
|
||||
for name, loc in zip(node.names, node.name_locs):
|
||||
|
@ -111,8 +123,9 @@ class LocalExtractor(algorithm.Visitor):
|
|||
self._check_not_in(name, self.params, "a parameter", "nonlocal", loc):
|
||||
continue
|
||||
|
||||
# nonlocal does not search global scope
|
||||
found = False
|
||||
for outer_env in reversed(self.env_stack):
|
||||
for outer_env in reversed(self.env_stack[1:]):
|
||||
if name in outer_env:
|
||||
found = True
|
||||
break
|
||||
|
@ -164,12 +177,7 @@ class ASTTypedRewriter(algorithm.Transformer):
|
|||
node = asttyped.ModuleT(
|
||||
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
||||
body=node.body, loc=node.loc)
|
||||
|
||||
try:
|
||||
self.env_stack.append(node.typing_env)
|
||||
return self.generic_visit(node)
|
||||
finally:
|
||||
self.env_stack.pop()
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||
|
@ -177,7 +185,7 @@ class ASTTypedRewriter(algorithm.Transformer):
|
|||
|
||||
node = asttyped.FunctionDefT(
|
||||
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
||||
return_type=types.TVar(),
|
||||
signature_type=self._find_name(node.name, node.name_loc), return_type=types.TVar(),
|
||||
name=node.name, args=node.args, returns=node.returns,
|
||||
body=node.body, decorator_list=node.decorator_list,
|
||||
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
|
||||
|
@ -295,6 +303,22 @@ class ASTTypedRewriter(algorithm.Transformer):
|
|||
finally:
|
||||
self.env_stack.pop()
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||
extractor.visit(node)
|
||||
|
||||
node = asttyped.LambdaT(
|
||||
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
||||
type=types.TVar(),
|
||||
args=node.args, body=node.body,
|
||||
lambda_loc=node.lambda_loc, colon_loc=node.colon_loc, loc=node.loc)
|
||||
|
||||
try:
|
||||
self.env_stack.append(node.typing_env)
|
||||
return self.generic_visit(node)
|
||||
finally:
|
||||
self.env_stack.pop()
|
||||
|
||||
def visit_Raise(self, node):
|
||||
node = self.generic_visit(node)
|
||||
if node.cause:
|
||||
|
@ -318,7 +342,6 @@ class ASTTypedRewriter(algorithm.Transformer):
|
|||
visit_DictComp = visit_unsupported
|
||||
visit_Ellipsis = visit_unsupported
|
||||
visit_GeneratorExp = visit_unsupported
|
||||
visit_Lambda = visit_unsupported
|
||||
visit_Set = visit_unsupported
|
||||
visit_SetComp = visit_unsupported
|
||||
visit_Str = visit_unsupported
|
||||
|
@ -363,12 +386,14 @@ class Inferencer(algorithm.Visitor):
|
|||
diagnostic.Diagnostic("note",
|
||||
"expression of type {typea}",
|
||||
{"typea": printer.name(typea)},
|
||||
loca),
|
||||
diagnostic.Diagnostic("note",
|
||||
"expression of type {typeb}",
|
||||
{"typeb": printer.name(typeb)},
|
||||
locb)
|
||||
loca)
|
||||
]
|
||||
if locb:
|
||||
notes.append(
|
||||
diagnostic.Diagnostic("note",
|
||||
"expression of type {typeb}",
|
||||
{"typeb": printer.name(typeb)},
|
||||
locb))
|
||||
|
||||
highlights = [locb] if locb else []
|
||||
if e.typea.find() == typea.find() and e.typeb.find() == typeb.find():
|
||||
|
@ -412,7 +437,8 @@ class Inferencer(algorithm.Visitor):
|
|||
if not types.is_var(object_type):
|
||||
if node.attr in object_type.attributes:
|
||||
# assumes no free type variables in .attributes
|
||||
node.type.unify(object_type.attributes[node.attr]) # should never fail
|
||||
self._unify(node.type, object_type.attributes[node.attr],
|
||||
node.loc, None)
|
||||
else:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"type {type} does not have an attribute '{attr}'",
|
||||
|
@ -433,7 +459,8 @@ class Inferencer(algorithm.Visitor):
|
|||
self.generic_visit(node)
|
||||
self._unify(node.body.type, node.orelse.type,
|
||||
node.body.loc, node.orelse.loc)
|
||||
node.type.unify(node.body.type) # should never fail
|
||||
self._unify(node.type, node.body.type,
|
||||
node.loc, None)
|
||||
|
||||
def visit_BoolOpT(self, node):
|
||||
self.generic_visit(node)
|
||||
|
@ -445,10 +472,12 @@ class Inferencer(algorithm.Visitor):
|
|||
self.generic_visit(node)
|
||||
operand_type = node.operand.type.find()
|
||||
if isinstance(node.op, ast.Not):
|
||||
node.type.unify(builtins.TBool()) # should never fail
|
||||
self._unify(node.type, builtins.TBool(),
|
||||
node.loc, None)
|
||||
elif isinstance(node.op, ast.Invert):
|
||||
if builtins.is_int(operand_type):
|
||||
node.type.unify(operand_type) # should never fail
|
||||
self._unify(node.type, operand_type,
|
||||
node.loc, None)
|
||||
elif not types.is_var(operand_type):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expected '~' operand to be of integer type, not {type}",
|
||||
|
@ -457,7 +486,8 @@ class Inferencer(algorithm.Visitor):
|
|||
self.engine.process(diag)
|
||||
else: # UAdd, USub
|
||||
if builtins.is_numeric(operand_type):
|
||||
node.type.unify(operand_type) # should never fail
|
||||
self._unify(node.type, operand_type,
|
||||
node.loc, None)
|
||||
elif not types.is_var(operand_type):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expected unary '{op}' operand to be of numeric type, not {type}",
|
||||
|
@ -572,7 +602,6 @@ class Inferencer(algorithm.Visitor):
|
|||
return
|
||||
|
||||
if types.is_tuple(collection.type):
|
||||
# should never fail
|
||||
return types.TTuple(left.type.find().elts +
|
||||
right.type.find().elts), left.type, right.type
|
||||
elif builtins.is_list(collection.type):
|
||||
|
@ -629,7 +658,8 @@ class Inferencer(algorithm.Visitor):
|
|||
return_type, left_type, right_type = coerced
|
||||
node.left = self._coerce_one(left_type, node.left, other_node=node.right)
|
||||
node.right = self._coerce_one(right_type, node.right, other_node=node.left)
|
||||
node.type.unify(return_type) # should never fail
|
||||
self._unify(node.type, return_type,
|
||||
node.loc, None)
|
||||
|
||||
def visit_CompareT(self, node):
|
||||
self.generic_visit(node)
|
||||
|
@ -665,16 +695,25 @@ class Inferencer(algorithm.Visitor):
|
|||
print(typ, other_node)
|
||||
node.left, *node.comparators = \
|
||||
[self._coerce_one(typ, operand, other_node) for operand in operands]
|
||||
node.type.unify(builtins.TBool())
|
||||
self._unify(node.type, builtins.TBool(),
|
||||
node.loc, None)
|
||||
|
||||
def visit_ListCompT(self, node):
|
||||
self.generic_visit(node)
|
||||
node.type.unify(builtins.TList(node.elt.type)) # should never fail
|
||||
self._unify(node.type, builtins.TList(node.elt.type),
|
||||
node.loc, None)
|
||||
|
||||
def visit_comprehension(self, node):
|
||||
self.generic_visit(node)
|
||||
self._unify_collection(element=node.target, collection=node.iter)
|
||||
|
||||
def visit_LambdaT(self, node):
|
||||
self.generic_visit(node)
|
||||
signature_type = self._type_from_arguments(node.args, node.body.type)
|
||||
if signature_type:
|
||||
self._unify(node.type, signature_type,
|
||||
node.loc, None)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
self.generic_visit(node)
|
||||
if len(node.targets) > 1:
|
||||
|
@ -762,6 +801,32 @@ class Inferencer(algorithm.Visitor):
|
|||
node.context_expr.loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
def _type_from_arguments(self, node, ret):
|
||||
self.generic_visit(node)
|
||||
|
||||
for (sigil_loc, vararg) in ((node.star_loc, node.vararg),
|
||||
(node.dstar_loc, node.kwarg)):
|
||||
if vararg:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"variadic arguments are not supported", {},
|
||||
sigil_loc, [vararg.loc])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
def extract_args(arg_nodes):
|
||||
args = [(arg_node.arg, arg_node.type) for arg_node in arg_nodes]
|
||||
return OrderedDict(args)
|
||||
|
||||
return types.TFunction(extract_args(node.args[:len(node.args) - len(node.defaults)]),
|
||||
extract_args(node.args[len(node.defaults):]),
|
||||
ret)
|
||||
|
||||
def visit_arguments(self, node):
|
||||
self.generic_visit(node)
|
||||
for arg, default in zip(node.args[len(node.defaults):], node.defaults):
|
||||
self._unify(arg.type, default.type,
|
||||
arg.loc, default.loc)
|
||||
|
||||
def visit_FunctionDefT(self, node):
|
||||
old_function, self.function = self.function, node
|
||||
old_in_loop, self.in_loop = self.in_loop, False
|
||||
|
@ -769,6 +834,18 @@ class Inferencer(algorithm.Visitor):
|
|||
self.function = old_function
|
||||
self.in_loop = old_in_loop
|
||||
|
||||
if any(node.decorator_list):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"decorators are not supported", {},
|
||||
node.at_locs[0], [node.decorator_list[0].loc])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
signature_type = self._type_from_arguments(node.args, node.return_type)
|
||||
if signature_type:
|
||||
self._unify(node.signature_type, signature_type,
|
||||
node.name_loc, None)
|
||||
|
||||
def visit_Return(self, node):
|
||||
if not self.function:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
|
|
|
@ -1,29 +1,35 @@
|
|||
# RUN: %python -m artiq.py2llvm.typing +diag %s >%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
x = 1
|
||||
def a():
|
||||
# CHECK-L: ${LINE:+1}: error: cannot declare name 'x' as nonlocal: it is not bound in any outer scope
|
||||
nonlocal x
|
||||
|
||||
x = 1
|
||||
def b():
|
||||
nonlocal x
|
||||
# CHECK-L: ${LINE:+1}: error: name 'x' cannot be nonlocal and global simultaneously
|
||||
global x
|
||||
def f():
|
||||
y = 1
|
||||
def b():
|
||||
nonlocal y
|
||||
# CHECK-L: ${LINE:+1}: error: name 'y' cannot be nonlocal and global simultaneously
|
||||
global y
|
||||
|
||||
def c():
|
||||
global x
|
||||
# CHECK-L: ${LINE:+1}: error: name 'x' cannot be global and nonlocal simultaneously
|
||||
nonlocal x
|
||||
def c():
|
||||
global y
|
||||
# CHECK-L: ${LINE:+1}: error: name 'y' cannot be global and nonlocal simultaneously
|
||||
nonlocal y
|
||||
|
||||
def d(x):
|
||||
# CHECK-L: ${LINE:+1}: error: name 'x' cannot be a parameter and global simultaneously
|
||||
global x
|
||||
def d(y):
|
||||
# CHECK-L: ${LINE:+1}: error: name 'y' cannot be a parameter and global simultaneously
|
||||
global y
|
||||
|
||||
def e(x):
|
||||
# CHECK-L: ${LINE:+1}: error: name 'x' cannot be a parameter and nonlocal simultaneously
|
||||
nonlocal x
|
||||
def e(y):
|
||||
# CHECK-L: ${LINE:+1}: error: name 'y' cannot be a parameter and nonlocal simultaneously
|
||||
nonlocal y
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: duplicate parameter 'x'
|
||||
def f(x, x):
|
||||
pass
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: variadic arguments are not supported
|
||||
def g(*x):
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
# RUN: %python -m artiq.py2llvm.typing %s >%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
def f():
|
||||
global x
|
||||
x = 1
|
||||
# CHECK-L: [x:int(width='a)]
|
||||
[x]
|
|
@ -50,3 +50,7 @@ not 1
|
|||
|
||||
[x for x in [1]]
|
||||
# CHECK-L: [x:int(width='j) for x:int(width='j) in [1:int(width='j)]:list(elt=int(width='j))]:list(elt=int(width='j))
|
||||
|
||||
lambda x, y=1: x
|
||||
# CHECK-L: lambda x:'a, y:int(width='b)=1:int(width='b): x:'a:(x:'a, ?y:int(width='b))->'a
|
||||
|
||||
|
|
Loading…
Reference in New Issue