From 20e0e69358481c11e92ed9ae7ca3c49c5c85b6a5 Mon Sep 17 00:00:00 2001 From: whitequark Date: Mon, 15 Jun 2015 11:30:50 +0300 Subject: [PATCH] Add support for function types and LambdaT. Also fix scoping of Nonlocal. --- artiq/py2llvm/asttyped.py | 2 +- artiq/py2llvm/types.py | 64 +++++++++++- artiq/py2llvm/typing.py | 127 +++++++++++++++++++----- lit-test/py2llvm/typing/error_locals.py | 36 ++++--- lit-test/py2llvm/typing/scoping.py | 8 ++ lit-test/py2llvm/typing/unify.py | 4 + 6 files changed, 195 insertions(+), 46 deletions(-) create mode 100644 lit-test/py2llvm/typing/scoping.py diff --git a/artiq/py2llvm/asttyped.py b/artiq/py2llvm/asttyped.py index e96edd641..a9762ac4e 100644 --- a/artiq/py2llvm/asttyped.py +++ b/artiq/py2llvm/asttyped.py @@ -29,7 +29,7 @@ class argT(ast.arg, commontyped): class ClassDefT(ast.ClassDef, scoped): pass class FunctionDefT(ast.FunctionDef, scoped): - pass + _types = ("signature_type",) class ModuleT(ast.Module, scoped): pass diff --git a/artiq/py2llvm/types.py b/artiq/py2llvm/types.py index 19ad56d48..84823ddaa 100644 --- a/artiq/py2llvm/types.py +++ b/artiq/py2llvm/types.py @@ -58,7 +58,7 @@ class TVar(Type): def __repr__(self): if self.parent is self: - return "TVar(%d)" % id(self) + return "" % id(self) else: return repr(self.find()) @@ -88,7 +88,7 @@ class TMono(Type): raise UnificationError(self, other) def __repr__(self): - return "TMono(%s, %s)" % (repr(self.name), repr(self.params)) + return "py2llvm.types.TMono(%s, %s)" % (repr(self.name), repr(self.params)) def __getitem__(self, param): return self.params[param] @@ -102,7 +102,11 @@ class TMono(Type): return not (self == other) class TTuple(Type): - """A tuple type.""" + """ + A tuple type. + + :ivar elts: (list of :class:`Type`) elements + """ attributes = {} @@ -122,7 +126,7 @@ class TTuple(Type): raise UnificationError(self, other) def __repr__(self): - return "TTuple(%s)" % (", ".join(map(repr, self.elts))) + return "py2llvm.types.TTuple(%s)" % repr(self.elts) def __eq__(self, other): return isinstance(other, TTuple) and \ @@ -131,6 +135,51 @@ class TTuple(Type): def __ne__(self, other): return not (self == other) +class TFunction(Type): + """ + A function type. + + :ivar args: (:class:`collections.OrderedDict` of string to :class:`Type`) + mandatory arguments + :ivar optargs: (:class:`collections.OrderedDict` of string to :class:`Type`) + optional arguments + :ivar ret: (:class:`Type`) + return type + """ + + attributes = {} + + def __init__(self, args, optargs, ret): + self.args, self.optargs, self.ret = args, optargs, ret + + def find(self): + return self + + def unify(self, other): + if isinstance(other, TFunction) and \ + self.args.keys() == other.args.keys() and \ + self.optargs.keys() == other.optargs.keys(): + for selfarg, otherarg in zip(self.args.values() + self.optargs.values(), + other.args.values() + other.optargs.values()): + selfarg.unify(otherarg) + self.ret.unify(other.ret) + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(self, other) + + def __repr__(self): + return "py2llvm.types.TFunction(%s, %s, %s)" % \ + (repr(self.args), repr(self.optargs), repr(self.ret)) + + def __eq__(self, other): + return isinstance(other, TFunction) and \ + self.args == other.args and \ + self.optargs == other.optargs + + def __ne__(self, other): + return not (self == other) + class TValue(Type): """ A type-level value (such as the integer denoting width of @@ -150,7 +199,7 @@ class TValue(Type): raise UnificationError(self, other) def __repr__(self): - return "TValue(%s)" % repr(self.value) + return "py2llvm.types.TValue(%s)" % repr(self.value) def __eq__(self, other): return isinstance(other, TValue) and \ @@ -216,6 +265,11 @@ class TypePrinter(object): return "(%s,)" % self.name(typ.elts[0]) else: return "(%s)" % ", ".join(list(map(self.name, typ.elts))) + elif isinstance(typ, TFunction): + args = [] + args += [ "%s:%s" % (arg, self.name(typ.args[arg])) for arg in typ.args] + args += ["?%s:%s" % (arg, self.name(typ.optargs[arg])) for arg in typ.optargs] + return "(%s)->%s" % (", ".join(args), self.name(typ.ret)) elif isinstance(typ, TValue): return repr(typ.value) else: diff --git a/artiq/py2llvm/typing.py b/artiq/py2llvm/typing.py index 29981389a..6b10cc795 100644 --- a/artiq/py2llvm/typing.py +++ b/artiq/py2llvm/typing.py @@ -1,4 +1,5 @@ from pythonparser import source, ast, algorithm, diagnostic, parse_buffer +from collections import OrderedDict from . import asttyped, types, builtins # This visitor will be called for every node with a scope, @@ -23,6 +24,9 @@ class LocalExtractor(algorithm.Visitor): # parameters can't be declared as global or nonlocal self.params = set() + if len(self.env_stack) == 0: + self.env_stack.append(self.typing_env) + def visit_in_assign(self, node): try: self.in_assign = True @@ -58,14 +62,19 @@ class LocalExtractor(algorithm.Visitor): self.in_root = True self.generic_visit(node) - visit_ClassDef = visit_root # don't look at inner scopes - visit_FunctionDef = visit_root + visit_Module = visit_root # don't look at inner scopes + visit_ClassDef = visit_root visit_Lambda = visit_root visit_DictComp = visit_root visit_ListComp = visit_root visit_SetComp = visit_root visit_GeneratorExp = visit_root + def visit_FunctionDef(self, node): + if self.in_root: + self._assignable(node.name) + self.visit_root(node) + def _assignable(self, name): if name not in self.typing_env and name not in self.nonlocal_: self.typing_env[name] = types.TVar() @@ -103,7 +112,10 @@ class LocalExtractor(algorithm.Visitor): if self._check_not_in(name, self.nonlocal_, "nonlocal", "global", loc) or \ self._check_not_in(name, self.params, "a parameter", "global", loc): continue + self.global_.add(name) + self._assignable(name) + self.env_stack[0][name] = self.typing_env[name] def visit_Nonlocal(self, node): for name, loc in zip(node.names, node.name_locs): @@ -111,8 +123,9 @@ class LocalExtractor(algorithm.Visitor): self._check_not_in(name, self.params, "a parameter", "nonlocal", loc): continue + # nonlocal does not search global scope found = False - for outer_env in reversed(self.env_stack): + for outer_env in reversed(self.env_stack[1:]): if name in outer_env: found = True break @@ -164,12 +177,7 @@ class ASTTypedRewriter(algorithm.Transformer): node = asttyped.ModuleT( typing_env=extractor.typing_env, globals_in_scope=extractor.global_, body=node.body, loc=node.loc) - - try: - self.env_stack.append(node.typing_env) - return self.generic_visit(node) - finally: - self.env_stack.pop() + return self.generic_visit(node) def visit_FunctionDef(self, node): extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) @@ -177,7 +185,7 @@ class ASTTypedRewriter(algorithm.Transformer): node = asttyped.FunctionDefT( typing_env=extractor.typing_env, globals_in_scope=extractor.global_, - return_type=types.TVar(), + signature_type=self._find_name(node.name, node.name_loc), return_type=types.TVar(), name=node.name, args=node.args, returns=node.returns, body=node.body, decorator_list=node.decorator_list, keyword_loc=node.keyword_loc, name_loc=node.name_loc, @@ -295,6 +303,22 @@ class ASTTypedRewriter(algorithm.Transformer): finally: self.env_stack.pop() + def visit_Lambda(self, node): + extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) + extractor.visit(node) + + node = asttyped.LambdaT( + typing_env=extractor.typing_env, globals_in_scope=extractor.global_, + type=types.TVar(), + args=node.args, body=node.body, + lambda_loc=node.lambda_loc, colon_loc=node.colon_loc, loc=node.loc) + + try: + self.env_stack.append(node.typing_env) + return self.generic_visit(node) + finally: + self.env_stack.pop() + def visit_Raise(self, node): node = self.generic_visit(node) if node.cause: @@ -318,7 +342,6 @@ class ASTTypedRewriter(algorithm.Transformer): visit_DictComp = visit_unsupported visit_Ellipsis = visit_unsupported visit_GeneratorExp = visit_unsupported - visit_Lambda = visit_unsupported visit_Set = visit_unsupported visit_SetComp = visit_unsupported visit_Str = visit_unsupported @@ -363,12 +386,14 @@ class Inferencer(algorithm.Visitor): diagnostic.Diagnostic("note", "expression of type {typea}", {"typea": printer.name(typea)}, - loca), - diagnostic.Diagnostic("note", - "expression of type {typeb}", - {"typeb": printer.name(typeb)}, - locb) + loca) ] + if locb: + notes.append( + diagnostic.Diagnostic("note", + "expression of type {typeb}", + {"typeb": printer.name(typeb)}, + locb)) highlights = [locb] if locb else [] if e.typea.find() == typea.find() and e.typeb.find() == typeb.find(): @@ -412,7 +437,8 @@ class Inferencer(algorithm.Visitor): if not types.is_var(object_type): if node.attr in object_type.attributes: # assumes no free type variables in .attributes - node.type.unify(object_type.attributes[node.attr]) # should never fail + self._unify(node.type, object_type.attributes[node.attr], + node.loc, None) else: diag = diagnostic.Diagnostic("error", "type {type} does not have an attribute '{attr}'", @@ -433,7 +459,8 @@ class Inferencer(algorithm.Visitor): self.generic_visit(node) self._unify(node.body.type, node.orelse.type, node.body.loc, node.orelse.loc) - node.type.unify(node.body.type) # should never fail + self._unify(node.type, node.body.type, + node.loc, None) def visit_BoolOpT(self, node): self.generic_visit(node) @@ -445,10 +472,12 @@ class Inferencer(algorithm.Visitor): self.generic_visit(node) operand_type = node.operand.type.find() if isinstance(node.op, ast.Not): - node.type.unify(builtins.TBool()) # should never fail + self._unify(node.type, builtins.TBool(), + node.loc, None) elif isinstance(node.op, ast.Invert): if builtins.is_int(operand_type): - node.type.unify(operand_type) # should never fail + self._unify(node.type, operand_type, + node.loc, None) elif not types.is_var(operand_type): diag = diagnostic.Diagnostic("error", "expected '~' operand to be of integer type, not {type}", @@ -457,7 +486,8 @@ class Inferencer(algorithm.Visitor): self.engine.process(diag) else: # UAdd, USub if builtins.is_numeric(operand_type): - node.type.unify(operand_type) # should never fail + self._unify(node.type, operand_type, + node.loc, None) elif not types.is_var(operand_type): diag = diagnostic.Diagnostic("error", "expected unary '{op}' operand to be of numeric type, not {type}", @@ -572,7 +602,6 @@ class Inferencer(algorithm.Visitor): return if types.is_tuple(collection.type): - # should never fail return types.TTuple(left.type.find().elts + right.type.find().elts), left.type, right.type elif builtins.is_list(collection.type): @@ -629,7 +658,8 @@ class Inferencer(algorithm.Visitor): return_type, left_type, right_type = coerced node.left = self._coerce_one(left_type, node.left, other_node=node.right) node.right = self._coerce_one(right_type, node.right, other_node=node.left) - node.type.unify(return_type) # should never fail + self._unify(node.type, return_type, + node.loc, None) def visit_CompareT(self, node): self.generic_visit(node) @@ -665,16 +695,25 @@ class Inferencer(algorithm.Visitor): print(typ, other_node) node.left, *node.comparators = \ [self._coerce_one(typ, operand, other_node) for operand in operands] - node.type.unify(builtins.TBool()) + self._unify(node.type, builtins.TBool(), + node.loc, None) def visit_ListCompT(self, node): self.generic_visit(node) - node.type.unify(builtins.TList(node.elt.type)) # should never fail + self._unify(node.type, builtins.TList(node.elt.type), + node.loc, None) def visit_comprehension(self, node): self.generic_visit(node) self._unify_collection(element=node.target, collection=node.iter) + def visit_LambdaT(self, node): + self.generic_visit(node) + signature_type = self._type_from_arguments(node.args, node.body.type) + if signature_type: + self._unify(node.type, signature_type, + node.loc, None) + def visit_Assign(self, node): self.generic_visit(node) if len(node.targets) > 1: @@ -762,6 +801,32 @@ class Inferencer(algorithm.Visitor): node.context_expr.loc) self.engine.process(diag) + def _type_from_arguments(self, node, ret): + self.generic_visit(node) + + for (sigil_loc, vararg) in ((node.star_loc, node.vararg), + (node.dstar_loc, node.kwarg)): + if vararg: + diag = diagnostic.Diagnostic("error", + "variadic arguments are not supported", {}, + sigil_loc, [vararg.loc]) + self.engine.process(diag) + return + + def extract_args(arg_nodes): + args = [(arg_node.arg, arg_node.type) for arg_node in arg_nodes] + return OrderedDict(args) + + return types.TFunction(extract_args(node.args[:len(node.args) - len(node.defaults)]), + extract_args(node.args[len(node.defaults):]), + ret) + + def visit_arguments(self, node): + self.generic_visit(node) + for arg, default in zip(node.args[len(node.defaults):], node.defaults): + self._unify(arg.type, default.type, + arg.loc, default.loc) + def visit_FunctionDefT(self, node): old_function, self.function = self.function, node old_in_loop, self.in_loop = self.in_loop, False @@ -769,6 +834,18 @@ class Inferencer(algorithm.Visitor): self.function = old_function self.in_loop = old_in_loop + if any(node.decorator_list): + diag = diagnostic.Diagnostic("error", + "decorators are not supported", {}, + node.at_locs[0], [node.decorator_list[0].loc]) + self.engine.process(diag) + return + + signature_type = self._type_from_arguments(node.args, node.return_type) + if signature_type: + self._unify(node.signature_type, signature_type, + node.name_loc, None) + def visit_Return(self, node): if not self.function: diag = diagnostic.Diagnostic("error", diff --git a/lit-test/py2llvm/typing/error_locals.py b/lit-test/py2llvm/typing/error_locals.py index 7d07d699b..836029b55 100644 --- a/lit-test/py2llvm/typing/error_locals.py +++ b/lit-test/py2llvm/typing/error_locals.py @@ -1,29 +1,35 @@ # RUN: %python -m artiq.py2llvm.typing +diag %s >%t # RUN: OutputCheck %s --file-to-check=%t +x = 1 def a(): # CHECK-L: ${LINE:+1}: error: cannot declare name 'x' as nonlocal: it is not bound in any outer scope nonlocal x -x = 1 -def b(): - nonlocal x - # CHECK-L: ${LINE:+1}: error: name 'x' cannot be nonlocal and global simultaneously - global x +def f(): + y = 1 + def b(): + nonlocal y + # CHECK-L: ${LINE:+1}: error: name 'y' cannot be nonlocal and global simultaneously + global y -def c(): - global x - # CHECK-L: ${LINE:+1}: error: name 'x' cannot be global and nonlocal simultaneously - nonlocal x + def c(): + global y + # CHECK-L: ${LINE:+1}: error: name 'y' cannot be global and nonlocal simultaneously + nonlocal y -def d(x): - # CHECK-L: ${LINE:+1}: error: name 'x' cannot be a parameter and global simultaneously - global x + def d(y): + # CHECK-L: ${LINE:+1}: error: name 'y' cannot be a parameter and global simultaneously + global y -def e(x): - # CHECK-L: ${LINE:+1}: error: name 'x' cannot be a parameter and nonlocal simultaneously - nonlocal x + def e(y): + # CHECK-L: ${LINE:+1}: error: name 'y' cannot be a parameter and nonlocal simultaneously + nonlocal y # CHECK-L: ${LINE:+1}: error: duplicate parameter 'x' def f(x, x): pass + +# CHECK-L: ${LINE:+1}: error: variadic arguments are not supported +def g(*x): + pass diff --git a/lit-test/py2llvm/typing/scoping.py b/lit-test/py2llvm/typing/scoping.py new file mode 100644 index 000000000..96eefeb14 --- /dev/null +++ b/lit-test/py2llvm/typing/scoping.py @@ -0,0 +1,8 @@ +# RUN: %python -m artiq.py2llvm.typing %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + global x + x = 1 +# CHECK-L: [x:int(width='a)] +[x] diff --git a/lit-test/py2llvm/typing/unify.py b/lit-test/py2llvm/typing/unify.py index 94207cc1f..ba441527f 100644 --- a/lit-test/py2llvm/typing/unify.py +++ b/lit-test/py2llvm/typing/unify.py @@ -50,3 +50,7 @@ not 1 [x for x in [1]] # CHECK-L: [x:int(width='j) for x:int(width='j) in [1:int(width='j)]:list(elt=int(width='j))]:list(elt=int(width='j)) + +lambda x, y=1: x +# CHECK-L: lambda x:'a, y:int(width='b)=1:int(width='b): x:'a:(x:'a, ?y:int(width='b))->'a +