mirror of
https://github.com/m-labs/artiq.git
synced 2024-12-25 11:18:27 +08:00
compiler.embedding: support calling methods marked as @kernel.
This commit is contained in:
parent
d0fd61866f
commit
c21387dc09
@ -34,10 +34,11 @@ class ObjectMap:
|
||||
return self.forward_map[obj_key]
|
||||
|
||||
class ASTSynthesizer:
|
||||
def __init__(self, type_map, value_map, expanded_from=None):
|
||||
def __init__(self, type_map, value_map, quote_function=None, expanded_from=None):
|
||||
self.source = ""
|
||||
self.source_buffer = source.Buffer(self.source, "<synthesized>")
|
||||
self.type_map, self.value_map = type_map, value_map
|
||||
self.quote_function = quote_function
|
||||
self.expanded_from = expanded_from
|
||||
|
||||
def finalize(self):
|
||||
@ -82,6 +83,10 @@ class ASTSynthesizer:
|
||||
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
|
||||
begin_loc=begin_loc, end_loc=end_loc,
|
||||
loc=begin_loc.join(end_loc))
|
||||
elif inspect.isfunction(value) or inspect.ismethod(value):
|
||||
function_name, function_type = self.quote_function(value, self.expanded_from)
|
||||
return asttyped.NameT(id=function_name, ctx=None, type=function_type,
|
||||
loc=self._add(repr(value)))
|
||||
else:
|
||||
quote_loc = self._add('`')
|
||||
repr_loc = self._add(repr(value))
|
||||
@ -155,6 +160,36 @@ class ASTSynthesizer:
|
||||
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
|
||||
loc=name_loc.join(end_loc))
|
||||
|
||||
def assign_local(self, var_name, value):
|
||||
name_loc = self._add(var_name)
|
||||
_ = self._add(" ")
|
||||
equals_loc = self._add("=")
|
||||
_ = self._add(" ")
|
||||
value_node = self.quote(value)
|
||||
|
||||
var_node = asttyped.NameT(id=var_name, ctx=None, type=value_node.type,
|
||||
loc=name_loc)
|
||||
|
||||
return ast.Assign(targets=[var_node], value=value_node,
|
||||
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))
|
||||
|
||||
def assign_attribute(self, obj, attr_name, value):
|
||||
obj_node = self.quote(obj)
|
||||
dot_loc = self._add(".")
|
||||
name_loc = self._add(attr_name)
|
||||
_ = self._add(" ")
|
||||
equals_loc = self._add("=")
|
||||
_ = self._add(" ")
|
||||
value_node = self.quote(value)
|
||||
|
||||
attr_node = asttyped.AttributeT(value=obj_node, attr=attr_name, ctx=None,
|
||||
type=value_node.type,
|
||||
dot_loc=dot_loc, attr_loc=name_loc,
|
||||
loc=obj_node.loc.join(name_loc))
|
||||
|
||||
return ast.Assign(targets=[attr_node], value=value_node,
|
||||
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))
|
||||
|
||||
class StitchingASTTypedRewriter(ASTTypedRewriter):
|
||||
def __init__(self, engine, prelude, globals, host_environment, quote):
|
||||
super().__init__(engine, prelude)
|
||||
@ -221,7 +256,20 @@ class StitchingInferencer(Inferencer):
|
||||
# overhead (i.e. synthesizing a source buffer), but has the advantage
|
||||
# of having the host-to-ARTIQ mapping code in only one place and
|
||||
# also immediately getting proper diagnostics on type errors.
|
||||
ast = self.quote(getattr(object_value, node.attr), object_loc.expanded_from)
|
||||
attr_value = getattr(object_value, node.attr)
|
||||
if (inspect.ismethod(attr_value) and hasattr(attr_value.__func__, 'artiq_embedded')
|
||||
and types.is_instance(object_type)):
|
||||
# In cases like:
|
||||
# class c:
|
||||
# @kernel
|
||||
# def f(self): pass
|
||||
# we want f to be defined on the class, not on the instance.
|
||||
attributes = object_type.constructor.attributes
|
||||
attr_value = attr_value.__func__
|
||||
else:
|
||||
attributes = object_type.attributes
|
||||
|
||||
ast = self.quote(attr_value, None)
|
||||
|
||||
def proxy_diagnostic(diag):
|
||||
note = diagnostic.Diagnostic("note",
|
||||
@ -238,17 +286,17 @@ class StitchingInferencer(Inferencer):
|
||||
Inferencer(engine=proxy_engine).visit(ast)
|
||||
IntMonomorphizer(engine=proxy_engine).visit(ast)
|
||||
|
||||
if node.attr not in object_type.attributes:
|
||||
if node.attr not in attributes:
|
||||
# We just figured out what the type should be. Add it.
|
||||
object_type.attributes[node.attr] = ast.type
|
||||
elif object_type.attributes[node.attr] != ast.type:
|
||||
attributes[node.attr] = ast.type
|
||||
elif attributes[node.attr] != ast.type:
|
||||
# Does this conflict with an earlier guess?
|
||||
printer = types.TypePrinter()
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"host object has an attribute of type {typea}, which is"
|
||||
" different from previously inferred type {typeb}",
|
||||
{"typea": printer.name(ast.type),
|
||||
"typeb": printer.name(object_type.attributes[node.attr])},
|
||||
"typeb": printer.name(attributes[node.attr])},
|
||||
object_loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
@ -261,11 +309,9 @@ class TypedtreeHasher(algorithm.Visitor):
|
||||
return self.visit(obj)
|
||||
elif isinstance(obj, types.Type):
|
||||
return hash(obj.find())
|
||||
elif isinstance(obj, list):
|
||||
return tuple(obj)
|
||||
else:
|
||||
assert obj is None or isinstance(obj, (bool, int, float, str))
|
||||
return obj
|
||||
# We don't care; only types change during inference.
|
||||
pass
|
||||
|
||||
fields = node._fields
|
||||
if hasattr(node, '_types'):
|
||||
@ -281,6 +327,7 @@ class Stitcher:
|
||||
|
||||
self.name = ""
|
||||
self.typedtree = []
|
||||
self.inject_at = 0
|
||||
self.prelude = prelude.globals()
|
||||
self.globals = {}
|
||||
|
||||
@ -290,6 +337,17 @@ class Stitcher:
|
||||
self.type_map = {}
|
||||
self.value_map = defaultdict(lambda: [])
|
||||
|
||||
def stitch_call(self, function, args, kwargs):
|
||||
function_node = self._quote_embedded_function(function)
|
||||
self.typedtree.append(function_node)
|
||||
|
||||
# We synthesize source code for the initial call so that
|
||||
# diagnostics would have something meaningful to display to the user.
|
||||
synthesizer = self._synthesizer()
|
||||
call_node = synthesizer.call(function_node, args, kwargs)
|
||||
synthesizer.finalize()
|
||||
self.typedtree.append(call_node)
|
||||
|
||||
def finalize(self):
|
||||
inferencer = StitchingInferencer(engine=self.engine,
|
||||
value_map=self.value_map,
|
||||
@ -306,12 +364,50 @@ class Stitcher:
|
||||
break
|
||||
old_typedtree_hash = typedtree_hash
|
||||
|
||||
# For every host class we embed, add an appropriate constructor
|
||||
# as a global. This is necessary for method lookup, which uses
|
||||
# the getconstructor instruction.
|
||||
for instance_type, constructor_type in list(self.type_map.values()):
|
||||
# Do we have any direct reference to a constructor?
|
||||
if len(self.value_map[constructor_type]) > 0:
|
||||
# Yes, use it.
|
||||
constructor, _constructor_loc = self.value_map[constructor_type][0]
|
||||
else:
|
||||
# No, extract one from a reference to an instance.
|
||||
instance, _instance_loc = self.value_map[instance_type][0]
|
||||
constructor = type(instance)
|
||||
|
||||
self.globals[constructor_type.name] = constructor_type
|
||||
|
||||
synthesizer = self._synthesizer()
|
||||
ast = synthesizer.assign_local(constructor_type.name, constructor)
|
||||
synthesizer.finalize()
|
||||
self._inject(ast)
|
||||
|
||||
for attr in constructor_type.attributes:
|
||||
if types.is_function(constructor_type.attributes[attr]):
|
||||
synthesizer = self._synthesizer()
|
||||
ast = synthesizer.assign_attribute(constructor, attr,
|
||||
getattr(constructor, attr))
|
||||
synthesizer.finalize()
|
||||
self._inject(ast)
|
||||
|
||||
# After we have found all functions, synthesize a module to hold them.
|
||||
source_buffer = source.Buffer("", "<synthesized>")
|
||||
self.typedtree = asttyped.ModuleT(
|
||||
typing_env=self.globals, globals_in_scope=set(),
|
||||
body=self.typedtree, loc=source.Range(source_buffer, 0, 0))
|
||||
|
||||
def _inject(self, node):
|
||||
self.typedtree.insert(self.inject_at, node)
|
||||
self.inject_at += 1
|
||||
|
||||
def _synthesizer(self, expanded_from=None):
|
||||
return ASTSynthesizer(expanded_from=expanded_from,
|
||||
type_map=self.type_map,
|
||||
value_map=self.value_map,
|
||||
quote_function=self._quote_function)
|
||||
|
||||
def _quote_embedded_function(self, function):
|
||||
if not hasattr(function, "artiq_embedded"):
|
||||
raise ValueError("{} is not an embedded function".format(repr(function)))
|
||||
@ -414,10 +510,7 @@ class Stitcher:
|
||||
# This is tricky, because the default value might not have
|
||||
# a well-defined type in APython.
|
||||
# In this case, we bail out, but mention why we do it.
|
||||
synthesizer = ASTSynthesizer(type_map=self.type_map,
|
||||
value_map=self.value_map)
|
||||
ast = synthesizer.quote(param.default)
|
||||
synthesizer.finalize()
|
||||
ast = self._quote(param.default, None)
|
||||
|
||||
def proxy_diagnostic(diag):
|
||||
note = diagnostic.Diagnostic("note",
|
||||
@ -499,11 +592,12 @@ class Stitcher:
|
||||
self.globals[function_name] = function_type
|
||||
self.functions[function] = function_name
|
||||
|
||||
return function_name
|
||||
return function_name, function_type
|
||||
|
||||
def _quote_function(self, function, loc):
|
||||
if function in self.functions:
|
||||
return self.functions[function]
|
||||
function_name = self.functions[function]
|
||||
return function_name, self.globals[function_name]
|
||||
|
||||
if hasattr(function, "artiq_embedded"):
|
||||
if function.artiq_embedded.function is not None:
|
||||
@ -511,8 +605,8 @@ class Stitcher:
|
||||
# It doesn't really matter where we insert as long as it is before
|
||||
# the final call.
|
||||
function_node = self._quote_embedded_function(function)
|
||||
self.typedtree.insert(0, function_node)
|
||||
return function_node.name
|
||||
self._inject(function_node)
|
||||
return function_node.name, self.globals[function_node.name]
|
||||
elif function.artiq_embedded.syscall is not None:
|
||||
# Insert a storage-less global whose type instructs the compiler
|
||||
# to perform a system call instead of a regular call.
|
||||
@ -527,31 +621,7 @@ class Stitcher:
|
||||
syscall=None)
|
||||
|
||||
def _quote(self, value, loc):
|
||||
if inspect.isfunction(value) or inspect.ismethod(value):
|
||||
# It's a function. We need to translate the function and insert
|
||||
# a reference to it.
|
||||
function_name = self._quote_function(value, loc)
|
||||
return asttyped.NameT(id=function_name, ctx=None,
|
||||
type=self.globals[function_name],
|
||||
loc=loc)
|
||||
|
||||
else:
|
||||
# It's just a value. Quote it.
|
||||
synthesizer = ASTSynthesizer(expanded_from=loc,
|
||||
type_map=self.type_map,
|
||||
value_map=self.value_map)
|
||||
node = synthesizer.quote(value)
|
||||
synthesizer.finalize()
|
||||
return node
|
||||
|
||||
def stitch_call(self, function, args, kwargs):
|
||||
function_node = self._quote_embedded_function(function)
|
||||
self.typedtree.append(function_node)
|
||||
|
||||
# We synthesize source code for the initial call so that
|
||||
# diagnostics would have something meaningful to display to the user.
|
||||
synthesizer = ASTSynthesizer(type_map=self.type_map,
|
||||
value_map=self.value_map)
|
||||
call_node = synthesizer.call(function_node, args, kwargs)
|
||||
synthesizer = self._synthesizer(loc)
|
||||
node = synthesizer.quote(value)
|
||||
synthesizer.finalize()
|
||||
self.typedtree.append(call_node)
|
||||
return node
|
||||
|
@ -127,8 +127,10 @@ class Inferencer(algorithm.Visitor):
|
||||
when=" while inferring the type for self argument")
|
||||
|
||||
attr_type = types.TMethod(object_type, attr_type)
|
||||
self._unify(node.type, attr_type,
|
||||
node.loc, None)
|
||||
|
||||
if not types.is_var(attr_type):
|
||||
self._unify(node.type, attr_type,
|
||||
node.loc, None)
|
||||
else:
|
||||
if node.attr_loc.source_buffer == node.value.loc.source_buffer:
|
||||
highlights, notes = [node.value.loc], []
|
||||
|
@ -266,7 +266,7 @@ class LLVMIRGenerator:
|
||||
elif types.is_constructor(typ):
|
||||
name = "class.{}".format(typ.name)
|
||||
else:
|
||||
name = typ.name
|
||||
name = "instance.{}".format(typ.name)
|
||||
|
||||
llty = self.llcontext.get_identified_type(name)
|
||||
if llty.elements is None:
|
||||
@ -991,7 +991,7 @@ class LLVMIRGenerator:
|
||||
llfields.append(self._quote(getattr(value, attr), typ.attributes[attr],
|
||||
lambda: path() + [attr]))
|
||||
|
||||
llvalue = ll.Constant.literal_struct(llfields)
|
||||
llvalue = ll.Constant(llty.pointee, llfields)
|
||||
llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name)
|
||||
llconst.initializer = llvalue
|
||||
llconst.linkage = "private"
|
||||
@ -1012,8 +1012,10 @@ class LLVMIRGenerator:
|
||||
elif builtins.is_str(typ):
|
||||
assert isinstance(value, (str, bytes))
|
||||
return self.llstr_of_str(value)
|
||||
elif types.is_rpc_function(typ):
|
||||
return ll.Constant.literal_struct([])
|
||||
elif types.is_function(typ):
|
||||
# RPC and C functions have no runtime representation; ARTIQ
|
||||
# functions are initialized explicitly.
|
||||
return ll.Constant(llty, ll.Undefined)
|
||||
else:
|
||||
print(typ)
|
||||
assert False
|
||||
|
Loading…
Reference in New Issue
Block a user