Add support for referring to host values in embedded functions.

This commit is contained in:
whitequark 2015-08-07 13:20:29 +03:00
parent 353f454a29
commit 50448ef554
4 changed files with 109 additions and 26 deletions

View File

@ -5,7 +5,7 @@ the references to the host objects and translates the functions
annotated as ``@kernel`` when they are referenced. annotated as ``@kernel`` when they are referenced.
""" """
import inspect import inspect, os
from pythonparser import ast, source, diagnostic, parse_buffer from pythonparser import ast, source, diagnostic, parse_buffer
from . import types, builtins, asttyped, prelude from . import types, builtins, asttyped, prelude
from .transforms import ASTTypedRewriter, Inferencer from .transforms import ASTTypedRewriter, Inferencer
@ -28,13 +28,14 @@ class ASTSynthesizer:
def quote(self, value): def quote(self, value):
"""Construct an AST fragment equal to `value`.""" """Construct an AST fragment equal to `value`."""
if value in (None, True, False): if value is None:
if node.value is True or node.value is False:
typ = builtins.TBool()
elif node.value is None:
typ = builtins.TNone() typ = builtins.TNone()
return asttyped.NameConstantT(value=value, type=typ, return asttyped.NameConstantT(value=value, type=typ,
loc=self._add(repr(value))) loc=self._add(repr(value)))
elif value is True or value is False:
typ = builtins.TBool()
return asttyped.NameConstantT(value=value, type=typ,
loc=self._add(repr(value)))
elif isinstance(value, (int, float)): elif isinstance(value, (int, float)):
if isinstance(value, int): if isinstance(value, int):
typ = builtins.TInt() typ = builtins.TInt()
@ -45,12 +46,12 @@ class ASTSynthesizer:
elif isinstance(value, list): elif isinstance(value, list):
begin_loc = self._add("[") begin_loc = self._add("[")
elts = [] elts = []
for index, elt in value: for index, elt in enumerate(value):
elts.append(self.quote(elt)) elts.append(self.quote(elt))
if index < len(value) - 1: if index < len(value) - 1:
self._add(", ") self._add(", ")
end_loc = self._add("]") end_loc = self._add("]")
return asttyped.ListT(elts=elts, ctx=None, type=types.TVar(), return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
begin_loc=begin_loc, end_loc=end_loc, begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc)) loc=begin_loc.join(end_loc))
else: else:
@ -99,7 +100,43 @@ class ASTSynthesizer:
loc=name_loc.join(end_loc)) loc=name_loc.join(end_loc))
class StitchingASTTypedRewriter(ASTTypedRewriter): class StitchingASTTypedRewriter(ASTTypedRewriter):
pass def __init__(self, engine, prelude, globals, host_environment, quote_function):
super().__init__(engine, prelude)
self.globals = globals
self.env_stack.append(self.globals)
self.host_environment = host_environment
self.quote_function = quote_function
def visit_Name(self, node):
typ = super()._try_find_name(node.id)
if typ is not None:
# Value from device environment.
return asttyped.NameT(type=typ, id=node.id, ctx=node.ctx,
loc=node.loc)
else:
# Try to find this value in the host environment and quote it.
if node.id in self.host_environment:
value = self.host_environment[node.id]
if inspect.isfunction(value):
# It's a function. We need to translate the function and insert
# a reference to it.
function_name = self.quote_function(value)
return asttyped.NameT(id=function_name, ctx=None,
type=self.globals[function_name],
loc=node.loc)
else:
# It's just a value. Quote it.
synthesizer = ASTSynthesizer()
node = synthesizer.quote(value)
synthesizer.finalize()
return node
else:
diag = diagnostic.Diagnostic("fatal",
"name '{name}' is not bound to anything", {"name":node.id},
node.loc)
self.engine.process(diag)
class Stitcher: class Stitcher:
def __init__(self, engine=None): def __init__(self, engine=None):
@ -108,24 +145,30 @@ class Stitcher:
else: else:
self.engine = engine self.engine = engine
self.asttyped_rewriter = StitchingASTTypedRewriter( self.name = ""
engine=self.engine, globals=prelude.globals()) self.typedtree = []
self.inferencer = Inferencer(engine=self.engine) self.prelude = prelude.globals()
self.globals = {}
self.name = "stitched" self.functions = {}
self.typedtree = None
self.globals = self.asttyped_rewriter.globals
self.rpc_map = {} self.rpc_map = {}
def _iterate(self): def _iterate(self):
inferencer = Inferencer(engine=self.engine)
# Iterate inference to fixed point. # Iterate inference to fixed point.
self.inference_finished = False self.inference_finished = False
while not self.inference_finished: while not self.inference_finished:
self.inference_finished = True self.inference_finished = True
self.inferencer.visit(self.typedtree) inferencer.visit(self.typedtree)
def _parse_embedded_function(self, function): # After we have found all functions, synthesize a module to hold them.
self.typedtree = asttyped.ModuleT(
typing_env=self.globals, globals_in_scope=set(),
body=self.typedtree, loc=None)
def _quote_embedded_function(self, function):
if not hasattr(function, "artiq_embedded"): if not hasattr(function, "artiq_embedded"):
raise ValueError("{} is not an embedded function".format(repr(function))) raise ValueError("{} is not an embedded function".format(repr(function)))
@ -133,25 +176,62 @@ class Stitcher:
embedded_function = function.artiq_embedded.function embedded_function = function.artiq_embedded.function
source_code = inspect.getsource(embedded_function) source_code = inspect.getsource(embedded_function)
filename = embedded_function.__code__.co_filename filename = embedded_function.__code__.co_filename
module_name, _ = os.path.splitext(os.path.basename(filename))
first_line = embedded_function.__code__.co_firstlineno first_line = embedded_function.__code__.co_firstlineno
# Extract function environment.
host_environment = dict()
host_environment.update(embedded_function.__globals__)
cells = embedded_function.__closure__
cell_names = embedded_function.__code__.co_freevars
host_environment.update({var: cells[index] for index, var in enumerate(cell_names)})
# Parse. # Parse.
source_buffer = source.Buffer(source_code, filename, first_line) source_buffer = source.Buffer(source_code, filename, first_line)
parsetree, comments = parse_buffer(source_buffer, engine=self.engine) parsetree, comments = parse_buffer(source_buffer, engine=self.engine)
function_node = parsetree.body[0]
# Mangle the name, since we put everything into a single module.
function_node.name = "{}.{}".format(module_name, function_node.name)
# Normally, LocalExtractor would populate the typing environment
# of the module with the function name. However, since we run
# ASTTypedRewriter on the function node directly, we need to do it
# explicitly.
self.globals[function_node.name] = types.TVar()
# Memoize the function before typing it to handle recursive
# invocations.
self.functions[function] = function_node.name
# Rewrite into typed form. # Rewrite into typed form.
typedtree = self.asttyped_rewriter.visit(parsetree) asttyped_rewriter = StitchingASTTypedRewriter(
engine=self.engine, prelude=self.prelude,
globals=self.globals, host_environment=host_environment,
quote_function=self._quote_function)
return asttyped_rewriter.visit(function_node)
return typedtree, typedtree.body[0] def _quote_function(self, function):
if function in self.functions:
return self.functions[function]
# Insert the typed AST for the new function and restart inference.
# It doesn't really matter where we insert as long as it is before
# the final call.
function_node = self._quote_embedded_function(function)
self.typedtree.insert(0, function_node)
self.inference_finished = False
return function_node.name
def stitch_call(self, function, args, kwargs): def stitch_call(self, function, args, kwargs):
self.typedtree, function_node = self._parse_embedded_function(function) function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node)
# We synthesize fake source code for the initial call so that # We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user. # diagnostics would have something meaningful to display to the user.
synthesizer = ASTSynthesizer() synthesizer = ASTSynthesizer()
call_node = synthesizer.call(function_node, args, kwargs) call_node = synthesizer.call(function_node, args, kwargs)
synthesizer.finalize() synthesizer.finalize()
self.typedtree.body.append(call_node) self.typedtree.append(call_node)
self._iterate() self._iterate()

