forked from M-Labs/artiq
compiler.embedding: support calling methods marked as @kernel.
This commit is contained in:
parent
d0fd61866f
commit
c21387dc09
|
@ -34,10 +34,11 @@ class ObjectMap:
|
||||||
return self.forward_map[obj_key]
|
return self.forward_map[obj_key]
|
||||||
|
|
||||||
class ASTSynthesizer:
|
class ASTSynthesizer:
|
||||||
def __init__(self, type_map, value_map, expanded_from=None):
|
def __init__(self, type_map, value_map, quote_function=None, expanded_from=None):
|
||||||
self.source = ""
|
self.source = ""
|
||||||
self.source_buffer = source.Buffer(self.source, "<synthesized>")
|
self.source_buffer = source.Buffer(self.source, "<synthesized>")
|
||||||
self.type_map, self.value_map = type_map, value_map
|
self.type_map, self.value_map = type_map, value_map
|
||||||
|
self.quote_function = quote_function
|
||||||
self.expanded_from = expanded_from
|
self.expanded_from = expanded_from
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
|
@ -82,6 +83,10 @@ class ASTSynthesizer:
|
||||||
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
|
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
|
||||||
begin_loc=begin_loc, end_loc=end_loc,
|
begin_loc=begin_loc, end_loc=end_loc,
|
||||||
loc=begin_loc.join(end_loc))
|
loc=begin_loc.join(end_loc))
|
||||||
|
elif inspect.isfunction(value) or inspect.ismethod(value):
|
||||||
|
function_name, function_type = self.quote_function(value, self.expanded_from)
|
||||||
|
return asttyped.NameT(id=function_name, ctx=None, type=function_type,
|
||||||
|
loc=self._add(repr(value)))
|
||||||
else:
|
else:
|
||||||
quote_loc = self._add('`')
|
quote_loc = self._add('`')
|
||||||
repr_loc = self._add(repr(value))
|
repr_loc = self._add(repr(value))
|
||||||
|
@ -155,6 +160,36 @@ class ASTSynthesizer:
|
||||||
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
|
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
|
||||||
loc=name_loc.join(end_loc))
|
loc=name_loc.join(end_loc))
|
||||||
|
|
||||||
|
def assign_local(self, var_name, value):
|
||||||
|
name_loc = self._add(var_name)
|
||||||
|
_ = self._add(" ")
|
||||||
|
equals_loc = self._add("=")
|
||||||
|
_ = self._add(" ")
|
||||||
|
value_node = self.quote(value)
|
||||||
|
|
||||||
|
var_node = asttyped.NameT(id=var_name, ctx=None, type=value_node.type,
|
||||||
|
loc=name_loc)
|
||||||
|
|
||||||
|
return ast.Assign(targets=[var_node], value=value_node,
|
||||||
|
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))
|
||||||
|
|
||||||
|
def assign_attribute(self, obj, attr_name, value):
|
||||||
|
obj_node = self.quote(obj)
|
||||||
|
dot_loc = self._add(".")
|
||||||
|
name_loc = self._add(attr_name)
|
||||||
|
_ = self._add(" ")
|
||||||
|
equals_loc = self._add("=")
|
||||||
|
_ = self._add(" ")
|
||||||
|
value_node = self.quote(value)
|
||||||
|
|
||||||
|
attr_node = asttyped.AttributeT(value=obj_node, attr=attr_name, ctx=None,
|
||||||
|
type=value_node.type,
|
||||||
|
dot_loc=dot_loc, attr_loc=name_loc,
|
||||||
|
loc=obj_node.loc.join(name_loc))
|
||||||
|
|
||||||
|
return ast.Assign(targets=[attr_node], value=value_node,
|
||||||
|
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))
|
||||||
|
|
||||||
class StitchingASTTypedRewriter(ASTTypedRewriter):
|
class StitchingASTTypedRewriter(ASTTypedRewriter):
|
||||||
def __init__(self, engine, prelude, globals, host_environment, quote):
|
def __init__(self, engine, prelude, globals, host_environment, quote):
|
||||||
super().__init__(engine, prelude)
|
super().__init__(engine, prelude)
|
||||||
|
@ -221,7 +256,20 @@ class StitchingInferencer(Inferencer):
|
||||||
# overhead (i.e. synthesizing a source buffer), but has the advantage
|
# overhead (i.e. synthesizing a source buffer), but has the advantage
|
||||||
# of having the host-to-ARTIQ mapping code in only one place and
|
# of having the host-to-ARTIQ mapping code in only one place and
|
||||||
# also immediately getting proper diagnostics on type errors.
|
# also immediately getting proper diagnostics on type errors.
|
||||||
ast = self.quote(getattr(object_value, node.attr), object_loc.expanded_from)
|
attr_value = getattr(object_value, node.attr)
|
||||||
|
if (inspect.ismethod(attr_value) and hasattr(attr_value.__func__, 'artiq_embedded')
|
||||||
|
and types.is_instance(object_type)):
|
||||||
|
# In cases like:
|
||||||
|
# class c:
|
||||||
|
# @kernel
|
||||||
|
# def f(self): pass
|
||||||
|
# we want f to be defined on the class, not on the instance.
|
||||||
|
attributes = object_type.constructor.attributes
|
||||||
|
attr_value = attr_value.__func__
|
||||||
|
else:
|
||||||
|
attributes = object_type.attributes
|
||||||
|
|
||||||
|
ast = self.quote(attr_value, None)
|
||||||
|
|
||||||
def proxy_diagnostic(diag):
|
def proxy_diagnostic(diag):
|
||||||
note = diagnostic.Diagnostic("note",
|
note = diagnostic.Diagnostic("note",
|
||||||
|
@ -238,17 +286,17 @@ class StitchingInferencer(Inferencer):
|
||||||
Inferencer(engine=proxy_engine).visit(ast)
|
Inferencer(engine=proxy_engine).visit(ast)
|
||||||
IntMonomorphizer(engine=proxy_engine).visit(ast)
|
IntMonomorphizer(engine=proxy_engine).visit(ast)
|
||||||
|
|
||||||
if node.attr not in object_type.attributes:
|
if node.attr not in attributes:
|
||||||
# We just figured out what the type should be. Add it.
|
# We just figured out what the type should be. Add it.
|
||||||
object_type.attributes[node.attr] = ast.type
|
attributes[node.attr] = ast.type
|
||||||
elif object_type.attributes[node.attr] != ast.type:
|
elif attributes[node.attr] != ast.type:
|
||||||
# Does this conflict with an earlier guess?
|
# Does this conflict with an earlier guess?
|
||||||
printer = types.TypePrinter()
|
printer = types.TypePrinter()
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"host object has an attribute of type {typea}, which is"
|
"host object has an attribute of type {typea}, which is"
|
||||||
" different from previously inferred type {typeb}",
|
" different from previously inferred type {typeb}",
|
||||||
{"typea": printer.name(ast.type),
|
{"typea": printer.name(ast.type),
|
||||||
"typeb": printer.name(object_type.attributes[node.attr])},
|
"typeb": printer.name(attributes[node.attr])},
|
||||||
object_loc)
|
object_loc)
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
@ -261,11 +309,9 @@ class TypedtreeHasher(algorithm.Visitor):
|
||||||
return self.visit(obj)
|
return self.visit(obj)
|
||||||
elif isinstance(obj, types.Type):
|
elif isinstance(obj, types.Type):
|
||||||
return hash(obj.find())
|
return hash(obj.find())
|
||||||
elif isinstance(obj, list):
|
|
||||||
return tuple(obj)
|
|
||||||
else:
|
else:
|
||||||
assert obj is None or isinstance(obj, (bool, int, float, str))
|
# We don't care; only types change during inference.
|
||||||
return obj
|
pass
|
||||||
|
|
||||||
fields = node._fields
|
fields = node._fields
|
||||||
if hasattr(node, '_types'):
|
if hasattr(node, '_types'):
|
||||||
|
@ -281,6 +327,7 @@ class Stitcher:
|
||||||
|
|
||||||
self.name = ""
|
self.name = ""
|
||||||
self.typedtree = []
|
self.typedtree = []
|
||||||
|
self.inject_at = 0
|
||||||
self.prelude = prelude.globals()
|
self.prelude = prelude.globals()
|
||||||
self.globals = {}
|
self.globals = {}
|
||||||
|
|
||||||
|
@ -290,6 +337,17 @@ class Stitcher:
|
||||||
self.type_map = {}
|
self.type_map = {}
|
||||||
self.value_map = defaultdict(lambda: [])
|
self.value_map = defaultdict(lambda: [])
|
||||||
|
|
||||||
|
def stitch_call(self, function, args, kwargs):
|
||||||
|
function_node = self._quote_embedded_function(function)
|
||||||
|
self.typedtree.append(function_node)
|
||||||
|
|
||||||
|
# We synthesize source code for the initial call so that
|
||||||
|
# diagnostics would have something meaningful to display to the user.
|
||||||
|
synthesizer = self._synthesizer()
|
||||||
|
call_node = synthesizer.call(function_node, args, kwargs)
|
||||||
|
synthesizer.finalize()
|
||||||
|
self.typedtree.append(call_node)
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
inferencer = StitchingInferencer(engine=self.engine,
|
inferencer = StitchingInferencer(engine=self.engine,
|
||||||
value_map=self.value_map,
|
value_map=self.value_map,
|
||||||
|
@ -306,12 +364,50 @@ class Stitcher:
|
||||||
break
|
break
|
||||||
old_typedtree_hash = typedtree_hash
|
old_typedtree_hash = typedtree_hash
|
||||||
|
|
||||||
|
# For every host class we embed, add an appropriate constructor
|
||||||
|
# as a global. This is necessary for method lookup, which uses
|
||||||
|
# the getconstructor instruction.
|
||||||
|
for instance_type, constructor_type in list(self.type_map.values()):
|
||||||
|
# Do we have any direct reference to a constructor?
|
||||||
|
if len(self.value_map[constructor_type]) > 0:
|
||||||
|
# Yes, use it.
|
||||||
|
constructor, _constructor_loc = self.value_map[constructor_type][0]
|
||||||
|
else:
|
||||||
|
# No, extract one from a reference to an instance.
|
||||||
|
instance, _instance_loc = self.value_map[instance_type][0]
|
||||||
|
constructor = type(instance)
|
||||||
|
|
||||||
|
self.globals[constructor_type.name] = constructor_type
|
||||||
|
|
||||||
|
synthesizer = self._synthesizer()
|
||||||
|
ast = synthesizer.assign_local(constructor_type.name, constructor)
|
||||||
|
synthesizer.finalize()
|
||||||
|
self._inject(ast)
|
||||||
|
|
||||||
|
for attr in constructor_type.attributes:
|
||||||
|
if types.is_function(constructor_type.attributes[attr]):
|
||||||
|
synthesizer = self._synthesizer()
|
||||||
|
ast = synthesizer.assign_attribute(constructor, attr,
|
||||||
|
getattr(constructor, attr))
|
||||||
|
synthesizer.finalize()
|
||||||
|
self._inject(ast)
|
||||||
|
|
||||||
# After we have found all functions, synthesize a module to hold them.
|
# After we have found all functions, synthesize a module to hold them.
|
||||||
source_buffer = source.Buffer("", "<synthesized>")
|
source_buffer = source.Buffer("", "<synthesized>")
|
||||||
self.typedtree = asttyped.ModuleT(
|
self.typedtree = asttyped.ModuleT(
|
||||||
typing_env=self.globals, globals_in_scope=set(),
|
typing_env=self.globals, globals_in_scope=set(),
|
||||||
body=self.typedtree, loc=source.Range(source_buffer, 0, 0))
|
body=self.typedtree, loc=source.Range(source_buffer, 0, 0))
|
||||||
|
|
||||||
|
def _inject(self, node):
|
||||||
|
self.typedtree.insert(self.inject_at, node)
|
||||||
|
self.inject_at += 1
|
||||||
|
|
||||||
|
def _synthesizer(self, expanded_from=None):
|
||||||
|
return ASTSynthesizer(expanded_from=expanded_from,
|
||||||
|
type_map=self.type_map,
|
||||||
|
value_map=self.value_map,
|
||||||
|
quote_function=self._quote_function)
|
||||||
|
|
||||||
def _quote_embedded_function(self, function):
|
def _quote_embedded_function(self, function):
|
||||||
if not hasattr(function, "artiq_embedded"):
|
if not hasattr(function, "artiq_embedded"):
|
||||||
raise ValueError("{} is not an embedded function".format(repr(function)))
|
raise ValueError("{} is not an embedded function".format(repr(function)))
|
||||||
|
@ -414,10 +510,7 @@ class Stitcher:
|
||||||
# This is tricky, because the default value might not have
|
# This is tricky, because the default value might not have
|
||||||
# a well-defined type in APython.
|
# a well-defined type in APython.
|
||||||
# In this case, we bail out, but mention why we do it.
|
# In this case, we bail out, but mention why we do it.
|
||||||
synthesizer = ASTSynthesizer(type_map=self.type_map,
|
ast = self._quote(param.default, None)
|
||||||
value_map=self.value_map)
|
|
||||||
ast = synthesizer.quote(param.default)
|
|
||||||
synthesizer.finalize()
|
|
||||||
|
|
||||||
def proxy_diagnostic(diag):
|
def proxy_diagnostic(diag):
|
||||||
note = diagnostic.Diagnostic("note",
|
note = diagnostic.Diagnostic("note",
|
||||||
|
@ -499,11 +592,12 @@ class Stitcher:
|
||||||
self.globals[function_name] = function_type
|
self.globals[function_name] = function_type
|
||||||
self.functions[function] = function_name
|
self.functions[function] = function_name
|
||||||
|
|
||||||
return function_name
|
return function_name, function_type
|
||||||
|
|
||||||
def _quote_function(self, function, loc):
|
def _quote_function(self, function, loc):
|
||||||
if function in self.functions:
|
if function in self.functions:
|
||||||
return self.functions[function]
|
function_name = self.functions[function]
|
||||||
|
return function_name, self.globals[function_name]
|
||||||
|
|
||||||
if hasattr(function, "artiq_embedded"):
|
if hasattr(function, "artiq_embedded"):
|
||||||
if function.artiq_embedded.function is not None:
|
if function.artiq_embedded.function is not None:
|
||||||
|
@ -511,8 +605,8 @@ class Stitcher:
|
||||||
# It doesn't really matter where we insert as long as it is before
|
# It doesn't really matter where we insert as long as it is before
|
||||||
# the final call.
|
# the final call.
|
||||||
function_node = self._quote_embedded_function(function)
|
function_node = self._quote_embedded_function(function)
|
||||||
self.typedtree.insert(0, function_node)
|
self._inject(function_node)
|
||||||
return function_node.name
|
return function_node.name, self.globals[function_node.name]
|
||||||
elif function.artiq_embedded.syscall is not None:
|
elif function.artiq_embedded.syscall is not None:
|
||||||
# Insert a storage-less global whose type instructs the compiler
|
# Insert a storage-less global whose type instructs the compiler
|
||||||
# to perform a system call instead of a regular call.
|
# to perform a system call instead of a regular call.
|
||||||
|
@ -527,31 +621,7 @@ class Stitcher:
|
||||||
syscall=None)
|
syscall=None)
|
||||||
|
|
||||||
def _quote(self, value, loc):
|
def _quote(self, value, loc):
|
||||||
if inspect.isfunction(value) or inspect.ismethod(value):
|
synthesizer = self._synthesizer(loc)
|
||||||
# It's a function. We need to translate the function and insert
|
node = synthesizer.quote(value)
|
||||||
# a reference to it.
|
|
||||||
function_name = self._quote_function(value, loc)
|
|
||||||
return asttyped.NameT(id=function_name, ctx=None,
|
|
||||||
type=self.globals[function_name],
|
|
||||||
loc=loc)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# It's just a value. Quote it.
|
|
||||||
synthesizer = ASTSynthesizer(expanded_from=loc,
|
|
||||||
type_map=self.type_map,
|
|
||||||
value_map=self.value_map)
|
|
||||||
node = synthesizer.quote(value)
|
|
||||||
synthesizer.finalize()
|
|
||||||
return node
|
|
||||||
|
|
||||||
def stitch_call(self, function, args, kwargs):
|
|
||||||
function_node = self._quote_embedded_function(function)
|
|
||||||
self.typedtree.append(function_node)
|
|
||||||
|
|
||||||
# We synthesize source code for the initial call so that
|
|
||||||
# diagnostics would have something meaningful to display to the user.
|
|
||||||
synthesizer = ASTSynthesizer(type_map=self.type_map,
|
|
||||||
value_map=self.value_map)
|
|
||||||
call_node = synthesizer.call(function_node, args, kwargs)
|
|
||||||
synthesizer.finalize()
|
synthesizer.finalize()
|
||||||
self.typedtree.append(call_node)
|
return node
|
||||||
|
|
|
@ -127,8 +127,10 @@ class Inferencer(algorithm.Visitor):
|
||||||
when=" while inferring the type for self argument")
|
when=" while inferring the type for self argument")
|
||||||
|
|
||||||
attr_type = types.TMethod(object_type, attr_type)
|
attr_type = types.TMethod(object_type, attr_type)
|
||||||
self._unify(node.type, attr_type,
|
|
||||||
node.loc, None)
|
if not types.is_var(attr_type):
|
||||||
|
self._unify(node.type, attr_type,
|
||||||
|
node.loc, None)
|
||||||
else:
|
else:
|
||||||
if node.attr_loc.source_buffer == node.value.loc.source_buffer:
|
if node.attr_loc.source_buffer == node.value.loc.source_buffer:
|
||||||
highlights, notes = [node.value.loc], []
|
highlights, notes = [node.value.loc], []
|
||||||
|
|
|
@ -266,7 +266,7 @@ class LLVMIRGenerator:
|
||||||
elif types.is_constructor(typ):
|
elif types.is_constructor(typ):
|
||||||
name = "class.{}".format(typ.name)
|
name = "class.{}".format(typ.name)
|
||||||
else:
|
else:
|
||||||
name = typ.name
|
name = "instance.{}".format(typ.name)
|
||||||
|
|
||||||
llty = self.llcontext.get_identified_type(name)
|
llty = self.llcontext.get_identified_type(name)
|
||||||
if llty.elements is None:
|
if llty.elements is None:
|
||||||
|
@ -991,7 +991,7 @@ class LLVMIRGenerator:
|
||||||
llfields.append(self._quote(getattr(value, attr), typ.attributes[attr],
|
llfields.append(self._quote(getattr(value, attr), typ.attributes[attr],
|
||||||
lambda: path() + [attr]))
|
lambda: path() + [attr]))
|
||||||
|
|
||||||
llvalue = ll.Constant.literal_struct(llfields)
|
llvalue = ll.Constant(llty.pointee, llfields)
|
||||||
llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name)
|
llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name)
|
||||||
llconst.initializer = llvalue
|
llconst.initializer = llvalue
|
||||||
llconst.linkage = "private"
|
llconst.linkage = "private"
|
||||||
|
@ -1012,8 +1012,10 @@ class LLVMIRGenerator:
|
||||||
elif builtins.is_str(typ):
|
elif builtins.is_str(typ):
|
||||||
assert isinstance(value, (str, bytes))
|
assert isinstance(value, (str, bytes))
|
||||||
return self.llstr_of_str(value)
|
return self.llstr_of_str(value)
|
||||||
elif types.is_rpc_function(typ):
|
elif types.is_function(typ):
|
||||||
return ll.Constant.literal_struct([])
|
# RPC and C functions have no runtime representation; ARTIQ
|
||||||
|
# functions are initialized explicitly.
|
||||||
|
return ll.Constant(llty, ll.Undefined)
|
||||||
else:
|
else:
|
||||||
print(typ)
|
print(typ)
|
||||||
assert False
|
assert False
|
||||||
|
|
Loading…
Reference in New Issue