compiler.embedding: support calling methods marked as @kernel.

This commit is contained in:
whitequark 2015-08-27 19:46:50 -05:00
parent d0fd61866f
commit c21387dc09
3 changed files with 125 additions and 51 deletions

View File

@ -34,10 +34,11 @@ class ObjectMap:
return self.forward_map[obj_key] return self.forward_map[obj_key]
class ASTSynthesizer: class ASTSynthesizer:
def __init__(self, type_map, value_map, expanded_from=None): def __init__(self, type_map, value_map, quote_function=None, expanded_from=None):
self.source = "" self.source = ""
self.source_buffer = source.Buffer(self.source, "<synthesized>") self.source_buffer = source.Buffer(self.source, "<synthesized>")
self.type_map, self.value_map = type_map, value_map self.type_map, self.value_map = type_map, value_map
self.quote_function = quote_function
self.expanded_from = expanded_from self.expanded_from = expanded_from
def finalize(self): def finalize(self):
@ -82,6 +83,10 @@ class ASTSynthesizer:
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(), return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
begin_loc=begin_loc, end_loc=end_loc, begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc)) loc=begin_loc.join(end_loc))
elif inspect.isfunction(value) or inspect.ismethod(value):
function_name, function_type = self.quote_function(value, self.expanded_from)
return asttyped.NameT(id=function_name, ctx=None, type=function_type,
loc=self._add(repr(value)))
else: else:
quote_loc = self._add('`') quote_loc = self._add('`')
repr_loc = self._add(repr(value)) repr_loc = self._add(repr(value))
@ -155,6 +160,36 @@ class ASTSynthesizer:
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None, begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
loc=name_loc.join(end_loc)) loc=name_loc.join(end_loc))
def assign_local(self, var_name, value):
name_loc = self._add(var_name)
_ = self._add(" ")
equals_loc = self._add("=")
_ = self._add(" ")
value_node = self.quote(value)
var_node = asttyped.NameT(id=var_name, ctx=None, type=value_node.type,
loc=name_loc)
return ast.Assign(targets=[var_node], value=value_node,
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))
def assign_attribute(self, obj, attr_name, value):
obj_node = self.quote(obj)
dot_loc = self._add(".")
name_loc = self._add(attr_name)
_ = self._add(" ")
equals_loc = self._add("=")
_ = self._add(" ")
value_node = self.quote(value)
attr_node = asttyped.AttributeT(value=obj_node, attr=attr_name, ctx=None,
type=value_node.type,
dot_loc=dot_loc, attr_loc=name_loc,
loc=obj_node.loc.join(name_loc))
return ast.Assign(targets=[attr_node], value=value_node,
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))
class StitchingASTTypedRewriter(ASTTypedRewriter): class StitchingASTTypedRewriter(ASTTypedRewriter):
def __init__(self, engine, prelude, globals, host_environment, quote): def __init__(self, engine, prelude, globals, host_environment, quote):
super().__init__(engine, prelude) super().__init__(engine, prelude)
@ -221,7 +256,20 @@ class StitchingInferencer(Inferencer):
# overhead (i.e. synthesizing a source buffer), but has the advantage # overhead (i.e. synthesizing a source buffer), but has the advantage
# of having the host-to-ARTIQ mapping code in only one place and # of having the host-to-ARTIQ mapping code in only one place and
# also immediately getting proper diagnostics on type errors. # also immediately getting proper diagnostics on type errors.
ast = self.quote(getattr(object_value, node.attr), object_loc.expanded_from) attr_value = getattr(object_value, node.attr)
if (inspect.ismethod(attr_value) and hasattr(attr_value.__func__, 'artiq_embedded')
and types.is_instance(object_type)):
# In cases like:
# class c:
# @kernel
# def f(self): pass
# we want f to be defined on the class, not on the instance.
attributes = object_type.constructor.attributes
attr_value = attr_value.__func__
else:
attributes = object_type.attributes
ast = self.quote(attr_value, None)
def proxy_diagnostic(diag): def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
@ -238,17 +286,17 @@ class StitchingInferencer(Inferencer):
Inferencer(engine=proxy_engine).visit(ast) Inferencer(engine=proxy_engine).visit(ast)
IntMonomorphizer(engine=proxy_engine).visit(ast) IntMonomorphizer(engine=proxy_engine).visit(ast)
if node.attr not in object_type.attributes: if node.attr not in attributes:
# We just figured out what the type should be. Add it. # We just figured out what the type should be. Add it.
object_type.attributes[node.attr] = ast.type attributes[node.attr] = ast.type
elif object_type.attributes[node.attr] != ast.type: elif attributes[node.attr] != ast.type:
# Does this conflict with an earlier guess? # Does this conflict with an earlier guess?
printer = types.TypePrinter() printer = types.TypePrinter()
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"host object has an attribute of type {typea}, which is" "host object has an attribute of type {typea}, which is"
" different from previously inferred type {typeb}", " different from previously inferred type {typeb}",
{"typea": printer.name(ast.type), {"typea": printer.name(ast.type),
"typeb": printer.name(object_type.attributes[node.attr])}, "typeb": printer.name(attributes[node.attr])},
object_loc) object_loc)
self.engine.process(diag) self.engine.process(diag)
@ -261,11 +309,9 @@ class TypedtreeHasher(algorithm.Visitor):
return self.visit(obj) return self.visit(obj)
elif isinstance(obj, types.Type): elif isinstance(obj, types.Type):
return hash(obj.find()) return hash(obj.find())
elif isinstance(obj, list):
return tuple(obj)
else: else:
assert obj is None or isinstance(obj, (bool, int, float, str)) # We don't care; only types change during inference.
return obj pass
fields = node._fields fields = node._fields
if hasattr(node, '_types'): if hasattr(node, '_types'):
@ -281,6 +327,7 @@ class Stitcher:
self.name = "" self.name = ""
self.typedtree = [] self.typedtree = []
self.inject_at = 0
self.prelude = prelude.globals() self.prelude = prelude.globals()
self.globals = {} self.globals = {}
@ -290,6 +337,17 @@ class Stitcher:
self.type_map = {} self.type_map = {}
self.value_map = defaultdict(lambda: []) self.value_map = defaultdict(lambda: [])
def stitch_call(self, function, args, kwargs):
function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node)
# We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = self._synthesizer()
call_node = synthesizer.call(function_node, args, kwargs)
synthesizer.finalize()
self.typedtree.append(call_node)
def finalize(self): def finalize(self):
inferencer = StitchingInferencer(engine=self.engine, inferencer = StitchingInferencer(engine=self.engine,
value_map=self.value_map, value_map=self.value_map,
@ -306,12 +364,50 @@ class Stitcher:
break break
old_typedtree_hash = typedtree_hash old_typedtree_hash = typedtree_hash
# For every host class we embed, add an appropriate constructor
# as a global. This is necessary for method lookup, which uses
# the getconstructor instruction.
for instance_type, constructor_type in list(self.type_map.values()):
# Do we have any direct reference to a constructor?
if len(self.value_map[constructor_type]) > 0:
# Yes, use it.
constructor, _constructor_loc = self.value_map[constructor_type][0]
else:
# No, extract one from a reference to an instance.
instance, _instance_loc = self.value_map[instance_type][0]
constructor = type(instance)
self.globals[constructor_type.name] = constructor_type
synthesizer = self._synthesizer()
ast = synthesizer.assign_local(constructor_type.name, constructor)
synthesizer.finalize()
self._inject(ast)
for attr in constructor_type.attributes:
if types.is_function(constructor_type.attributes[attr]):
synthesizer = self._synthesizer()
ast = synthesizer.assign_attribute(constructor, attr,
getattr(constructor, attr))
synthesizer.finalize()
self._inject(ast)
# After we have found all functions, synthesize a module to hold them. # After we have found all functions, synthesize a module to hold them.
source_buffer = source.Buffer("", "<synthesized>") source_buffer = source.Buffer("", "<synthesized>")
self.typedtree = asttyped.ModuleT( self.typedtree = asttyped.ModuleT(
typing_env=self.globals, globals_in_scope=set(), typing_env=self.globals, globals_in_scope=set(),
body=self.typedtree, loc=source.Range(source_buffer, 0, 0)) body=self.typedtree, loc=source.Range(source_buffer, 0, 0))
def _inject(self, node):
self.typedtree.insert(self.inject_at, node)
self.inject_at += 1
def _synthesizer(self, expanded_from=None):
return ASTSynthesizer(expanded_from=expanded_from,
type_map=self.type_map,
value_map=self.value_map,
quote_function=self._quote_function)
def _quote_embedded_function(self, function): def _quote_embedded_function(self, function):
if not hasattr(function, "artiq_embedded"): if not hasattr(function, "artiq_embedded"):
raise ValueError("{} is not an embedded function".format(repr(function))) raise ValueError("{} is not an embedded function".format(repr(function)))
@ -414,10 +510,7 @@ class Stitcher:
# This is tricky, because the default value might not have # This is tricky, because the default value might not have
# a well-defined type in APython. # a well-defined type in APython.
# In this case, we bail out, but mention why we do it. # In this case, we bail out, but mention why we do it.
synthesizer = ASTSynthesizer(type_map=self.type_map, ast = self._quote(param.default, None)
value_map=self.value_map)
ast = synthesizer.quote(param.default)
synthesizer.finalize()
def proxy_diagnostic(diag): def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
@ -499,11 +592,12 @@ class Stitcher:
self.globals[function_name] = function_type self.globals[function_name] = function_type
self.functions[function] = function_name self.functions[function] = function_name
return function_name return function_name, function_type
def _quote_function(self, function, loc): def _quote_function(self, function, loc):
if function in self.functions: if function in self.functions:
return self.functions[function] function_name = self.functions[function]
return function_name, self.globals[function_name]
if hasattr(function, "artiq_embedded"): if hasattr(function, "artiq_embedded"):
if function.artiq_embedded.function is not None: if function.artiq_embedded.function is not None:
@ -511,8 +605,8 @@ class Stitcher:
# It doesn't really matter where we insert as long as it is before # It doesn't really matter where we insert as long as it is before
# the final call. # the final call.
function_node = self._quote_embedded_function(function) function_node = self._quote_embedded_function(function)
self.typedtree.insert(0, function_node) self._inject(function_node)
return function_node.name return function_node.name, self.globals[function_node.name]
elif function.artiq_embedded.syscall is not None: elif function.artiq_embedded.syscall is not None:
# Insert a storage-less global whose type instructs the compiler # Insert a storage-less global whose type instructs the compiler
# to perform a system call instead of a regular call. # to perform a system call instead of a regular call.
@ -527,31 +621,7 @@ class Stitcher:
syscall=None) syscall=None)
def _quote(self, value, loc): def _quote(self, value, loc):
if inspect.isfunction(value) or inspect.ismethod(value): synthesizer = self._synthesizer(loc)
# It's a function. We need to translate the function and insert
# a reference to it.
function_name = self._quote_function(value, loc)
return asttyped.NameT(id=function_name, ctx=None,
type=self.globals[function_name],
loc=loc)
else:
# It's just a value. Quote it.
synthesizer = ASTSynthesizer(expanded_from=loc,
type_map=self.type_map,
value_map=self.value_map)
node = synthesizer.quote(value) node = synthesizer.quote(value)
synthesizer.finalize() synthesizer.finalize()
return node return node
def stitch_call(self, function, args, kwargs):
function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node)
# We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = ASTSynthesizer(type_map=self.type_map,
value_map=self.value_map)
call_node = synthesizer.call(function_node, args, kwargs)
synthesizer.finalize()
self.typedtree.append(call_node)

