forked from M-Labs/artiq
1
0
Fork 0

Add support for function types and LambdaT.

Also fix scoping of Nonlocal.
This commit is contained in:
whitequark 2015-06-15 11:30:50 +03:00
parent dbfdbc3c22
commit 20e0e69358
6 changed files with 195 additions and 46 deletions

View File

@ -29,7 +29,7 @@ class argT(ast.arg, commontyped):
class ClassDefT(ast.ClassDef, scoped): class ClassDefT(ast.ClassDef, scoped):
pass pass
class FunctionDefT(ast.FunctionDef, scoped): class FunctionDefT(ast.FunctionDef, scoped):
pass _types = ("signature_type",)
class ModuleT(ast.Module, scoped): class ModuleT(ast.Module, scoped):
pass pass

View File

@ -58,7 +58,7 @@ class TVar(Type):
def __repr__(self): def __repr__(self):
if self.parent is self: if self.parent is self:
return "TVar(%d)" % id(self) return "<py2llvm.types.TVar %d>" % id(self)
else: else:
return repr(self.find()) return repr(self.find())
@ -88,7 +88,7 @@ class TMono(Type):
raise UnificationError(self, other) raise UnificationError(self, other)
def __repr__(self): def __repr__(self):
return "TMono(%s, %s)" % (repr(self.name), repr(self.params)) return "py2llvm.types.TMono(%s, %s)" % (repr(self.name), repr(self.params))
def __getitem__(self, param): def __getitem__(self, param):
return self.params[param] return self.params[param]
@ -102,7 +102,11 @@ class TMono(Type):
return not (self == other) return not (self == other)
class TTuple(Type): class TTuple(Type):
"""A tuple type.""" """
A tuple type.
:ivar elts: (list of :class:`Type`) elements
"""
attributes = {} attributes = {}
@ -122,7 +126,7 @@ class TTuple(Type):
raise UnificationError(self, other) raise UnificationError(self, other)
def __repr__(self): def __repr__(self):
return "TTuple(%s)" % (", ".join(map(repr, self.elts))) return "py2llvm.types.TTuple(%s)" % repr(self.elts)
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, TTuple) and \ return isinstance(other, TTuple) and \
@ -131,6 +135,51 @@ class TTuple(Type):
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
class TFunction(Type):
"""
A function type.
:ivar args: (:class:`collections.OrderedDict` of string to :class:`Type`)
mandatory arguments
:ivar optargs: (:class:`collections.OrderedDict` of string to :class:`Type`)
optional arguments
:ivar ret: (:class:`Type`)
return type
"""
attributes = {}
def __init__(self, args, optargs, ret):
self.args, self.optargs, self.ret = args, optargs, ret
def find(self):
return self
def unify(self, other):
if isinstance(other, TFunction) and \
self.args.keys() == other.args.keys() and \
self.optargs.keys() == other.optargs.keys():
for selfarg, otherarg in zip(self.args.values() + self.optargs.values(),
other.args.values() + other.optargs.values()):
selfarg.unify(otherarg)
self.ret.unify(other.ret)
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(self, other)
def __repr__(self):
return "py2llvm.types.TFunction(%s, %s, %s)" % \
(repr(self.args), repr(self.optargs), repr(self.ret))
def __eq__(self, other):
return isinstance(other, TFunction) and \
self.args == other.args and \
self.optargs == other.optargs
def __ne__(self, other):
return not (self == other)
class TValue(Type): class TValue(Type):
""" """
A type-level value (such as the integer denoting width of A type-level value (such as the integer denoting width of
@ -150,7 +199,7 @@ class TValue(Type):
raise UnificationError(self, other) raise UnificationError(self, other)
def __repr__(self): def __repr__(self):
return "TValue(%s)" % repr(self.value) return "py2llvm.types.TValue(%s)" % repr(self.value)
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, TValue) and \ return isinstance(other, TValue) and \
@ -216,6 +265,11 @@ class TypePrinter(object):
return "(%s,)" % self.name(typ.elts[0]) return "(%s,)" % self.name(typ.elts[0])
else: else:
return "(%s)" % ", ".join(list(map(self.name, typ.elts))) return "(%s)" % ", ".join(list(map(self.name, typ.elts)))
elif isinstance(typ, TFunction):
args = []
args += [ "%s:%s" % (arg, self.name(typ.args[arg])) for arg in typ.args]
args += ["?%s:%s" % (arg, self.name(typ.optargs[arg])) for arg in typ.optargs]
return "(%s)->%s" % (", ".join(args), self.name(typ.ret))
elif isinstance(typ, TValue): elif isinstance(typ, TValue):
return repr(typ.value) return repr(typ.value)
else: else:

View File

@ -1,4 +1,5 @@
from pythonparser import source, ast, algorithm, diagnostic, parse_buffer from pythonparser import source, ast, algorithm, diagnostic, parse_buffer
from collections import OrderedDict
from . import asttyped, types, builtins from . import asttyped, types, builtins
# This visitor will be called for every node with a scope, # This visitor will be called for every node with a scope,
@ -23,6 +24,9 @@ class LocalExtractor(algorithm.Visitor):
# parameters can't be declared as global or nonlocal # parameters can't be declared as global or nonlocal
self.params = set() self.params = set()
if len(self.env_stack) == 0:
self.env_stack.append(self.typing_env)
def visit_in_assign(self, node): def visit_in_assign(self, node):
try: try:
self.in_assign = True self.in_assign = True
@ -58,14 +62,19 @@ class LocalExtractor(algorithm.Visitor):
self.in_root = True self.in_root = True
self.generic_visit(node) self.generic_visit(node)
visit_ClassDef = visit_root # don't look at inner scopes visit_Module = visit_root # don't look at inner scopes
visit_FunctionDef = visit_root visit_ClassDef = visit_root
visit_Lambda = visit_root visit_Lambda = visit_root
visit_DictComp = visit_root visit_DictComp = visit_root
visit_ListComp = visit_root visit_ListComp = visit_root
visit_SetComp = visit_root visit_SetComp = visit_root
visit_GeneratorExp = visit_root visit_GeneratorExp = visit_root
def visit_FunctionDef(self, node):
if self.in_root:
self._assignable(node.name)
self.visit_root(node)
def _assignable(self, name): def _assignable(self, name):
if name not in self.typing_env and name not in self.nonlocal_: if name not in self.typing_env and name not in self.nonlocal_:
self.typing_env[name] = types.TVar() self.typing_env[name] = types.TVar()
@ -103,7 +112,10 @@ class LocalExtractor(algorithm.Visitor):
if self._check_not_in(name, self.nonlocal_, "nonlocal", "global", loc) or \ if self._check_not_in(name, self.nonlocal_, "nonlocal", "global", loc) or \
self._check_not_in(name, self.params, "a parameter", "global", loc): self._check_not_in(name, self.params, "a parameter", "global", loc):
continue continue
self.global_.add(name) self.global_.add(name)
self._assignable(name)
self.env_stack[0][name] = self.typing_env[name]
def visit_Nonlocal(self, node): def visit_Nonlocal(self, node):
for name, loc in zip(node.names, node.name_locs): for name, loc in zip(node.names, node.name_locs):
@ -111,8 +123,9 @@ class LocalExtractor(algorithm.Visitor):
self._check_not_in(name, self.params, "a parameter", "nonlocal", loc): self._check_not_in(name, self.params, "a parameter", "nonlocal", loc):
continue continue
# nonlocal does not search global scope
found = False found = False
for outer_env in reversed(self.env_stack): for outer_env in reversed(self.env_stack[1:]):
if name in outer_env: if name in outer_env:
found = True found = True
break break
@ -164,12 +177,7 @@ class ASTTypedRewriter(algorithm.Transformer):
node = asttyped.ModuleT( node = asttyped.ModuleT(
typing_env=extractor.typing_env, globals_in_scope=extractor.global_, typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
body=node.body, loc=node.loc) body=node.body, loc=node.loc)
return self.generic_visit(node)
try:
self.env_stack.append(node.typing_env)
return self.generic_visit(node)
finally:
self.env_stack.pop()
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
@ -177,7 +185,7 @@ class ASTTypedRewriter(algorithm.Transformer):
node = asttyped.FunctionDefT( node = asttyped.FunctionDefT(
typing_env=extractor.typing_env, globals_in_scope=extractor.global_, typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
return_type=types.TVar(), signature_type=self._find_name(node.name, node.name_loc), return_type=types.TVar(),
name=node.name, args=node.args, returns=node.returns, name=node.name, args=node.args, returns=node.returns,
body=node.body, decorator_list=node.decorator_list, body=node.body, decorator_list=node.decorator_list,
keyword_loc=node.keyword_loc, name_loc=node.name_loc, keyword_loc=node.keyword_loc, name_loc=node.name_loc,
@ -295,6 +303,22 @@ class ASTTypedRewriter(algorithm.Transformer):
finally: finally:
self.env_stack.pop() self.env_stack.pop()
def visit_Lambda(self, node):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
extractor.visit(node)
node = asttyped.LambdaT(
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
type=types.TVar(),
args=node.args, body=node.body,
lambda_loc=node.lambda_loc, colon_loc=node.colon_loc, loc=node.loc)
try:
self.env_stack.append(node.typing_env)
return self.generic_visit(node)
finally:
self.env_stack.pop()
def visit_Raise(self, node): def visit_Raise(self, node):
node = self.generic_visit(node) node = self.generic_visit(node)
if node.cause: if node.cause:
@ -318,7 +342,6 @@ class ASTTypedRewriter(algorithm.Transformer):
visit_DictComp = visit_unsupported visit_DictComp = visit_unsupported
visit_Ellipsis = visit_unsupported visit_Ellipsis = visit_unsupported
visit_GeneratorExp = visit_unsupported visit_GeneratorExp = visit_unsupported
visit_Lambda = visit_unsupported
visit_Set = visit_unsupported visit_Set = visit_unsupported
visit_SetComp = visit_unsupported visit_SetComp = visit_unsupported
visit_Str = visit_unsupported visit_Str = visit_unsupported
@ -363,12 +386,14 @@ class Inferencer(algorithm.Visitor):
diagnostic.Diagnostic("note", diagnostic.Diagnostic("note",
"expression of type {typea}", "expression of type {typea}",
{"typea": printer.name(typea)}, {"typea": printer.name(typea)},
loca), loca)
diagnostic.Diagnostic("note",
"expression of type {typeb}",
{"typeb": printer.name(typeb)},
locb)
] ]
if locb:
notes.append(
diagnostic.Diagnostic("note",
"expression of type {typeb}",
{"typeb": printer.name(typeb)},
locb))
highlights = [locb] if locb else [] highlights = [locb] if locb else []
if e.typea.find() == typea.find() and e.typeb.find() == typeb.find(): if e.typea.find() == typea.find() and e.typeb.find() == typeb.find():
@ -412,7 +437,8 @@ class Inferencer(algorithm.Visitor):
if not types.is_var(object_type): if not types.is_var(object_type):
if node.attr in object_type.attributes: if node.attr in object_type.attributes:
# assumes no free type variables in .attributes # assumes no free type variables in .attributes
node.type.unify(object_type.attributes[node.attr]) # should never fail self._unify(node.type, object_type.attributes[node.attr],
node.loc, None)
else: else:
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"type {type} does not have an attribute '{attr}'", "type {type} does not have an attribute '{attr}'",
@ -433,7 +459,8 @@ class Inferencer(algorithm.Visitor):
self.generic_visit(node) self.generic_visit(node)
self._unify(node.body.type, node.orelse.type, self._unify(node.body.type, node.orelse.type,
node.body.loc, node.orelse.loc) node.body.loc, node.orelse.loc)
node.type.unify(node.body.type) # should never fail self._unify(node.type, node.body.type,
node.loc, None)
def visit_BoolOpT(self, node): def visit_BoolOpT(self, node):
self.generic_visit(node) self.generic_visit(node)
@ -445,10 +472,12 @@ class Inferencer(algorithm.Visitor):
self.generic_visit(node) self.generic_visit(node)
operand_type = node.operand.type.find() operand_type = node.operand.type.find()
if isinstance(node.op, ast.Not): if isinstance(node.op, ast.Not):
node.type.unify(builtins.TBool()) # should never fail self._unify(node.type, builtins.TBool(),
node.loc, None)
elif isinstance(node.op, ast.Invert): elif isinstance(node.op, ast.Invert):
if builtins.is_int(operand_type): if builtins.is_int(operand_type):
node.type.unify(operand_type) # should never fail self._unify(node.type, operand_type,
node.loc, None)
elif not types.is_var(operand_type): elif not types.is_var(operand_type):
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"expected '~' operand to be of integer type, not {type}", "expected '~' operand to be of integer type, not {type}",
@ -457,7 +486,8 @@ class Inferencer(algorithm.Visitor):
self.engine.process(diag) self.engine.process(diag)
else: # UAdd, USub else: # UAdd, USub
if builtins.is_numeric(operand_type): if builtins.is_numeric(operand_type):
node.type.unify(operand_type) # should never fail self._unify(node.type, operand_type,
node.loc, None)
elif not types.is_var(operand_type): elif not types.is_var(operand_type):
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"expected unary '{op}' operand to be of numeric type, not {type}", "expected unary '{op}' operand to be of numeric type, not {type}",
@ -572,7 +602,6 @@ class Inferencer(algorithm.Visitor):
return return
if types.is_tuple(collection.type): if types.is_tuple(collection.type):
# should never fail
return types.TTuple(left.type.find().elts + return types.TTuple(left.type.find().elts +
right.type.find().elts), left.type, right.type right.type.find().elts), left.type, right.type
elif builtins.is_list(collection.type): elif builtins.is_list(collection.type):
@ -629,7 +658,8 @@ class Inferencer(algorithm.Visitor):
return_type, left_type, right_type = coerced return_type, left_type, right_type = coerced
node.left = self._coerce_one(left_type, node.left, other_node=node.right) node.left = self._coerce_one(left_type, node.left, other_node=node.right)
node.right = self._coerce_one(right_type, node.right, other_node=node.left) node.right = self._coerce_one(right_type, node.right, other_node=node.left)
node.type.unify(return_type) # should never fail self._unify(node.type, return_type,
node.loc, None)
def visit_CompareT(self, node): def visit_CompareT(self, node):
self.generic_visit(node) self.generic_visit(node)
@ -665,16 +695,25 @@ class Inferencer(algorithm.Visitor):
print(typ, other_node) print(typ, other_node)
node.left, *node.comparators = \ node.left, *node.comparators = \
[self._coerce_one(typ, operand, other_node) for operand in operands] [self._coerce_one(typ, operand, other_node) for operand in operands]
node.type.unify(builtins.TBool()) self._unify(node.type, builtins.TBool(),
node.loc, None)
def visit_ListCompT(self, node): def visit_ListCompT(self, node):
self.generic_visit(node) self.generic_visit(node)
node.type.unify(builtins.TList(node.elt.type)) # should never fail self._unify(node.type, builtins.TList(node.elt.type),
node.loc, None)
def visit_comprehension(self, node): def visit_comprehension(self, node):
self.generic_visit(node) self.generic_visit(node)
self._unify_collection(element=node.target, collection=node.iter) self._unify_collection(element=node.target, collection=node.iter)
def visit_LambdaT(self, node):
self.generic_visit(node)
signature_type = self._type_from_arguments(node.args, node.body.type)
if signature_type:
self._unify(node.type, signature_type,
node.loc, None)
def visit_Assign(self, node): def visit_Assign(self, node):
self.generic_visit(node) self.generic_visit(node)
if len(node.targets) > 1: if len(node.targets) > 1:
@ -762,6 +801,32 @@ class Inferencer(algorithm.Visitor):
node.context_expr.loc) node.context_expr.loc)
self.engine.process(diag) self.engine.process(diag)
def _type_from_arguments(self, node, ret):
self.generic_visit(node)
for (sigil_loc, vararg) in ((node.star_loc, node.vararg),
(node.dstar_loc, node.kwarg)):
if vararg:
diag = diagnostic.Diagnostic("error",
"variadic arguments are not supported", {},
sigil_loc, [vararg.loc])
self.engine.process(diag)
return
def extract_args(arg_nodes):
args = [(arg_node.arg, arg_node.type) for arg_node in arg_nodes]
return OrderedDict(args)
return types.TFunction(extract_args(node.args[:len(node.args) - len(node.defaults)]),
extract_args(node.args[len(node.defaults):]),
ret)
def visit_arguments(self, node):
self.generic_visit(node)
for arg, default in zip(node.args[len(node.defaults):], node.defaults):
self._unify(arg.type, default.type,
arg.loc, default.loc)
def visit_FunctionDefT(self, node): def visit_FunctionDefT(self, node):
old_function, self.function = self.function, node old_function, self.function = self.function, node
old_in_loop, self.in_loop = self.in_loop, False old_in_loop, self.in_loop = self.in_loop, False
@ -769,6 +834,18 @@ class Inferencer(algorithm.Visitor):
self.function = old_function self.function = old_function
self.in_loop = old_in_loop self.in_loop = old_in_loop
if any(node.decorator_list):
diag = diagnostic.Diagnostic("error",
"decorators are not supported", {},
node.at_locs[0], [node.decorator_list[0].loc])
self.engine.process(diag)
return
signature_type = self._type_from_arguments(node.args, node.return_type)
if signature_type:
self._unify(node.signature_type, signature_type,
node.name_loc, None)
def visit_Return(self, node): def visit_Return(self, node):
if not self.function: if not self.function:
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",

