From 3a1fc729cfd5559399cde44679933ea42edc081e Mon Sep 17 00:00:00 2001 From: whitequark Date: Thu, 13 Apr 2017 08:01:46 +0000 Subject: [PATCH] compiler: refactor type annotations recognizing in kernels. The new implementation is much more generic, more robust, and shares code with the same for syscalls as well as RPCs. Fixes #713. --- artiq/compiler/embedding.py | 233 ++++++++++++------------- artiq/test/lit/embedding/annotation.py | 17 ++ 2 files changed, 131 insertions(+), 119 deletions(-) create mode 100644 artiq/test/lit/embedding/annotation.py diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 121158c57..2f7edc96c 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -423,49 +423,9 @@ class StitchingASTTypedRewriter(ASTTypedRewriter): self.host_environment = host_environment self.quote = quote - def match_annotation(self, annot): - if isinstance(annot, ast.Name): - if annot.id == "TNone": - return builtins.TNone() - if annot.id == "TBool": - return builtins.TBool() - if annot.id == "TInt32": - return builtins.TInt(types.TValue(32)) - if annot.id == "TInt64": - return builtins.TInt(types.TValue(64)) - if annot.id == "TFloat": - return builtins.TFloat() - if annot.id == "TStr": - return builtins.TStr() - if annot.id == "TRange32": - return builtins.TRange(builtins.TInt(types.TValue(32))) - if annot.id == "TRange64": - return builtins.TRange(builtins.TInt(types.TValue(64))) - if annot.id == "TVar": - return types.TVar() - elif (isinstance(annot, ast.Call) and - annot.keywords == [] and - annot.starargs is None and - annot.kwargs is None and - isinstance(annot.func, ast.Name)): - if annot.func.id == "TList" and len(annot.args) == 1: - elttyp = self.match_annotation(annot.args[0]) - if elttyp is not None: - return builtins.TList(elttyp) - else: - return None - - if annot is not None: - diag = diagnostic.Diagnostic("error", - "unrecognized type annotation", {}, - annot.loc) - self.engine.process(diag) - def visit_arg(self, node): typ = self._find_name(node.arg, node.loc) - annottyp = self.match_annotation(node.annotation) - if annottyp is not None: - typ.unify(annottyp) + # ignore annotations; these are handled in _quote_function return asttyped.argT(type=typ, arg=node.arg, annotation=None, arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) @@ -795,6 +755,79 @@ class Stitcher: value_map=self.value_map, quote_function=self._quote_function) + def _function_loc(self, function): + filename = function.__code__.co_filename + line = function.__code__.co_firstlineno + name = function.__code__.co_name + + source_line = linecache.getline(filename, line).lstrip() + while source_line.startswith("@") or source_line == "": + line += 1 + source_line = linecache.getline(filename, line).lstrip() + + if "" in function.__qualname__: + column = 0 # can't get column of lambda + else: + column = re.search("def", source_line).start(0) + source_buffer = source.Buffer(source_line, filename, line) + return source.Range(source_buffer, column, column) + + def _call_site_note(self, call_loc, fn_kind): + if call_loc: + if fn_kind == 'syscall': + return [diagnostic.Diagnostic("note", + "in system call here", {}, + call_loc)] + elif fn_kind == 'rpc': + return [diagnostic.Diagnostic("note", + "in function called remotely here", {}, + call_loc)] + elif fn_kind == 'kernel': + return [diagnostic.Diagnostic("note", + "in kernel function here", {}, + call_loc)] + else: + assert False + else: + return [] + + def _type_of_param(self, function, loc, param, fn_kind): + if param.annotation is not inspect.Parameter.empty: + # Type specified explicitly. + return self._extract_annot(function, param.annotation, + "argument '{}'".format(param.name), loc, + fn_kind) + elif fn_kind == 'syscall': + # Syscalls must be entirely annotated. + diag = diagnostic.Diagnostic("error", + "system call argument '{argument}' must have a type annotation", + {"argument": param.name}, + self._function_loc(function), + notes=self._call_site_note(loc, fn_kind)) + self.engine.process(diag) + elif fn_kind == 'rpc' and param.default is not inspect.Parameter.empty: + notes = [] + notes.append(diagnostic.Diagnostic("note", + "expanded from here while trying to infer a type for an" + " unannotated optional argument '{argument}' from its default value", + {"argument": param.name}, + self._function_loc(function))) + if loc is not None: + notes.append(self._call_site_note(loc, fn_kind)) + + with self.engine.context(*notes): + # Try and infer the type from the default value. + # This is tricky, because the default value might not have + # a well-defined type in APython. + # In this case, we bail out, but mention why we do it. + ast = self._quote(param.default, None) + Inferencer(engine=self.engine).visit(ast) + IntMonomorphizer(engine=self.engine).visit(ast) + return ast.type + else: + # Let the rest of the program decide. + return types.TVar() + def _quote_embedded_function(self, function, flags): if isinstance(function, SpecializedFunction): host_function = function.host_function @@ -811,6 +844,34 @@ class Stitcher: module_name = embedded_function.__globals__['__name__'] first_line = embedded_function.__code__.co_firstlineno + # Extract function annotation. + signature = inspect.signature(embedded_function) + loc = self._function_loc(embedded_function) + + arg_types = OrderedDict() + optarg_types = OrderedDict() + for param in signature.parameters.values(): + if param.kind == inspect.Parameter.VAR_POSITIONAL or \ + param.kind == inspect.Parameter.VAR_KEYWORD: + diag = diagnostic.Diagnostic("error", + "variadic arguments are not supported; '{argument}' is variadic", + {"argument": param.name}, + self._function_loc(function), + notes=self._call_site_note(loc, fn_kind='kernel')) + self.engine.process(diag) + + arg_type = self._type_of_param(function, loc, param, fn_kind='kernel') + if param.default is inspect.Parameter.empty: + arg_types[param.name] = arg_type + else: + optarg_types[param.name] = arg_type + + if signature.return_annotation is not inspect.Signature.empty: + ret_type = self._extract_annot(function, signature.return_annotation, + "return type", loc, fn_kind='kernel') + else: + ret_type = types.TVar() + # Extract function environment. host_environment = dict() host_environment.update(embedded_function.__globals__) @@ -846,9 +907,9 @@ class Stitcher: # can handle quoting it. self.embedding_map.store_function(function, function_node.name) - # Memoize the function type before typing it to handle recursive + # Fill in the function type before typing it to handle recursive # invocations. - self.functions[function] = types.TVar() + self.functions[function] = types.TFunction(arg_types, optarg_types, ret_type) # Rewrite into typed form. asttyped_rewriter = StitchingASTTypedRewriter( @@ -866,86 +927,19 @@ class Stitcher: return function_node - def _function_loc(self, function): - filename = function.__code__.co_filename - line = function.__code__.co_firstlineno - name = function.__code__.co_name - - source_line = linecache.getline(filename, line).lstrip() - while source_line.startswith("@") or source_line == "": - line += 1 - source_line = linecache.getline(filename, line).lstrip() - - if "" in function.__qualname__: - column = 0 # can't get column of lambda - else: - column = re.search("def", source_line).start(0) - source_buffer = source.Buffer(source_line, filename, line) - return source.Range(source_buffer, column, column) - - def _call_site_note(self, call_loc, is_syscall): - if call_loc: - if is_syscall: - return [diagnostic.Diagnostic("note", - "in system call here", {}, - call_loc)] - else: - return [diagnostic.Diagnostic("note", - "in function called remotely here", {}, - call_loc)] - else: - return [] - - def _extract_annot(self, function, annot, kind, call_loc, is_syscall): + def _extract_annot(self, function, annot, kind, call_loc, fn_kind): if not isinstance(annot, types.Type): diag = diagnostic.Diagnostic("error", "type annotation for {kind}, '{annot}', is not an ARTIQ type", {"kind": kind, "annot": repr(annot)}, self._function_loc(function), - notes=self._call_site_note(call_loc, is_syscall)) + notes=self._call_site_note(call_loc, fn_kind)) self.engine.process(diag) return types.TVar() else: return annot - def _type_of_param(self, function, loc, param, is_syscall): - if param.annotation is not inspect.Parameter.empty: - # Type specified explicitly. - return self._extract_annot(function, param.annotation, - "argument '{}'".format(param.name), loc, - is_syscall) - elif is_syscall: - # Syscalls must be entirely annotated. - diag = diagnostic.Diagnostic("error", - "system call argument '{argument}' must have a type annotation", - {"argument": param.name}, - self._function_loc(function), - notes=self._call_site_note(loc, is_syscall)) - self.engine.process(diag) - elif param.default is not inspect.Parameter.empty: - notes = [] - notes.append(diagnostic.Diagnostic("note", - "expanded from here while trying to infer a type for an" - " unannotated optional argument '{argument}' from its default value", - {"argument": param.name}, - self._function_loc(function))) - if loc is not None: - notes.append(self._call_site_note(loc, is_syscall)) - - with self.engine.context(*notes): - # Try and infer the type from the default value. - # This is tricky, because the default value might not have - # a well-defined type in APython. - # In this case, we bail out, but mention why we do it. - ast = self._quote(param.default, None) - Inferencer(engine=self.engine).visit(ast) - IntMonomorphizer(engine=self.engine).visit(ast) - return ast.type - else: - # Let the rest of the program decide. - return types.TVar() - def _quote_syscall(self, function, loc): signature = inspect.signature(function) @@ -957,27 +951,28 @@ class Stitcher: "system calls must only use positional arguments; '{argument}' isn't", {"argument": param.name}, self._function_loc(function), - notes=self._call_site_note(loc, is_syscall=True)) + notes=self._call_site_note(loc, fn_kind='syscall')) self.engine.process(diag) if param.default is inspect.Parameter.empty: - arg_types[param.name] = self._type_of_param(function, loc, param, is_syscall=True) + arg_types[param.name] = self._type_of_param(function, loc, param, + fn_kind='syscall') else: diag = diagnostic.Diagnostic("error", "system call argument '{argument}' must not have a default value", {"argument": param.name}, self._function_loc(function), - notes=self._call_site_note(loc, is_syscall=True)) + notes=self._call_site_note(loc, fn_kind='syscall')) self.engine.process(diag) if signature.return_annotation is not inspect.Signature.empty: ret_type = self._extract_annot(function, signature.return_annotation, - "return type", loc, is_syscall=True) + "return type", loc, fn_kind='syscall') else: diag = diagnostic.Diagnostic("error", "system call must have a return type annotation", {}, self._function_loc(function), - notes=self._call_site_note(loc, is_syscall=True)) + notes=self._call_site_note(loc, fn_kind='syscall')) self.engine.process(diag) ret_type = types.TVar() @@ -1005,7 +1000,7 @@ class Stitcher: signature = inspect.signature(host_function.__func__) if signature.return_annotation is not inspect.Signature.empty: ret_type = self._extract_annot(host_function, signature.return_annotation, - "return type", loc, is_syscall=False) + "return type", loc, fn_kind='rpc') else: assert False @@ -1077,7 +1072,7 @@ class Stitcher: diag = diagnostic.Diagnostic("fatal", "this function cannot be called as an RPC", {}, self._function_loc(host_function), - notes=self._call_site_note(loc, is_syscall=False)) + notes=self._call_site_note(loc, fn_kind='rpc')) self.engine.process(diag) else: assert False diff --git a/artiq/test/lit/embedding/annotation.py b/artiq/test/lit/embedding/annotation.py new file mode 100644 index 000000000..7e9c05d9f --- /dev/null +++ b/artiq/test/lit/embedding/annotation.py @@ -0,0 +1,17 @@ +# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s +# RUN: OutputCheck %s --file-to-check=%t.ll + +from artiq.language.core import * +from artiq.language.types import * + +# CHECK: i64 @_Z13testbench.foozz\(i64 %ARG.x, \{ i1, i64 \} %ARG.y\) + +@kernel +def foo(x: TInt64, y: TInt64 = 1) -> TInt64: + print(x+y) + return x+y + +@kernel +def entrypoint(): + print(foo(0, 2)) + print(foo(1, 3))