View File

@ -127,6 +127,8 @@ class Inferencer(algorithm.Visitor):
when=" while inferring the type for self argument") when=" while inferring the type for self argument")
attr_type = types.TMethod(object_type, attr_type) attr_type = types.TMethod(object_type, attr_type)
if not types.is_var(attr_type):
self._unify(node.type, attr_type, self._unify(node.type, attr_type,
node.loc, None) node.loc, None)
else: else:

View File

@ -266,7 +266,7 @@ class LLVMIRGenerator:
elif types.is_constructor(typ): elif types.is_constructor(typ):
name = "class.{}".format(typ.name) name = "class.{}".format(typ.name)
else: else:
name = typ.name name = "instance.{}".format(typ.name)
llty = self.llcontext.get_identified_type(name) llty = self.llcontext.get_identified_type(name)
if llty.elements is None: if llty.elements is None:
@ -991,7 +991,7 @@ class LLVMIRGenerator:
llfields.append(self._quote(getattr(value, attr), typ.attributes[attr], llfields.append(self._quote(getattr(value, attr), typ.attributes[attr],
lambda: path() + [attr])) lambda: path() + [attr]))
llvalue = ll.Constant.literal_struct(llfields) llvalue = ll.Constant(llty.pointee, llfields)
llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name) llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name)
llconst.initializer = llvalue llconst.initializer = llvalue
llconst.linkage = "private" llconst.linkage = "private"
@ -1012,8 +1012,10 @@ class LLVMIRGenerator:
elif builtins.is_str(typ): elif builtins.is_str(typ):
assert isinstance(value, (str, bytes)) assert isinstance(value, (str, bytes))
return self.llstr_of_str(value) return self.llstr_of_str(value)
elif types.is_rpc_function(typ): elif types.is_function(typ):
return ll.Constant.literal_struct([]) # RPC and C functions have no runtime representation; ARTIQ
# functions are initialized explicitly.
return ll.Constant(llty, ll.Undefined)
else: else:
print(typ) print(typ)
assert False assert False