View File

@ -1,29 +1,35 @@
# RUN: %python -m artiq.py2llvm.typing +diag %s >%t # RUN: %python -m artiq.py2llvm.typing +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t # RUN: OutputCheck %s --file-to-check=%t
x = 1
def a(): def a():
# CHECK-L: ${LINE:+1}: error: cannot declare name 'x' as nonlocal: it is not bound in any outer scope # CHECK-L: ${LINE:+1}: error: cannot declare name 'x' as nonlocal: it is not bound in any outer scope
nonlocal x nonlocal x
x = 1 def f():
def b(): y = 1
nonlocal x def b():
# CHECK-L: ${LINE:+1}: error: name 'x' cannot be nonlocal and global simultaneously nonlocal y
global x # CHECK-L: ${LINE:+1}: error: name 'y' cannot be nonlocal and global simultaneously
global y
def c(): def c():
global x global y
# CHECK-L: ${LINE:+1}: error: name 'x' cannot be global and nonlocal simultaneously # CHECK-L: ${LINE:+1}: error: name 'y' cannot be global and nonlocal simultaneously
nonlocal x nonlocal y
def d(x): def d(y):
# CHECK-L: ${LINE:+1}: error: name 'x' cannot be a parameter and global simultaneously # CHECK-L: ${LINE:+1}: error: name 'y' cannot be a parameter and global simultaneously
global x global y
def e(x): def e(y):
# CHECK-L: ${LINE:+1}: error: name 'x' cannot be a parameter and nonlocal simultaneously # CHECK-L: ${LINE:+1}: error: name 'y' cannot be a parameter and nonlocal simultaneously
nonlocal x nonlocal y
# CHECK-L: ${LINE:+1}: error: duplicate parameter 'x' # CHECK-L: ${LINE:+1}: error: duplicate parameter 'x'
def f(x, x): def f(x, x):
pass pass
# CHECK-L: ${LINE:+1}: error: variadic arguments are not supported
def g(*x):
pass

View File

@ -0,0 +1,8 @@
# RUN: %python -m artiq.py2llvm.typing %s >%t
# RUN: OutputCheck %s --file-to-check=%t
def f():
global x
x = 1
# CHECK-L: [x:int(width='a)]
[x]

View File

@ -50,3 +50,7 @@ not 1
[x for x in [1]] [x for x in [1]]
# CHECK-L: [x:int(width='j) for x:int(width='j) in [1:int(width='j)]:list(elt=int(width='j))]:list(elt=int(width='j)) # CHECK-L: [x:int(width='j) for x:int(width='j) in [1:int(width='j)]:list(elt=int(width='j))]:list(elt=int(width='j))
lambda x, y=1: x
# CHECK-L: lambda x:'a, y:int(width='b)=1:int(width='b): x:'a:(x:'a, ?y:int(width='b))->'a