View File

@ -67,7 +67,10 @@ class Module:
def entry_point(self): def entry_point(self):
"""Return the name of the function that is the entry point of this module.""" """Return the name of the function that is the entry point of this module."""
if self.name != "":
return self.name + ".__modinit__" return self.name + ".__modinit__"
else:
return "__modinit__"
def __repr__(self): def __repr__(self):
printer = types.TypePrinter() printer = types.TypePrinter()

View File

@ -70,7 +70,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
def __init__(self, module_name, engine): def __init__(self, module_name, engine):
self.engine = engine self.engine = engine
self.functions = [] self.functions = []
self.name = [module_name] self.name = [module_name] if module_name != "" else []
self.current_loc = None self.current_loc = None
self.current_function = None self.current_function = None
self.current_globals = set() self.current_globals = set()

View File

@ -185,10 +185,10 @@ class ASTTypedRewriter(algorithm.Transformer):
via :class:`LocalExtractor`. via :class:`LocalExtractor`.
""" """
def __init__(self, engine, globals): def __init__(self, engine, prelude):
self.engine = engine self.engine = engine
self.globals = None self.globals = None
self.env_stack = [globals] self.env_stack = [prelude]
def _try_find_name(self, name): def _try_find_name(self, name):
for typing_env in reversed(self.env_stack): for typing_env in reversed(self.env_stack):