forked from M-Labs/artiq
Add support for referring to host values in embedded functions.
This commit is contained in:
parent
353f454a29
commit
50448ef554
@ -5,7 +5,7 @@ the references to the host objects and translates the functions
|
||||
annotated as ``@kernel`` when they are referenced.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import inspect, os
|
||||
from pythonparser import ast, source, diagnostic, parse_buffer
|
||||
from . import types, builtins, asttyped, prelude
|
||||
from .transforms import ASTTypedRewriter, Inferencer
|
||||
@ -28,11 +28,12 @@ class ASTSynthesizer:
|
||||
|
||||
def quote(self, value):
|
||||
"""Construct an AST fragment equal to `value`."""
|
||||
if value in (None, True, False):
|
||||
if node.value is True or node.value is False:
|
||||
typ = builtins.TBool()
|
||||
elif node.value is None:
|
||||
typ = builtins.TNone()
|
||||
if value is None:
|
||||
typ = builtins.TNone()
|
||||
return asttyped.NameConstantT(value=value, type=typ,
|
||||
loc=self._add(repr(value)))
|
||||
elif value is True or value is False:
|
||||
typ = builtins.TBool()
|
||||
return asttyped.NameConstantT(value=value, type=typ,
|
||||
loc=self._add(repr(value)))
|
||||
elif isinstance(value, (int, float)):
|
||||
@ -45,12 +46,12 @@ class ASTSynthesizer:
|
||||
elif isinstance(value, list):
|
||||
begin_loc = self._add("[")
|
||||
elts = []
|
||||
for index, elt in value:
|
||||
for index, elt in enumerate(value):
|
||||
elts.append(self.quote(elt))
|
||||
if index < len(value) - 1:
|
||||
self._add(", ")
|
||||
end_loc = self._add("]")
|
||||
return asttyped.ListT(elts=elts, ctx=None, type=types.TVar(),
|
||||
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
|
||||
begin_loc=begin_loc, end_loc=end_loc,
|
||||
loc=begin_loc.join(end_loc))
|
||||
else:
|
||||
@ -99,7 +100,43 @@ class ASTSynthesizer:
|
||||
loc=name_loc.join(end_loc))
|
||||
|
||||
class StitchingASTTypedRewriter(ASTTypedRewriter):
|
||||
pass
|
||||
def __init__(self, engine, prelude, globals, host_environment, quote_function):
|
||||
super().__init__(engine, prelude)
|
||||
self.globals = globals
|
||||
self.env_stack.append(self.globals)
|
||||
|
||||
self.host_environment = host_environment
|
||||
self.quote_function = quote_function
|
||||
|
||||
def visit_Name(self, node):
|
||||
typ = super()._try_find_name(node.id)
|
||||
if typ is not None:
|
||||
# Value from device environment.
|
||||
return asttyped.NameT(type=typ, id=node.id, ctx=node.ctx,
|
||||
loc=node.loc)
|
||||
else:
|
||||
# Try to find this value in the host environment and quote it.
|
||||
if node.id in self.host_environment:
|
||||
value = self.host_environment[node.id]
|
||||
if inspect.isfunction(value):
|
||||
# It's a function. We need to translate the function and insert
|
||||
# a reference to it.
|
||||
function_name = self.quote_function(value)
|
||||
return asttyped.NameT(id=function_name, ctx=None,
|
||||
type=self.globals[function_name],
|
||||
loc=node.loc)
|
||||
|
||||
else:
|
||||
# It's just a value. Quote it.
|
||||
synthesizer = ASTSynthesizer()
|
||||
node = synthesizer.quote(value)
|
||||
synthesizer.finalize()
|
||||
return node
|
||||
else:
|
||||
diag = diagnostic.Diagnostic("fatal",
|
||||
"name '{name}' is not bound to anything", {"name":node.id},
|
||||
node.loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
class Stitcher:
|
||||
def __init__(self, engine=None):
|
||||
@ -108,24 +145,30 @@ class Stitcher:
|
||||
else:
|
||||
self.engine = engine
|
||||
|
||||
self.asttyped_rewriter = StitchingASTTypedRewriter(
|
||||
engine=self.engine, globals=prelude.globals())
|
||||
self.inferencer = Inferencer(engine=self.engine)
|
||||
self.name = ""
|
||||
self.typedtree = []
|
||||
self.prelude = prelude.globals()
|
||||
self.globals = {}
|
||||
|
||||
self.name = "stitched"
|
||||
self.typedtree = None
|
||||
self.globals = self.asttyped_rewriter.globals
|
||||
self.functions = {}
|
||||
|
||||
self.rpc_map = {}
|
||||
|
||||
def _iterate(self):
|
||||
inferencer = Inferencer(engine=self.engine)
|
||||
|
||||
# Iterate inference to fixed point.
|
||||
self.inference_finished = False
|
||||
while not self.inference_finished:
|
||||
self.inference_finished = True
|
||||
self.inferencer.visit(self.typedtree)
|
||||
inferencer.visit(self.typedtree)
|
||||
|
||||
def _parse_embedded_function(self, function):
|
||||
# After we have found all functions, synthesize a module to hold them.
|
||||
self.typedtree = asttyped.ModuleT(
|
||||
typing_env=self.globals, globals_in_scope=set(),
|
||||
body=self.typedtree, loc=None)
|
||||
|
||||
def _quote_embedded_function(self, function):
|
||||
if not hasattr(function, "artiq_embedded"):
|
||||
raise ValueError("{} is not an embedded function".format(repr(function)))
|
||||
|
||||
@ -133,25 +176,62 @@ class Stitcher:
|
||||
embedded_function = function.artiq_embedded.function
|
||||
source_code = inspect.getsource(embedded_function)
|
||||
filename = embedded_function.__code__.co_filename
|
||||
module_name, _ = os.path.splitext(os.path.basename(filename))
|
||||
first_line = embedded_function.__code__.co_firstlineno
|
||||
|
||||
# Extract function environment.
|
||||
host_environment = dict()
|
||||
host_environment.update(embedded_function.__globals__)
|
||||
cells = embedded_function.__closure__
|
||||
cell_names = embedded_function.__code__.co_freevars
|
||||
host_environment.update({var: cells[index] for index, var in enumerate(cell_names)})
|
||||
|
||||
# Parse.
|
||||
source_buffer = source.Buffer(source_code, filename, first_line)
|
||||
parsetree, comments = parse_buffer(source_buffer, engine=self.engine)
|
||||
function_node = parsetree.body[0]
|
||||
|
||||
# Mangle the name, since we put everything into a single module.
|
||||
function_node.name = "{}.{}".format(module_name, function_node.name)
|
||||
|
||||
# Normally, LocalExtractor would populate the typing environment
|
||||
# of the module with the function name. However, since we run
|
||||
# ASTTypedRewriter on the function node directly, we need to do it
|
||||
# explicitly.
|
||||
self.globals[function_node.name] = types.TVar()
|
||||
|
||||
# Memoize the function before typing it to handle recursive
|
||||
# invocations.
|
||||
self.functions[function] = function_node.name
|
||||
|
||||
# Rewrite into typed form.
|
||||
typedtree = self.asttyped_rewriter.visit(parsetree)
|
||||
asttyped_rewriter = StitchingASTTypedRewriter(
|
||||
engine=self.engine, prelude=self.prelude,
|
||||
globals=self.globals, host_environment=host_environment,
|
||||
quote_function=self._quote_function)
|
||||
return asttyped_rewriter.visit(function_node)
|
||||
|
||||
return typedtree, typedtree.body[0]
|
||||
def _quote_function(self, function):
|
||||
if function in self.functions:
|
||||
return self.functions[function]
|
||||
|
||||
# Insert the typed AST for the new function and restart inference.
|
||||
# It doesn't really matter where we insert as long as it is before
|
||||
# the final call.
|
||||
function_node = self._quote_embedded_function(function)
|
||||
self.typedtree.insert(0, function_node)
|
||||
self.inference_finished = False
|
||||
return function_node.name
|
||||
|
||||
def stitch_call(self, function, args, kwargs):
|
||||
self.typedtree, function_node = self._parse_embedded_function(function)
|
||||
function_node = self._quote_embedded_function(function)
|
||||
self.typedtree.append(function_node)
|
||||
|
||||
# We synthesize fake source code for the initial call so that
|
||||
# We synthesize source code for the initial call so that
|
||||
# diagnostics would have something meaningful to display to the user.
|
||||
synthesizer = ASTSynthesizer()
|
||||
call_node = synthesizer.call(function_node, args, kwargs)
|
||||
synthesizer.finalize()
|
||||
self.typedtree.body.append(call_node)
|
||||
self.typedtree.append(call_node)
|
||||
|
||||
self._iterate()
|
||||
|
@ -67,7 +67,10 @@ class Module:
|
||||
|
||||
def entry_point(self):
|
||||
"""Return the name of the function that is the entry point of this module."""
|
||||
return self.name + ".__modinit__"
|
||||
if self.name != "":
|
||||
return self.name + ".__modinit__"
|
||||
else:
|
||||
return "__modinit__"
|
||||
|
||||
def __repr__(self):
|
||||
printer = types.TypePrinter()
|
||||
|
@ -70,7 +70,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
def __init__(self, module_name, engine):
|
||||
self.engine = engine
|
||||
self.functions = []
|
||||
self.name = [module_name]
|
||||
self.name = [module_name] if module_name != "" else []
|
||||
self.current_loc = None
|
||||
self.current_function = None
|
||||
self.current_globals = set()
|
||||
|
@ -185,10 +185,10 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||
via :class:`LocalExtractor`.
|
||||
"""
|
||||
|
||||
def __init__(self, engine, globals):
|
||||
def __init__(self, engine, prelude):
|
||||
self.engine = engine
|
||||
self.globals = None
|
||||
self.env_stack = [globals]
|
||||
self.env_stack = [prelude]
|
||||
|
||||
def _try_find_name(self, name):
|
||||
for typing_env in reversed(self.env_stack):
|
||||
|
Loading…
Reference in New Issue
Block a user