forked from M-Labs/artiq
parent
640022122b
commit
355af3e569
|
@ -22,6 +22,26 @@ from .transforms.asttyped_rewriter import LocalExtractor
|
||||||
def coredevice_print(x): print(x)
|
def coredevice_print(x): print(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SpecializedFunction:
|
||||||
|
def __init__(self, instance_type, host_function):
|
||||||
|
self.instance_type = instance_type
|
||||||
|
self.host_function = host_function
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, tuple):
|
||||||
|
return (self.instance_type == other[0] or
|
||||||
|
self.host_function == other[1])
|
||||||
|
else:
|
||||||
|
return (self.instance_type == other.instance_type or
|
||||||
|
self.host_function == other.host_function)
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not self == other
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.instance_type, self.host_function))
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingMap:
|
class EmbeddingMap:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.object_current_key = 0
|
self.object_current_key = 0
|
||||||
|
@ -31,17 +51,14 @@ class EmbeddingMap:
|
||||||
self.function_map = {}
|
self.function_map = {}
|
||||||
|
|
||||||
# Types
|
# Types
|
||||||
def store_type(self, typ, instance_type, constructor_type):
|
def store_type(self, host_type, instance_type, constructor_type):
|
||||||
self.type_map[typ] = (instance_type, constructor_type)
|
self.type_map[host_type] = (instance_type, constructor_type)
|
||||||
|
|
||||||
def retrieve_type(self, typ):
|
def retrieve_type(self, host_type):
|
||||||
return self.type_map[typ]
|
return self.type_map[host_type]
|
||||||
|
|
||||||
def has_type(self, typ):
|
def has_type(self, host_type):
|
||||||
return typ in self.type_map
|
return host_type in self.type_map
|
||||||
|
|
||||||
def iter_types(self):
|
|
||||||
return self.type_map.values()
|
|
||||||
|
|
||||||
# Functions
|
# Functions
|
||||||
def store_function(self, function, ir_function_name):
|
def store_function(self, function, ir_function_name):
|
||||||
|
@ -50,6 +67,9 @@ class EmbeddingMap:
|
||||||
def retrieve_function(self, function):
|
def retrieve_function(self, function):
|
||||||
return self.function_map[function]
|
return self.function_map[function]
|
||||||
|
|
||||||
|
def specialize_function(self, instance_type, host_function):
|
||||||
|
return SpecializedFunction(instance_type, host_function)
|
||||||
|
|
||||||
# Objects
|
# Objects
|
||||||
def store_object(self, obj_ref):
|
def store_object(self, obj_ref):
|
||||||
obj_id = id(obj_ref)
|
obj_id = id(obj_ref)
|
||||||
|
@ -65,12 +85,22 @@ class EmbeddingMap:
|
||||||
return self.object_forward_map[obj_key]
|
return self.object_forward_map[obj_key]
|
||||||
|
|
||||||
def iter_objects(self):
|
def iter_objects(self):
|
||||||
return self.object_forward_map.keys()
|
for obj_id in self.object_forward_map.keys():
|
||||||
|
obj_ref = self.object_forward_map[obj_id]
|
||||||
|
if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType,
|
||||||
|
pytypes.BuiltinFunctionType, SpecializedFunction)):
|
||||||
|
continue
|
||||||
|
elif isinstance(obj_ref, type):
|
||||||
|
_, obj_typ = self.type_map[obj_ref]
|
||||||
|
else:
|
||||||
|
obj_typ, _ = self.type_map[type(obj_ref)]
|
||||||
|
yield obj_id, obj_ref, obj_typ
|
||||||
|
|
||||||
def has_rpc(self):
|
def has_rpc(self):
|
||||||
return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x),
|
return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x),
|
||||||
self.object_forward_map.values()))
|
self.object_forward_map.values()))
|
||||||
|
|
||||||
|
|
||||||
class ASTSynthesizer:
|
class ASTSynthesizer:
|
||||||
def __init__(self, embedding_map, value_map, quote_function=None, expanded_from=None):
|
def __init__(self, embedding_map, value_map, quote_function=None, expanded_from=None):
|
||||||
self.source = ""
|
self.source = ""
|
||||||
|
@ -128,7 +158,8 @@ class ASTSynthesizer:
|
||||||
begin_loc=begin_loc, end_loc=end_loc,
|
begin_loc=begin_loc, end_loc=end_loc,
|
||||||
loc=begin_loc.join(end_loc))
|
loc=begin_loc.join(end_loc))
|
||||||
elif inspect.isfunction(value) or inspect.ismethod(value) or \
|
elif inspect.isfunction(value) or inspect.ismethod(value) or \
|
||||||
isinstance(value, pytypes.BuiltinFunctionType):
|
isinstance(value, pytypes.BuiltinFunctionType) or \
|
||||||
|
isinstance(value, SpecializedFunction):
|
||||||
if inspect.ismethod(value):
|
if inspect.ismethod(value):
|
||||||
quoted_self = self.quote(value.__self__)
|
quoted_self = self.quote(value.__self__)
|
||||||
function_type = self.quote_function(value.__func__, self.expanded_from)
|
function_type = self.quote_function(value.__func__, self.expanded_from)
|
||||||
|
@ -139,7 +170,7 @@ class ASTSynthesizer:
|
||||||
loc = quoted_self.loc.join(name_loc)
|
loc = quoted_self.loc.join(name_loc)
|
||||||
return asttyped.QuoteT(value=value, type=method_type,
|
return asttyped.QuoteT(value=value, type=method_type,
|
||||||
self_loc=quoted_self.loc, loc=loc)
|
self_loc=quoted_self.loc, loc=loc)
|
||||||
else:
|
else: # function
|
||||||
function_type = self.quote_function(value, self.expanded_from)
|
function_type = self.quote_function(value, self.expanded_from)
|
||||||
|
|
||||||
quote_loc = self._add('`')
|
quote_loc = self._add('`')
|
||||||
|
@ -417,7 +448,7 @@ class StitchingInferencer(Inferencer):
|
||||||
# def f(self): pass
|
# def f(self): pass
|
||||||
# we want f to be defined on the class, not on the instance.
|
# we want f to be defined on the class, not on the instance.
|
||||||
attributes = object_type.constructor.attributes
|
attributes = object_type.constructor.attributes
|
||||||
attr_value = attr_value.__func__
|
attr_value = SpecializedFunction(object_type, attr_value.__func__)
|
||||||
else:
|
else:
|
||||||
attributes = object_type.attributes
|
attributes = object_type.attributes
|
||||||
|
|
||||||
|
@ -582,26 +613,6 @@ class Stitcher:
|
||||||
break
|
break
|
||||||
old_typedtree_hash = typedtree_hash
|
old_typedtree_hash = typedtree_hash
|
||||||
|
|
||||||
# For every host class we embed, fill in the function slots
|
|
||||||
# with their corresponding closures.
|
|
||||||
for instance_type, constructor_type in self.embedding_map.iter_types():
|
|
||||||
# Do we have any direct reference to a constructor?
|
|
||||||
if len(self.value_map[constructor_type]) > 0:
|
|
||||||
# Yes, use it.
|
|
||||||
constructor, _constructor_loc = self.value_map[constructor_type][0]
|
|
||||||
else:
|
|
||||||
# No, extract one from a reference to an instance.
|
|
||||||
instance, _instance_loc = self.value_map[instance_type][0]
|
|
||||||
constructor = type(instance)
|
|
||||||
|
|
||||||
for attr in constructor_type.attributes:
|
|
||||||
if types.is_function(constructor_type.attributes[attr]):
|
|
||||||
synthesizer = self._synthesizer()
|
|
||||||
ast = synthesizer.assign_attribute(constructor, attr,
|
|
||||||
getattr(constructor, attr))
|
|
||||||
synthesizer.finalize()
|
|
||||||
self._inject(ast)
|
|
||||||
|
|
||||||
# After we have found all functions, synthesize a module to hold them.
|
# After we have found all functions, synthesize a module to hold them.
|
||||||
source_buffer = source.Buffer("", "<synthesized>")
|
source_buffer = source.Buffer("", "<synthesized>")
|
||||||
self.typedtree = asttyped.ModuleT(
|
self.typedtree = asttyped.ModuleT(
|
||||||
|
@ -619,11 +630,16 @@ class Stitcher:
|
||||||
quote_function=self._quote_function)
|
quote_function=self._quote_function)
|
||||||
|
|
||||||
def _quote_embedded_function(self, function, flags):
|
def _quote_embedded_function(self, function, flags):
|
||||||
if not hasattr(function, "artiq_embedded"):
|
if isinstance(function, SpecializedFunction):
|
||||||
raise ValueError("{} is not an embedded function".format(repr(function)))
|
host_function = function.host_function
|
||||||
|
else:
|
||||||
|
host_function = function
|
||||||
|
|
||||||
|
if not hasattr(host_function, "artiq_embedded"):
|
||||||
|
raise ValueError("{} is not an embedded function".format(repr(host_function)))
|
||||||
|
|
||||||
# Extract function source.
|
# Extract function source.
|
||||||
embedded_function = function.artiq_embedded.function
|
embedded_function = host_function.artiq_embedded.function
|
||||||
source_code = inspect.getsource(embedded_function)
|
source_code = inspect.getsource(embedded_function)
|
||||||
filename = embedded_function.__code__.co_filename
|
filename = embedded_function.__code__.co_filename
|
||||||
module_name = embedded_function.__globals__['__name__']
|
module_name = embedded_function.__globals__['__name__']
|
||||||
|
@ -652,7 +668,13 @@ class Stitcher:
|
||||||
function_node = parser.file_input().body[0]
|
function_node = parser.file_input().body[0]
|
||||||
|
|
||||||
# Mangle the name, since we put everything into a single module.
|
# Mangle the name, since we put everything into a single module.
|
||||||
function_node.name = "{}.{}".format(module_name, function.__qualname__)
|
full_function_name = "{}.{}".format(module_name, host_function.__qualname__)
|
||||||
|
if isinstance(function, SpecializedFunction):
|
||||||
|
instance_type = function.instance_type
|
||||||
|
function_node.name = "_Z{}{}I{}{}Ezz".format(len(full_function_name), full_function_name,
|
||||||
|
len(instance_type.name), instance_type.name)
|
||||||
|
else:
|
||||||
|
function_node.name = "_Z{}{}zz".format(len(full_function_name), full_function_name)
|
||||||
|
|
||||||
# Record the function in the function map so that LLVM IR generator
|
# Record the function in the function map so that LLVM IR generator
|
||||||
# can handle quoting it.
|
# can handle quoting it.
|
||||||
|
@ -808,64 +830,75 @@ class Stitcher:
|
||||||
return function_type
|
return function_type
|
||||||
|
|
||||||
def _quote_rpc(self, function, loc):
|
def _quote_rpc(self, function, loc):
|
||||||
|
if isinstance(function, SpecializedFunction):
|
||||||
|
host_function = function.host_function
|
||||||
|
else:
|
||||||
|
host_function = function
|
||||||
ret_type = builtins.TNone()
|
ret_type = builtins.TNone()
|
||||||
|
|
||||||
if isinstance(function, pytypes.BuiltinFunctionType):
|
if isinstance(host_function, pytypes.BuiltinFunctionType):
|
||||||
pass
|
pass
|
||||||
elif isinstance(function, pytypes.FunctionType) or isinstance(function, pytypes.MethodType):
|
elif (isinstance(host_function, pytypes.FunctionType) or \
|
||||||
if isinstance(function, pytypes.FunctionType):
|
isinstance(host_function, pytypes.MethodType)):
|
||||||
signature = inspect.signature(function)
|
if isinstance(host_function, pytypes.FunctionType):
|
||||||
|
signature = inspect.signature(host_function)
|
||||||
else:
|
else:
|
||||||
# inspect bug?
|
# inspect bug?
|
||||||
signature = inspect.signature(function.__func__)
|
signature = inspect.signature(host_function.__func__)
|
||||||
if signature.return_annotation is not inspect.Signature.empty:
|
if signature.return_annotation is not inspect.Signature.empty:
|
||||||
ret_type = self._extract_annot(function, signature.return_annotation,
|
ret_type = self._extract_annot(host_function, signature.return_annotation,
|
||||||
"return type", loc, is_syscall=False)
|
"return type", loc, is_syscall=False)
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
function_type = types.TRPC(ret_type, service=self.embedding_map.store_object(function))
|
function_type = types.TRPC(ret_type,
|
||||||
|
service=self.embedding_map.store_object(host_function))
|
||||||
self.functions[function] = function_type
|
self.functions[function] = function_type
|
||||||
return function_type
|
return function_type
|
||||||
|
|
||||||
def _quote_function(self, function, loc):
|
def _quote_function(self, function, loc):
|
||||||
|
if isinstance(function, SpecializedFunction):
|
||||||
|
host_function = function.host_function
|
||||||
|
else:
|
||||||
|
host_function = function
|
||||||
|
|
||||||
if function in self.functions:
|
if function in self.functions:
|
||||||
pass
|
pass
|
||||||
elif not hasattr(function, "artiq_embedded"):
|
elif not hasattr(host_function, "artiq_embedded"):
|
||||||
self._quote_rpc(function, loc)
|
self._quote_rpc(function, loc)
|
||||||
elif function.artiq_embedded.function is not None:
|
elif host_function.artiq_embedded.function is not None:
|
||||||
if function.__name__ == "<lambda>":
|
if host_function.__name__ == "<lambda>":
|
||||||
note = diagnostic.Diagnostic("note",
|
note = diagnostic.Diagnostic("note",
|
||||||
"lambda created here", {},
|
"lambda created here", {},
|
||||||
self._function_loc(function.artiq_embedded.function))
|
self._function_loc(host_function.artiq_embedded.function))
|
||||||
diag = diagnostic.Diagnostic("fatal",
|
diag = diagnostic.Diagnostic("fatal",
|
||||||
"lambdas cannot be used as kernel functions", {},
|
"lambdas cannot be used as kernel functions", {},
|
||||||
loc,
|
loc,
|
||||||
notes=[note])
|
notes=[note])
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
core_name = function.artiq_embedded.core_name
|
core_name = host_function.artiq_embedded.core_name
|
||||||
if core_name is not None and self.dmgr.get(core_name) != self.core:
|
if core_name is not None and self.dmgr.get(core_name) != self.core:
|
||||||
note = diagnostic.Diagnostic("note",
|
note = diagnostic.Diagnostic("note",
|
||||||
"called from this function", {},
|
"called from this function", {},
|
||||||
loc)
|
loc)
|
||||||
diag = diagnostic.Diagnostic("fatal",
|
diag = diagnostic.Diagnostic("fatal",
|
||||||
"this function runs on a different core device '{name}'",
|
"this function runs on a different core device '{name}'",
|
||||||
{"name": function.artiq_embedded.core_name},
|
{"name": host_function.artiq_embedded.core_name},
|
||||||
self._function_loc(function.artiq_embedded.function),
|
self._function_loc(host_function.artiq_embedded.function),
|
||||||
notes=[note])
|
notes=[note])
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
self._quote_embedded_function(function,
|
self._quote_embedded_function(function,
|
||||||
flags=function.artiq_embedded.flags)
|
flags=host_function.artiq_embedded.flags)
|
||||||
elif function.artiq_embedded.syscall is not None:
|
elif host_function.artiq_embedded.syscall is not None:
|
||||||
# Insert a storage-less global whose type instructs the compiler
|
# Insert a storage-less global whose type instructs the compiler
|
||||||
# to perform a system call instead of a regular call.
|
# to perform a system call instead of a regular call.
|
||||||
self._quote_syscall(function, loc)
|
self._quote_syscall(function, loc)
|
||||||
elif function.artiq_embedded.forbidden is not None:
|
elif host_function.artiq_embedded.forbidden is not None:
|
||||||
diag = diagnostic.Diagnostic("fatal",
|
diag = diagnostic.Diagnostic("fatal",
|
||||||
"this function cannot be called as an RPC", {},
|
"this function cannot be called as an RPC", {},
|
||||||
self._function_loc(function),
|
self._function_loc(host_function),
|
||||||
notes=self._call_site_note(loc, is_syscall=True))
|
notes=self._call_site_note(loc, is_syscall=True))
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -891,20 +891,27 @@ class Inferencer(algorithm.Visitor):
|
||||||
typ_optargs = typ.optargs
|
typ_optargs = typ.optargs
|
||||||
typ_ret = typ.ret
|
typ_ret = typ.ret
|
||||||
else:
|
else:
|
||||||
typ = types.get_method_function(typ)
|
typ_self = types.get_method_self(typ)
|
||||||
if types.is_var(typ):
|
typ_func = types.get_method_function(typ)
|
||||||
|
if types.is_var(typ_func):
|
||||||
return # not enough info yet
|
return # not enough info yet
|
||||||
elif types.is_rpc(typ):
|
elif types.is_rpc(typ_func):
|
||||||
self._unify(node.type, typ.ret,
|
self._unify(node.type, typ_func.ret,
|
||||||
node.loc, None)
|
node.loc, None)
|
||||||
return
|
return
|
||||||
elif typ.arity() == 0:
|
elif typ_func.arity() == 0:
|
||||||
return # error elsewhere
|
return # error elsewhere
|
||||||
|
|
||||||
typ_arity = typ.arity() - 1
|
method_args = list(typ_func.args.items())
|
||||||
typ_args = OrderedDict(list(typ.args.items())[1:])
|
|
||||||
typ_optargs = typ.optargs
|
self_arg_name, self_arg_type = method_args[0]
|
||||||
typ_ret = typ.ret
|
self._unify(self_arg_type, typ_self,
|
||||||
|
node.loc, None)
|
||||||
|
|
||||||
|
typ_arity = typ_func.arity() - 1
|
||||||
|
typ_args = OrderedDict(method_args[1:])
|
||||||
|
typ_optargs = typ_func.optargs
|
||||||
|
typ_ret = typ_func.ret
|
||||||
|
|
||||||
passed_args = dict()
|
passed_args = dict()
|
||||||
|
|
||||||
|
|
|
@ -378,11 +378,19 @@ class LLVMIRGenerator:
|
||||||
llfunty = self.llty_of_type(typ, bare=True)
|
llfunty = self.llty_of_type(typ, bare=True)
|
||||||
llfun = ll.Function(self.llmodule, llfunty, name)
|
llfun = ll.Function(self.llmodule, llfunty, name)
|
||||||
|
|
||||||
llretty = self.llty_of_type(typ.ret, for_return=True)
|
llretty = self.llty_of_type(typ.find().ret, for_return=True)
|
||||||
if self.needs_sret(llretty):
|
if self.needs_sret(llretty):
|
||||||
llfun.args[0].add_attribute('sret')
|
llfun.args[0].add_attribute('sret')
|
||||||
return llfun
|
return llfun
|
||||||
|
|
||||||
|
def get_function_with_undef_env(self, typ, name):
|
||||||
|
llfun = self.get_function(typ, name)
|
||||||
|
llclosure = ll.Constant(self.llty_of_type(typ), [
|
||||||
|
ll.Constant(llptr, ll.Undefined),
|
||||||
|
llfun
|
||||||
|
])
|
||||||
|
return llclosure
|
||||||
|
|
||||||
def map(self, value):
|
def map(self, value):
|
||||||
if isinstance(value, (ir.Argument, ir.Instruction, ir.BasicBlock)):
|
if isinstance(value, (ir.Argument, ir.Instruction, ir.BasicBlock)):
|
||||||
return self.llmap[value]
|
return self.llmap[value]
|
||||||
|
@ -408,19 +416,10 @@ class LLVMIRGenerator:
|
||||||
def emit_attribute_writeback(self):
|
def emit_attribute_writeback(self):
|
||||||
llobjects = defaultdict(lambda: [])
|
llobjects = defaultdict(lambda: [])
|
||||||
|
|
||||||
for obj_id in self.embedding_map.iter_objects():
|
for obj_id, obj_ref, obj_typ in self.embedding_map.iter_objects():
|
||||||
obj_ref = self.embedding_map.retrieve_object(obj_id)
|
|
||||||
if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType,
|
|
||||||
pytypes.BuiltinFunctionType)):
|
|
||||||
continue
|
|
||||||
elif isinstance(obj_ref, type):
|
|
||||||
_, typ = self.embedding_map.retrieve_type(obj_ref)
|
|
||||||
else:
|
|
||||||
typ, _ = self.embedding_map.retrieve_type(type(obj_ref))
|
|
||||||
|
|
||||||
llobject = self.llmodule.get_global("O.{}".format(obj_id))
|
llobject = self.llmodule.get_global("O.{}".format(obj_id))
|
||||||
if llobject is not None:
|
if llobject is not None:
|
||||||
llobjects[typ].append(llobject.bitcast(llptr))
|
llobjects[obj_typ].append(llobject.bitcast(llptr))
|
||||||
|
|
||||||
llrpcattrty = self.llcontext.get_identified_type("A")
|
llrpcattrty = self.llcontext.get_identified_type("A")
|
||||||
llrpcattrty.elements = [lli32, llptr, llptr]
|
llrpcattrty.elements = [lli32, llptr, llptr]
|
||||||
|
@ -695,8 +694,8 @@ class LLVMIRGenerator:
|
||||||
llglobal = self.llmodule.get_global(name)
|
llglobal = self.llmodule.get_global(name)
|
||||||
else:
|
else:
|
||||||
llglobal = ll.GlobalVariable(self.llmodule, llty, name)
|
llglobal = ll.GlobalVariable(self.llmodule, llty, name)
|
||||||
if llvalue is not None:
|
|
||||||
llglobal.linkage = "private"
|
llglobal.linkage = "private"
|
||||||
|
if llvalue is not None:
|
||||||
llglobal.initializer = llvalue
|
llglobal.initializer = llvalue
|
||||||
return llglobal
|
return llglobal
|
||||||
|
|
||||||
|
@ -705,7 +704,7 @@ class LLVMIRGenerator:
|
||||||
llty = self.llty_of_type(typ).pointee
|
llty = self.llty_of_type(typ).pointee
|
||||||
return self.get_or_define_global("C.{}".format(typ.name), llty)
|
return self.get_or_define_global("C.{}".format(typ.name), llty)
|
||||||
|
|
||||||
def get_global_closure(self, typ, attr):
|
def get_global_closure_ptr(self, typ, attr):
|
||||||
closure_type = typ.attributes[attr]
|
closure_type = typ.attributes[attr]
|
||||||
assert types.is_constructor(typ)
|
assert types.is_constructor(typ)
|
||||||
assert types.is_function(closure_type) or types.is_rpc(closure_type)
|
assert types.is_function(closure_type) or types.is_rpc(closure_type)
|
||||||
|
@ -713,7 +712,13 @@ class LLVMIRGenerator:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
llty = self.llty_of_type(typ.attributes[attr])
|
llty = self.llty_of_type(typ.attributes[attr])
|
||||||
llclosureptr = self.get_or_define_global("F.{}.{}".format(typ.name, attr), llty)
|
return self.get_or_define_global("F.{}.{}".format(typ.name, attr), llty)
|
||||||
|
|
||||||
|
def get_global_closure(self, typ, attr):
|
||||||
|
llclosureptr = self.get_global_closure_ptr(typ, attr)
|
||||||
|
if llclosureptr is None:
|
||||||
|
return None
|
||||||
|
|
||||||
# LLVM's GlobalOpt pass only considers for SROA the globals that
|
# LLVM's GlobalOpt pass only considers for SROA the globals that
|
||||||
# are used only by GEPs, so we have to do this stupid hack.
|
# are used only by GEPs, so we have to do this stupid hack.
|
||||||
llenvptr = self.llbuilder.gep(llclosureptr, [self.llindex(0), self.llindex(0)])
|
llenvptr = self.llbuilder.gep(llclosureptr, [self.llindex(0), self.llindex(0)])
|
||||||
|
@ -721,12 +726,12 @@ class LLVMIRGenerator:
|
||||||
return [llenvptr, llfunptr]
|
return [llenvptr, llfunptr]
|
||||||
|
|
||||||
def load_closure(self, typ, attr):
|
def load_closure(self, typ, attr):
|
||||||
llclosureptrs = self.get_global_closure(typ, attr)
|
llclosureparts = self.get_global_closure(typ, attr)
|
||||||
if llclosureptrs is None:
|
if llclosureparts is None:
|
||||||
return ll.Constant(llunit, [])
|
return ll.Constant(llunit, [])
|
||||||
|
|
||||||
# See above.
|
# See above.
|
||||||
llenvptr, llfunptr = llclosureptrs
|
llenvptr, llfunptr = llclosureparts
|
||||||
llenv = self.llbuilder.load(llenvptr)
|
llenv = self.llbuilder.load(llenvptr)
|
||||||
llfun = self.llbuilder.load(llfunptr)
|
llfun = self.llbuilder.load(llfunptr)
|
||||||
llclosure = ll.Constant(ll.LiteralStructType([llenv.type, llfun.type]), ll.Undefined)
|
llclosure = ll.Constant(ll.LiteralStructType([llenv.type, llfun.type]), ll.Undefined)
|
||||||
|
@ -735,10 +740,10 @@ class LLVMIRGenerator:
|
||||||
return llclosure
|
return llclosure
|
||||||
|
|
||||||
def store_closure(self, llclosure, typ, attr):
|
def store_closure(self, llclosure, typ, attr):
|
||||||
llclosureptrs = self.get_global_closure(typ, attr)
|
llclosureparts = self.get_global_closure(typ, attr)
|
||||||
assert llclosureptrs is not None
|
assert llclosureparts is not None
|
||||||
|
|
||||||
llenvptr, llfunptr = llclosureptrs
|
llenvptr, llfunptr = llclosureparts
|
||||||
llenv = self.llbuilder.extract_value(llclosure, 0)
|
llenv = self.llbuilder.extract_value(llclosure, 0)
|
||||||
llfun = self.llbuilder.extract_value(llclosure, 1)
|
llfun = self.llbuilder.extract_value(llclosure, 1)
|
||||||
self.llbuilder.store(llenv, llenvptr)
|
self.llbuilder.store(llenv, llenvptr)
|
||||||
|
@ -1343,6 +1348,12 @@ class LLVMIRGenerator:
|
||||||
|
|
||||||
llty = self.llty_of_type(typ)
|
llty = self.llty_of_type(typ)
|
||||||
if types.is_constructor(typ) or types.is_instance(typ):
|
if types.is_constructor(typ) or types.is_instance(typ):
|
||||||
|
if types.is_instance(typ):
|
||||||
|
# Make sure the class functions are quoted, as this has the side effect of
|
||||||
|
# initializing the global closures.
|
||||||
|
self._quote(type(value), typ.constructor,
|
||||||
|
lambda: path() + ['__class__'])
|
||||||
|
|
||||||
llglobal = None
|
llglobal = None
|
||||||
llfields = []
|
llfields = []
|
||||||
for attr in typ.attributes:
|
for attr in typ.attributes:
|
||||||
|
@ -1359,8 +1370,18 @@ class LLVMIRGenerator:
|
||||||
|
|
||||||
self.llobject_map[value_id] = llglobal
|
self.llobject_map[value_id] = llglobal
|
||||||
else:
|
else:
|
||||||
llfields.append(self._quote(getattr(value, attr), typ.attributes[attr],
|
attrvalue = getattr(value, attr)
|
||||||
lambda: path() + [attr]))
|
is_class_function = (types.is_constructor(typ) and
|
||||||
|
types.is_function(typ.attributes[attr]) and
|
||||||
|
not types.is_c_function(typ.attributes[attr]))
|
||||||
|
if is_class_function:
|
||||||
|
attrvalue = self.embedding_map.specialize_function(typ.instance, attrvalue)
|
||||||
|
llattrvalue = self._quote(attrvalue, typ.attributes[attr],
|
||||||
|
lambda: path() + [attr])
|
||||||
|
llfields.append(llattrvalue)
|
||||||
|
if is_class_function:
|
||||||
|
llclosureptr = self.get_global_closure_ptr(typ, attr)
|
||||||
|
llclosureptr.initializer = llattrvalue
|
||||||
|
|
||||||
llglobal.initializer = ll.Constant(llty.pointee, llfields)
|
llglobal.initializer = ll.Constant(llty.pointee, llfields)
|
||||||
llglobal.linkage = "private"
|
llglobal.linkage = "private"
|
||||||
|
@ -1400,12 +1421,8 @@ class LLVMIRGenerator:
|
||||||
# RPC and C functions have no runtime representation.
|
# RPC and C functions have no runtime representation.
|
||||||
return ll.Constant(llty, ll.Undefined)
|
return ll.Constant(llty, ll.Undefined)
|
||||||
elif types.is_function(typ):
|
elif types.is_function(typ):
|
||||||
llfun = self.get_function(typ.find(), self.embedding_map.retrieve_function(value))
|
return self.get_function_with_undef_env(typ.find(),
|
||||||
llclosure = ll.Constant(self.llty_of_type(typ), [
|
self.embedding_map.retrieve_function(value))
|
||||||
ll.Constant(llptr, ll.Undefined),
|
|
||||||
llfun
|
|
||||||
])
|
|
||||||
return llclosure
|
|
||||||
elif types.is_method(typ):
|
elif types.is_method(typ):
|
||||||
llclosure = self._quote(value.__func__, types.get_method_function(typ),
|
llclosure = self._quote(value.__func__, types.get_method_function(typ),
|
||||||
lambda: path() + ['__func__'])
|
lambda: path() + ['__func__'])
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.embedding %s
|
||||||
|
|
||||||
|
from artiq.language.core import *
|
||||||
|
from artiq.language.types import *
|
||||||
|
|
||||||
|
class a:
|
||||||
|
@kernel
|
||||||
|
def f(self):
|
||||||
|
print(self.x)
|
||||||
|
return None
|
||||||
|
|
||||||
|
class b(a):
|
||||||
|
x = 1
|
||||||
|
class c(a):
|
||||||
|
x = 2
|
||||||
|
|
||||||
|
bi = b()
|
||||||
|
ci = c()
|
||||||
|
@kernel
|
||||||
|
def entrypoint():
|
||||||
|
bi.f()
|
||||||
|
ci.f()
|
Loading…
Reference in New Issue