forked from M-Labs/artiq
Add support for referring to host values in embedded functions.
This commit is contained in:
parent
353f454a29
commit
50448ef554
|
@ -5,7 +5,7 @@ the references to the host objects and translates the functions
|
||||||
annotated as ``@kernel`` when they are referenced.
|
annotated as ``@kernel`` when they are referenced.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect, os
|
||||||
from pythonparser import ast, source, diagnostic, parse_buffer
|
from pythonparser import ast, source, diagnostic, parse_buffer
|
||||||
from . import types, builtins, asttyped, prelude
|
from . import types, builtins, asttyped, prelude
|
||||||
from .transforms import ASTTypedRewriter, Inferencer
|
from .transforms import ASTTypedRewriter, Inferencer
|
||||||
|
@ -28,11 +28,12 @@ class ASTSynthesizer:
|
||||||
|
|
||||||
def quote(self, value):
|
def quote(self, value):
|
||||||
"""Construct an AST fragment equal to `value`."""
|
"""Construct an AST fragment equal to `value`."""
|
||||||
if value in (None, True, False):
|
if value is None:
|
||||||
if node.value is True or node.value is False:
|
typ = builtins.TNone()
|
||||||
typ = builtins.TBool()
|
return asttyped.NameConstantT(value=value, type=typ,
|
||||||
elif node.value is None:
|
loc=self._add(repr(value)))
|
||||||
typ = builtins.TNone()
|
elif value is True or value is False:
|
||||||
|
typ = builtins.TBool()
|
||||||
return asttyped.NameConstantT(value=value, type=typ,
|
return asttyped.NameConstantT(value=value, type=typ,
|
||||||
loc=self._add(repr(value)))
|
loc=self._add(repr(value)))
|
||||||
elif isinstance(value, (int, float)):
|
elif isinstance(value, (int, float)):
|
||||||
|
@ -45,12 +46,12 @@ class ASTSynthesizer:
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
begin_loc = self._add("[")
|
begin_loc = self._add("[")
|
||||||
elts = []
|
elts = []
|
||||||
for index, elt in value:
|
for index, elt in enumerate(value):
|
||||||
elts.append(self.quote(elt))
|
elts.append(self.quote(elt))
|
||||||
if index < len(value) - 1:
|
if index < len(value) - 1:
|
||||||
self._add(", ")
|
self._add(", ")
|
||||||
end_loc = self._add("]")
|
end_loc = self._add("]")
|
||||||
return asttyped.ListT(elts=elts, ctx=None, type=types.TVar(),
|
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
|
||||||
begin_loc=begin_loc, end_loc=end_loc,
|
begin_loc=begin_loc, end_loc=end_loc,
|
||||||
loc=begin_loc.join(end_loc))
|
loc=begin_loc.join(end_loc))
|
||||||
else:
|
else:
|
||||||
|
@ -99,7 +100,43 @@ class ASTSynthesizer:
|
||||||
loc=name_loc.join(end_loc))
|
loc=name_loc.join(end_loc))
|
||||||
|
|
||||||
class StitchingASTTypedRewriter(ASTTypedRewriter):
|
class StitchingASTTypedRewriter(ASTTypedRewriter):
|
||||||
pass
|
def __init__(self, engine, prelude, globals, host_environment, quote_function):
|
||||||
|
super().__init__(engine, prelude)
|
||||||
|
self.globals = globals
|
||||||
|
self.env_stack.append(self.globals)
|
||||||
|
|
||||||
|
self.host_environment = host_environment
|
||||||
|
self.quote_function = quote_function
|
||||||
|
|
||||||
|
def visit_Name(self, node):
|
||||||
|
typ = super()._try_find_name(node.id)
|
||||||
|
if typ is not None:
|
||||||
|
# Value from device environment.
|
||||||
|
return asttyped.NameT(type=typ, id=node.id, ctx=node.ctx,
|
||||||
|
loc=node.loc)
|
||||||
|
else:
|
||||||
|
# Try to find this value in the host environment and quote it.
|
||||||
|
if node.id in self.host_environment:
|
||||||
|
value = self.host_environment[node.id]
|
||||||
|
if inspect.isfunction(value):
|
||||||
|
# It's a function. We need to translate the function and insert
|
||||||
|
# a reference to it.
|
||||||
|
function_name = self.quote_function(value)
|
||||||
|
return asttyped.NameT(id=function_name, ctx=None,
|
||||||
|
type=self.globals[function_name],
|
||||||
|
loc=node.loc)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# It's just a value. Quote it.
|
||||||
|
synthesizer = ASTSynthesizer()
|
||||||
|
node = synthesizer.quote(value)
|
||||||
|
synthesizer.finalize()
|
||||||
|
return node
|
||||||
|
else:
|
||||||
|
diag = diagnostic.Diagnostic("fatal",
|
||||||
|
"name '{name}' is not bound to anything", {"name":node.id},
|
||||||
|
node.loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
class Stitcher:
|
class Stitcher:
|
||||||
def __init__(self, engine=None):
|
def __init__(self, engine=None):
|
||||||
|
@ -108,24 +145,30 @@ class Stitcher:
|
||||||
else:
|
else:
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
|
|
||||||
self.asttyped_rewriter = StitchingASTTypedRewriter(
|
self.name = ""
|
||||||
engine=self.engine, globals=prelude.globals())
|
self.typedtree = []
|
||||||
self.inferencer = Inferencer(engine=self.engine)
|
self.prelude = prelude.globals()
|
||||||
|
self.globals = {}
|
||||||
|
|
||||||
self.name = "stitched"
|
self.functions = {}
|
||||||
self.typedtree = None
|
|
||||||
self.globals = self.asttyped_rewriter.globals
|
|
||||||
|
|
||||||
self.rpc_map = {}
|
self.rpc_map = {}
|
||||||
|
|
||||||
def _iterate(self):
|
def _iterate(self):
|
||||||
|
inferencer = Inferencer(engine=self.engine)
|
||||||
|
|
||||||
# Iterate inference to fixed point.
|
# Iterate inference to fixed point.
|
||||||
self.inference_finished = False
|
self.inference_finished = False
|
||||||
while not self.inference_finished:
|
while not self.inference_finished:
|
||||||
self.inference_finished = True
|
self.inference_finished = True
|
||||||
self.inferencer.visit(self.typedtree)
|
inferencer.visit(self.typedtree)
|
||||||
|
|
||||||
def _parse_embedded_function(self, function):
|
# After we have found all functions, synthesize a module to hold them.
|
||||||
|
self.typedtree = asttyped.ModuleT(
|
||||||
|
typing_env=self.globals, globals_in_scope=set(),
|
||||||
|
body=self.typedtree, loc=None)
|
||||||
|
|
||||||
|
def _quote_embedded_function(self, function):
|
||||||
if not hasattr(function, "artiq_embedded"):
|
if not hasattr(function, "artiq_embedded"):
|
||||||
raise ValueError("{} is not an embedded function".format(repr(function)))
|
raise ValueError("{} is not an embedded function".format(repr(function)))
|
||||||
|
|
||||||
|
@ -133,25 +176,62 @@ class Stitcher:
|
||||||
embedded_function = function.artiq_embedded.function
|
embedded_function = function.artiq_embedded.function
|
||||||
source_code = inspect.getsource(embedded_function)
|
source_code = inspect.getsource(embedded_function)
|
||||||
filename = embedded_function.__code__.co_filename
|
filename = embedded_function.__code__.co_filename
|
||||||
|
module_name, _ = os.path.splitext(os.path.basename(filename))
|
||||||
first_line = embedded_function.__code__.co_firstlineno
|
first_line = embedded_function.__code__.co_firstlineno
|
||||||
|
|
||||||
|
# Extract function environment.
|
||||||
|
host_environment = dict()
|
||||||
|
host_environment.update(embedded_function.__globals__)
|
||||||
|
cells = embedded_function.__closure__
|
||||||
|
cell_names = embedded_function.__code__.co_freevars
|
||||||
|
host_environment.update({var: cells[index] for index, var in enumerate(cell_names)})
|
||||||
|
|
||||||
# Parse.
|
# Parse.
|
||||||
source_buffer = source.Buffer(source_code, filename, first_line)
|
source_buffer = source.Buffer(source_code, filename, first_line)
|
||||||
parsetree, comments = parse_buffer(source_buffer, engine=self.engine)
|
parsetree, comments = parse_buffer(source_buffer, engine=self.engine)
|
||||||
|
function_node = parsetree.body[0]
|
||||||
|
|
||||||
|
# Mangle the name, since we put everything into a single module.
|
||||||
|
function_node.name = "{}.{}".format(module_name, function_node.name)
|
||||||
|
|
||||||
|
# Normally, LocalExtractor would populate the typing environment
|
||||||
|
# of the module with the function name. However, since we run
|
||||||
|
# ASTTypedRewriter on the function node directly, we need to do it
|
||||||
|
# explicitly.
|
||||||
|
self.globals[function_node.name] = types.TVar()
|
||||||
|
|
||||||
|
# Memoize the function before typing it to handle recursive
|
||||||
|
# invocations.
|
||||||
|
self.functions[function] = function_node.name
|
||||||
|
|
||||||
# Rewrite into typed form.
|
# Rewrite into typed form.
|
||||||
typedtree = self.asttyped_rewriter.visit(parsetree)
|
asttyped_rewriter = StitchingASTTypedRewriter(
|
||||||
|
engine=self.engine, prelude=self.prelude,
|
||||||
|
globals=self.globals, host_environment=host_environment,
|
||||||
|
quote_function=self._quote_function)
|
||||||
|
return asttyped_rewriter.visit(function_node)
|
||||||
|
|
||||||
return typedtree, typedtree.body[0]
|
def _quote_function(self, function):
|
||||||
|
if function in self.functions:
|
||||||
|
return self.functions[function]
|
||||||
|
|
||||||
|
# Insert the typed AST for the new function and restart inference.
|
||||||
|
# It doesn't really matter where we insert as long as it is before
|
||||||
|
# the final call.
|
||||||
|
function_node = self._quote_embedded_function(function)
|
||||||
|
self.typedtree.insert(0, function_node)
|
||||||
|
self.inference_finished = False
|
||||||
|
return function_node.name
|
||||||
|
|
||||||
def stitch_call(self, function, args, kwargs):
|
def stitch_call(self, function, args, kwargs):
|
||||||
self.typedtree, function_node = self._parse_embedded_function(function)
|
function_node = self._quote_embedded_function(function)
|
||||||
|
self.typedtree.append(function_node)
|
||||||
|
|
||||||
# We synthesize fake source code for the initial call so that
|
# We synthesize source code for the initial call so that
|
||||||
# diagnostics would have something meaningful to display to the user.
|
# diagnostics would have something meaningful to display to the user.
|
||||||
synthesizer = ASTSynthesizer()
|
synthesizer = ASTSynthesizer()
|
||||||
call_node = synthesizer.call(function_node, args, kwargs)
|
call_node = synthesizer.call(function_node, args, kwargs)
|
||||||
synthesizer.finalize()
|
synthesizer.finalize()
|
||||||
self.typedtree.body.append(call_node)
|
self.typedtree.append(call_node)
|
||||||
|
|
||||||
self._iterate()
|
self._iterate()
|
||||||
|
|
|
@ -67,7 +67,10 @@ class Module:
|
||||||
|
|
||||||
def entry_point(self):
|
def entry_point(self):
|
||||||
"""Return the name of the function that is the entry point of this module."""
|
"""Return the name of the function that is the entry point of this module."""
|
||||||
return self.name + ".__modinit__"
|
if self.name != "":
|
||||||
|
return self.name + ".__modinit__"
|
||||||
|
else:
|
||||||
|
return "__modinit__"
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
printer = types.TypePrinter()
|
printer = types.TypePrinter()
|
||||||
|
|
|
@ -70,7 +70,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
def __init__(self, module_name, engine):
|
def __init__(self, module_name, engine):
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.functions = []
|
self.functions = []
|
||||||
self.name = [module_name]
|
self.name = [module_name] if module_name != "" else []
|
||||||
self.current_loc = None
|
self.current_loc = None
|
||||||
self.current_function = None
|
self.current_function = None
|
||||||
self.current_globals = set()
|
self.current_globals = set()
|
||||||
|
|
|
@ -185,10 +185,10 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||||
via :class:`LocalExtractor`.
|
via :class:`LocalExtractor`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, engine, globals):
|
def __init__(self, engine, prelude):
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.globals = None
|
self.globals = None
|
||||||
self.env_stack = [globals]
|
self.env_stack = [prelude]
|
||||||
|
|
||||||
def _try_find_name(self, name):
|
def _try_find_name(self, name):
|
||||||
for typing_env in reversed(self.env_stack):
|
for typing_env in reversed(self.env_stack):
|
||||||
|
|
Loading…
Reference in New Issue