From 50448ef554a53b14ba2d5e1260d25b5780459598 Mon Sep 17 00:00:00 2001 From: whitequark Date: Fri, 7 Aug 2015 13:20:29 +0300 Subject: [PATCH] Add support for referring to host values in embedded functions. --- artiq/compiler/embedding.py | 124 ++++++++++++++---- artiq/compiler/module.py | 5 +- .../compiler/transforms/artiq_ir_generator.py | 2 +- .../compiler/transforms/asttyped_rewriter.py | 4 +- 4 files changed, 109 insertions(+), 26 deletions(-) diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index a719c3a31..74498d436 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -5,7 +5,7 @@ the references to the host objects and translates the functions annotated as ``@kernel`` when they are referenced. """ -import inspect +import inspect, os from pythonparser import ast, source, diagnostic, parse_buffer from . import types, builtins, asttyped, prelude from .transforms import ASTTypedRewriter, Inferencer @@ -28,11 +28,12 @@ class ASTSynthesizer: def quote(self, value): """Construct an AST fragment equal to `value`.""" - if value in (None, True, False): - if node.value is True or node.value is False: - typ = builtins.TBool() - elif node.value is None: - typ = builtins.TNone() + if value is None: + typ = builtins.TNone() + return asttyped.NameConstantT(value=value, type=typ, + loc=self._add(repr(value))) + elif value is True or value is False: + typ = builtins.TBool() return asttyped.NameConstantT(value=value, type=typ, loc=self._add(repr(value))) elif isinstance(value, (int, float)): @@ -45,12 +46,12 @@ class ASTSynthesizer: elif isinstance(value, list): begin_loc = self._add("[") elts = [] - for index, elt in value: + for index, elt in enumerate(value): elts.append(self.quote(elt)) if index < len(value) - 1: self._add(", ") end_loc = self._add("]") - return asttyped.ListT(elts=elts, ctx=None, type=types.TVar(), + return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(), begin_loc=begin_loc, end_loc=end_loc, loc=begin_loc.join(end_loc)) else: @@ -99,7 +100,43 @@ class ASTSynthesizer: loc=name_loc.join(end_loc)) class StitchingASTTypedRewriter(ASTTypedRewriter): - pass + def __init__(self, engine, prelude, globals, host_environment, quote_function): + super().__init__(engine, prelude) + self.globals = globals + self.env_stack.append(self.globals) + + self.host_environment = host_environment + self.quote_function = quote_function + + def visit_Name(self, node): + typ = super()._try_find_name(node.id) + if typ is not None: + # Value from device environment. + return asttyped.NameT(type=typ, id=node.id, ctx=node.ctx, + loc=node.loc) + else: + # Try to find this value in the host environment and quote it. + if node.id in self.host_environment: + value = self.host_environment[node.id] + if inspect.isfunction(value): + # It's a function. We need to translate the function and insert + # a reference to it. + function_name = self.quote_function(value) + return asttyped.NameT(id=function_name, ctx=None, + type=self.globals[function_name], + loc=node.loc) + + else: + # It's just a value. Quote it. + synthesizer = ASTSynthesizer() + node = synthesizer.quote(value) + synthesizer.finalize() + return node + else: + diag = diagnostic.Diagnostic("fatal", + "name '{name}' is not bound to anything", {"name":node.id}, + node.loc) + self.engine.process(diag) class Stitcher: def __init__(self, engine=None): @@ -108,24 +145,30 @@ class Stitcher: else: self.engine = engine - self.asttyped_rewriter = StitchingASTTypedRewriter( - engine=self.engine, globals=prelude.globals()) - self.inferencer = Inferencer(engine=self.engine) + self.name = "" + self.typedtree = [] + self.prelude = prelude.globals() + self.globals = {} - self.name = "stitched" - self.typedtree = None - self.globals = self.asttyped_rewriter.globals + self.functions = {} self.rpc_map = {} def _iterate(self): + inferencer = Inferencer(engine=self.engine) + # Iterate inference to fixed point. self.inference_finished = False while not self.inference_finished: self.inference_finished = True - self.inferencer.visit(self.typedtree) + inferencer.visit(self.typedtree) - def _parse_embedded_function(self, function): + # After we have found all functions, synthesize a module to hold them. + self.typedtree = asttyped.ModuleT( + typing_env=self.globals, globals_in_scope=set(), + body=self.typedtree, loc=None) + + def _quote_embedded_function(self, function): if not hasattr(function, "artiq_embedded"): raise ValueError("{} is not an embedded function".format(repr(function))) @@ -133,25 +176,62 @@ class Stitcher: embedded_function = function.artiq_embedded.function source_code = inspect.getsource(embedded_function) filename = embedded_function.__code__.co_filename + module_name, _ = os.path.splitext(os.path.basename(filename)) first_line = embedded_function.__code__.co_firstlineno + # Extract function environment. + host_environment = dict() + host_environment.update(embedded_function.__globals__) + cells = embedded_function.__closure__ + cell_names = embedded_function.__code__.co_freevars + host_environment.update({var: cells[index] for index, var in enumerate(cell_names)}) + # Parse. source_buffer = source.Buffer(source_code, filename, first_line) parsetree, comments = parse_buffer(source_buffer, engine=self.engine) + function_node = parsetree.body[0] + + # Mangle the name, since we put everything into a single module. + function_node.name = "{}.{}".format(module_name, function_node.name) + + # Normally, LocalExtractor would populate the typing environment + # of the module with the function name. However, since we run + # ASTTypedRewriter on the function node directly, we need to do it + # explicitly. + self.globals[function_node.name] = types.TVar() + + # Memoize the function before typing it to handle recursive + # invocations. + self.functions[function] = function_node.name # Rewrite into typed form. - typedtree = self.asttyped_rewriter.visit(parsetree) + asttyped_rewriter = StitchingASTTypedRewriter( + engine=self.engine, prelude=self.prelude, + globals=self.globals, host_environment=host_environment, + quote_function=self._quote_function) + return asttyped_rewriter.visit(function_node) - return typedtree, typedtree.body[0] + def _quote_function(self, function): + if function in self.functions: + return self.functions[function] + + # Insert the typed AST for the new function and restart inference. + # It doesn't really matter where we insert as long as it is before + # the final call. + function_node = self._quote_embedded_function(function) + self.typedtree.insert(0, function_node) + self.inference_finished = False + return function_node.name def stitch_call(self, function, args, kwargs): - self.typedtree, function_node = self._parse_embedded_function(function) + function_node = self._quote_embedded_function(function) + self.typedtree.append(function_node) - # We synthesize fake source code for the initial call so that + # We synthesize source code for the initial call so that # diagnostics would have something meaningful to display to the user. synthesizer = ASTSynthesizer() call_node = synthesizer.call(function_node, args, kwargs) synthesizer.finalize() - self.typedtree.body.append(call_node) + self.typedtree.append(call_node) self._iterate() diff --git a/artiq/compiler/module.py b/artiq/compiler/module.py index 2f62da1d8..d35a887c4 100644 --- a/artiq/compiler/module.py +++ b/artiq/compiler/module.py @@ -67,7 +67,10 @@ class Module: def entry_point(self): """Return the name of the function that is the entry point of this module.""" - return self.name + ".__modinit__" + if self.name != "": + return self.name + ".__modinit__" + else: + return "__modinit__" def __repr__(self): printer = types.TypePrinter() diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 845ef3189..84bcbc6c5 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -70,7 +70,7 @@ class ARTIQIRGenerator(algorithm.Visitor): def __init__(self, module_name, engine): self.engine = engine self.functions = [] - self.name = [module_name] + self.name = [module_name] if module_name != "" else [] self.current_loc = None self.current_function = None self.current_globals = set() diff --git a/artiq/compiler/transforms/asttyped_rewriter.py b/artiq/compiler/transforms/asttyped_rewriter.py index 85b96e694..51b84efb0 100644 --- a/artiq/compiler/transforms/asttyped_rewriter.py +++ b/artiq/compiler/transforms/asttyped_rewriter.py @@ -185,10 +185,10 @@ class ASTTypedRewriter(algorithm.Transformer): via :class:`LocalExtractor`. """ - def __init__(self, engine, globals): + def __init__(self, engine, prelude): self.engine = engine self.globals = None - self.env_stack = [globals] + self.env_stack = [prelude] def _try_find_name(self, name): for typing_env in reversed(self.env_stack):