compiler.embedding: support RPC functions as host attribute values.

This commit is contained in:
whitequark 2015-08-27 05:53:18 -05:00
parent 04bd2421ad
commit c62b16d5e1
2 changed files with 36 additions and 39 deletions

View File

@ -12,7 +12,6 @@ from pythonparser import ast, source, diagnostic, parse_buffer
from . import types, builtins, asttyped, prelude from . import types, builtins, asttyped, prelude
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
from .validators import MonomorphismValidator
class ObjectMap: class ObjectMap:
@ -156,16 +155,13 @@ class ASTSynthesizer:
loc=name_loc.join(end_loc)) loc=name_loc.join(end_loc))
class StitchingASTTypedRewriter(ASTTypedRewriter): class StitchingASTTypedRewriter(ASTTypedRewriter):
def __init__(self, engine, prelude, globals, host_environment, quote_function, def __init__(self, engine, prelude, globals, host_environment, quote):
type_map, value_map):
super().__init__(engine, prelude) super().__init__(engine, prelude)
self.globals = globals self.globals = globals
self.env_stack.append(self.globals) self.env_stack.append(self.globals)
self.host_environment = host_environment self.host_environment = host_environment
self.quote_function = quote_function self.quote = quote
self.type_map = type_map
self.value_map = value_map
def visit_Name(self, node): def visit_Name(self, node):
typ = super()._try_find_name(node.id) typ = super()._try_find_name(node.id)
@ -176,23 +172,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
else: else:
# Try to find this value in the host environment and quote it. # Try to find this value in the host environment and quote it.
if node.id in self.host_environment: if node.id in self.host_environment:
value = self.host_environment[node.id] return self.quote(self.host_environment[node.id], node.loc)
if inspect.isfunction(value):
# It's a function. We need to translate the function and insert
# a reference to it.
function_name = self.quote_function(value, node.loc)
return asttyped.NameT(id=function_name, ctx=None,
type=self.globals[function_name],
loc=node.loc)
else:
# It's just a value. Quote it.
synthesizer = ASTSynthesizer(expanded_from=node.loc,
type_map=self.type_map,
value_map=self.value_map)
node = synthesizer.quote(value)
synthesizer.finalize()
return node
else: else:
diag = diagnostic.Diagnostic("fatal", diag = diagnostic.Diagnostic("fatal",
"name '{name}' is not bound to anything", {"name":node.id}, "name '{name}' is not bound to anything", {"name":node.id},
@ -200,9 +180,10 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
self.engine.process(diag) self.engine.process(diag)
class StitchingInferencer(Inferencer): class StitchingInferencer(Inferencer):
def __init__(self, engine, type_map, value_map): def __init__(self, engine, value_map, quote):
super().__init__(engine) super().__init__(engine)
self.type_map, self.value_map = type_map, value_map self.value_map = value_map
self.quote = quote
def visit_AttributeT(self, node): def visit_AttributeT(self, node):
self.generic_visit(node) self.generic_visit(node)
@ -239,10 +220,7 @@ class StitchingInferencer(Inferencer):
# overhead (i.e. synthesizing a source buffer), but has the advantage # overhead (i.e. synthesizing a source buffer), but has the advantage
# of having the host-to-ARTIQ mapping code in only one place and # of having the host-to-ARTIQ mapping code in only one place and
# also immediately getting proper diagnostics on type errors. # also immediately getting proper diagnostics on type errors.
synthesizer = ASTSynthesizer(type_map=self.type_map, ast = self.quote(getattr(object_value, node.attr), object_loc.expanded_from)
value_map=self.value_map)
ast = synthesizer.quote(getattr(object_value, node.attr))
synthesizer.finalize()
def proxy_diagnostic(diag): def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
@ -258,7 +236,6 @@ class StitchingInferencer(Inferencer):
proxy_engine.process = proxy_diagnostic proxy_engine.process = proxy_diagnostic
Inferencer(engine=proxy_engine).visit(ast) Inferencer(engine=proxy_engine).visit(ast)
IntMonomorphizer(engine=proxy_engine).visit(ast) IntMonomorphizer(engine=proxy_engine).visit(ast)
MonomorphismValidator(engine=proxy_engine).visit(ast)
if node.attr not in object_type.attributes: if node.attr not in object_type.attributes:
# We just figured out what the type should be. Add it. # We just figured out what the type should be. Add it.
@ -296,7 +273,8 @@ class Stitcher:
def finalize(self): def finalize(self):
inferencer = StitchingInferencer(engine=self.engine, inferencer = StitchingInferencer(engine=self.engine,
type_map=self.type_map, value_map=self.value_map) value_map=self.value_map,
quote=self._quote)
# Iterate inference to fixed point. # Iterate inference to fixed point.
self.inference_finished = False self.inference_finished = False
@ -350,8 +328,7 @@ class Stitcher:
asttyped_rewriter = StitchingASTTypedRewriter( asttyped_rewriter = StitchingASTTypedRewriter(
engine=self.engine, prelude=self.prelude, engine=self.engine, prelude=self.prelude,
globals=self.globals, host_environment=host_environment, globals=self.globals, host_environment=host_environment,
quote_function=self._quote_function, quote=self._quote)
type_map=self.type_map, value_map=self.value_map)
return asttyped_rewriter.visit(function_node) return asttyped_rewriter.visit(function_node)
def _function_loc(self, function): def _function_loc(self, function):
@ -526,6 +503,24 @@ class Stitcher:
return self._quote_foreign_function(function, loc, return self._quote_foreign_function(function, loc,
syscall=None) syscall=None)
def _quote(self, value, loc):
if inspect.isfunction(value):
# It's a function. We need to translate the function and insert
# a reference to it.
function_name = self._quote_function(value, loc)
return asttyped.NameT(id=function_name, ctx=None,
type=self.globals[function_name],
loc=loc)
else:
# It's just a value. Quote it.
synthesizer = ASTSynthesizer(expanded_from=loc,
type_map=self.type_map,
value_map=self.value_map)
node = synthesizer.quote(value)
synthesizer.finalize()
return node
def stitch_call(self, function, args, kwargs): def stitch_call(self, function, args, kwargs):
function_node = self._quote_embedded_function(function) function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node) self.typedtree.append(function_node)

View File

@ -991,6 +991,11 @@ class LLVMIRGenerator:
lambda: path() + [attr])) lambda: path() + [attr]))
llvalue = ll.Constant.literal_struct(llfields) llvalue = ll.Constant.literal_struct(llfields)
llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name)
llconst.initializer = llvalue
llconst.linkage = "private"
self.llobject_map[value_id] = llconst
return llconst
elif builtins.is_none(typ): elif builtins.is_none(typ):
assert value is None assert value is None
return ll.Constant.literal_struct([]) return ll.Constant.literal_struct([])
@ -1006,15 +1011,12 @@ class LLVMIRGenerator:
elif builtins.is_str(typ): elif builtins.is_str(typ):
assert isinstance(value, (str, bytes)) assert isinstance(value, (str, bytes))
return self.llstr_of_str(value) return self.llstr_of_str(value)
elif types.is_rpc_function(typ):
return ll.Constant.literal_struct([])
else: else:
print(typ)
assert False assert False
llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name)
llconst.initializer = llvalue
llconst.linkage = "private"
self.llobject_map[value_id] = llconst
return llconst
def process_Quote(self, insn): def process_Quote(self, insn):
assert self.object_map is not None assert self.object_map is not None
return self._quote(insn.value, insn.type, lambda: [repr(insn.value)]) return self._quote(insn.value, insn.type, lambda: [repr(insn.value)])