diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 6727cf5d3..491a2052e 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -22,6 +22,26 @@ from .transforms.asttyped_rewriter import LocalExtractor def coredevice_print(x): print(x) +class SpecializedFunction: + def __init__(self, instance_type, host_function): + self.instance_type = instance_type + self.host_function = host_function + + def __eq__(self, other): + if isinstance(other, tuple): + return (self.instance_type == other[0] or + self.host_function == other[1]) + else: + return (self.instance_type == other.instance_type or + self.host_function == other.host_function) + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((self.instance_type, self.host_function)) + + class EmbeddingMap: def __init__(self): self.object_current_key = 0 @@ -31,17 +51,14 @@ class EmbeddingMap: self.function_map = {} # Types - def store_type(self, typ, instance_type, constructor_type): - self.type_map[typ] = (instance_type, constructor_type) + def store_type(self, host_type, instance_type, constructor_type): + self.type_map[host_type] = (instance_type, constructor_type) - def retrieve_type(self, typ): - return self.type_map[typ] + def retrieve_type(self, host_type): + return self.type_map[host_type] - def has_type(self, typ): - return typ in self.type_map - - def iter_types(self): - return self.type_map.values() + def has_type(self, host_type): + return host_type in self.type_map # Functions def store_function(self, function, ir_function_name): @@ -50,6 +67,9 @@ class EmbeddingMap: def retrieve_function(self, function): return self.function_map[function] + def specialize_function(self, instance_type, host_function): + return SpecializedFunction(instance_type, host_function) + # Objects def store_object(self, obj_ref): obj_id = id(obj_ref) @@ -65,12 +85,22 @@ class EmbeddingMap: return self.object_forward_map[obj_key] def iter_objects(self): - return self.object_forward_map.keys() + for obj_id in self.object_forward_map.keys(): + obj_ref = self.object_forward_map[obj_id] + if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType, + pytypes.BuiltinFunctionType, SpecializedFunction)): + continue + elif isinstance(obj_ref, type): + _, obj_typ = self.type_map[obj_ref] + else: + obj_typ, _ = self.type_map[type(obj_ref)] + yield obj_id, obj_ref, obj_typ def has_rpc(self): return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x), self.object_forward_map.values())) + class ASTSynthesizer: def __init__(self, embedding_map, value_map, quote_function=None, expanded_from=None): self.source = "" @@ -128,7 +158,8 @@ class ASTSynthesizer: begin_loc=begin_loc, end_loc=end_loc, loc=begin_loc.join(end_loc)) elif inspect.isfunction(value) or inspect.ismethod(value) or \ - isinstance(value, pytypes.BuiltinFunctionType): + isinstance(value, pytypes.BuiltinFunctionType) or \ + isinstance(value, SpecializedFunction): if inspect.ismethod(value): quoted_self = self.quote(value.__self__) function_type = self.quote_function(value.__func__, self.expanded_from) @@ -139,7 +170,7 @@ class ASTSynthesizer: loc = quoted_self.loc.join(name_loc) return asttyped.QuoteT(value=value, type=method_type, self_loc=quoted_self.loc, loc=loc) - else: + else: # function function_type = self.quote_function(value, self.expanded_from) quote_loc = self._add('`') @@ -417,7 +448,7 @@ class StitchingInferencer(Inferencer): # def f(self): pass # we want f to be defined on the class, not on the instance. attributes = object_type.constructor.attributes - attr_value = attr_value.__func__ + attr_value = SpecializedFunction(object_type, attr_value.__func__) else: attributes = object_type.attributes @@ -582,26 +613,6 @@ class Stitcher: break old_typedtree_hash = typedtree_hash - # For every host class we embed, fill in the function slots - # with their corresponding closures. - for instance_type, constructor_type in self.embedding_map.iter_types(): - # Do we have any direct reference to a constructor? - if len(self.value_map[constructor_type]) > 0: - # Yes, use it. - constructor, _constructor_loc = self.value_map[constructor_type][0] - else: - # No, extract one from a reference to an instance. - instance, _instance_loc = self.value_map[instance_type][0] - constructor = type(instance) - - for attr in constructor_type.attributes: - if types.is_function(constructor_type.attributes[attr]): - synthesizer = self._synthesizer() - ast = synthesizer.assign_attribute(constructor, attr, - getattr(constructor, attr)) - synthesizer.finalize() - self._inject(ast) - # After we have found all functions, synthesize a module to hold them. source_buffer = source.Buffer("", "") self.typedtree = asttyped.ModuleT( @@ -619,11 +630,16 @@ class Stitcher: quote_function=self._quote_function) def _quote_embedded_function(self, function, flags): - if not hasattr(function, "artiq_embedded"): - raise ValueError("{} is not an embedded function".format(repr(function))) + if isinstance(function, SpecializedFunction): + host_function = function.host_function + else: + host_function = function + + if not hasattr(host_function, "artiq_embedded"): + raise ValueError("{} is not an embedded function".format(repr(host_function))) # Extract function source. - embedded_function = function.artiq_embedded.function + embedded_function = host_function.artiq_embedded.function source_code = inspect.getsource(embedded_function) filename = embedded_function.__code__.co_filename module_name = embedded_function.__globals__['__name__'] @@ -652,7 +668,13 @@ class Stitcher: function_node = parser.file_input().body[0] # Mangle the name, since we put everything into a single module. - function_node.name = "{}.{}".format(module_name, function.__qualname__) + full_function_name = "{}.{}".format(module_name, host_function.__qualname__) + if isinstance(function, SpecializedFunction): + instance_type = function.instance_type + function_node.name = "_Z{}{}I{}{}Ezz".format(len(full_function_name), full_function_name, + len(instance_type.name), instance_type.name) + else: + function_node.name = "_Z{}{}zz".format(len(full_function_name), full_function_name) # Record the function in the function map so that LLVM IR generator # can handle quoting it. @@ -808,64 +830,75 @@ class Stitcher: return function_type def _quote_rpc(self, function, loc): + if isinstance(function, SpecializedFunction): + host_function = function.host_function + else: + host_function = function ret_type = builtins.TNone() - if isinstance(function, pytypes.BuiltinFunctionType): + if isinstance(host_function, pytypes.BuiltinFunctionType): pass - elif isinstance(function, pytypes.FunctionType) or isinstance(function, pytypes.MethodType): - if isinstance(function, pytypes.FunctionType): - signature = inspect.signature(function) + elif (isinstance(host_function, pytypes.FunctionType) or \ + isinstance(host_function, pytypes.MethodType)): + if isinstance(host_function, pytypes.FunctionType): + signature = inspect.signature(host_function) else: # inspect bug? - signature = inspect.signature(function.__func__) + signature = inspect.signature(host_function.__func__) if signature.return_annotation is not inspect.Signature.empty: - ret_type = self._extract_annot(function, signature.return_annotation, + ret_type = self._extract_annot(host_function, signature.return_annotation, "return type", loc, is_syscall=False) else: assert False - function_type = types.TRPC(ret_type, service=self.embedding_map.store_object(function)) + function_type = types.TRPC(ret_type, + service=self.embedding_map.store_object(host_function)) self.functions[function] = function_type return function_type def _quote_function(self, function, loc): + if isinstance(function, SpecializedFunction): + host_function = function.host_function + else: + host_function = function + if function in self.functions: pass - elif not hasattr(function, "artiq_embedded"): + elif not hasattr(host_function, "artiq_embedded"): self._quote_rpc(function, loc) - elif function.artiq_embedded.function is not None: - if function.__name__ == "": + elif host_function.artiq_embedded.function is not None: + if host_function.__name__ == "": note = diagnostic.Diagnostic("note", "lambda created here", {}, - self._function_loc(function.artiq_embedded.function)) + self._function_loc(host_function.artiq_embedded.function)) diag = diagnostic.Diagnostic("fatal", "lambdas cannot be used as kernel functions", {}, loc, notes=[note]) self.engine.process(diag) - core_name = function.artiq_embedded.core_name + core_name = host_function.artiq_embedded.core_name if core_name is not None and self.dmgr.get(core_name) != self.core: note = diagnostic.Diagnostic("note", "called from this function", {}, loc) diag = diagnostic.Diagnostic("fatal", "this function runs on a different core device '{name}'", - {"name": function.artiq_embedded.core_name}, - self._function_loc(function.artiq_embedded.function), + {"name": host_function.artiq_embedded.core_name}, + self._function_loc(host_function.artiq_embedded.function), notes=[note]) self.engine.process(diag) self._quote_embedded_function(function, - flags=function.artiq_embedded.flags) - elif function.artiq_embedded.syscall is not None: + flags=host_function.artiq_embedded.flags) + elif host_function.artiq_embedded.syscall is not None: # Insert a storage-less global whose type instructs the compiler # to perform a system call instead of a regular call. self._quote_syscall(function, loc) - elif function.artiq_embedded.forbidden is not None: + elif host_function.artiq_embedded.forbidden is not None: diag = diagnostic.Diagnostic("fatal", "this function cannot be called as an RPC", {}, - self._function_loc(function), + self._function_loc(host_function), notes=self._call_site_note(loc, is_syscall=True)) self.engine.process(diag) else: diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 50fb7a298..d0216cd50 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -891,20 +891,27 @@ class Inferencer(algorithm.Visitor): typ_optargs = typ.optargs typ_ret = typ.ret else: - typ = types.get_method_function(typ) - if types.is_var(typ): + typ_self = types.get_method_self(typ) + typ_func = types.get_method_function(typ) + if types.is_var(typ_func): return # not enough info yet - elif types.is_rpc(typ): - self._unify(node.type, typ.ret, + elif types.is_rpc(typ_func): + self._unify(node.type, typ_func.ret, node.loc, None) return - elif typ.arity() == 0: + elif typ_func.arity() == 0: return # error elsewhere - typ_arity = typ.arity() - 1 - typ_args = OrderedDict(list(typ.args.items())[1:]) - typ_optargs = typ.optargs - typ_ret = typ.ret + method_args = list(typ_func.args.items()) + + self_arg_name, self_arg_type = method_args[0] + self._unify(self_arg_type, typ_self, + node.loc, None) + + typ_arity = typ_func.arity() - 1 + typ_args = OrderedDict(method_args[1:]) + typ_optargs = typ_func.optargs + typ_ret = typ_func.ret passed_args = dict() diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index a92378c55..9f1eb25d9 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -378,11 +378,19 @@ class LLVMIRGenerator: llfunty = self.llty_of_type(typ, bare=True) llfun = ll.Function(self.llmodule, llfunty, name) - llretty = self.llty_of_type(typ.ret, for_return=True) + llretty = self.llty_of_type(typ.find().ret, for_return=True) if self.needs_sret(llretty): llfun.args[0].add_attribute('sret') return llfun + def get_function_with_undef_env(self, typ, name): + llfun = self.get_function(typ, name) + llclosure = ll.Constant(self.llty_of_type(typ), [ + ll.Constant(llptr, ll.Undefined), + llfun + ]) + return llclosure + def map(self, value): if isinstance(value, (ir.Argument, ir.Instruction, ir.BasicBlock)): return self.llmap[value] @@ -408,19 +416,10 @@ class LLVMIRGenerator: def emit_attribute_writeback(self): llobjects = defaultdict(lambda: []) - for obj_id in self.embedding_map.iter_objects(): - obj_ref = self.embedding_map.retrieve_object(obj_id) - if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType, - pytypes.BuiltinFunctionType)): - continue - elif isinstance(obj_ref, type): - _, typ = self.embedding_map.retrieve_type(obj_ref) - else: - typ, _ = self.embedding_map.retrieve_type(type(obj_ref)) - + for obj_id, obj_ref, obj_typ in self.embedding_map.iter_objects(): llobject = self.llmodule.get_global("O.{}".format(obj_id)) if llobject is not None: - llobjects[typ].append(llobject.bitcast(llptr)) + llobjects[obj_typ].append(llobject.bitcast(llptr)) llrpcattrty = self.llcontext.get_identified_type("A") llrpcattrty.elements = [lli32, llptr, llptr] @@ -695,8 +694,8 @@ class LLVMIRGenerator: llglobal = self.llmodule.get_global(name) else: llglobal = ll.GlobalVariable(self.llmodule, llty, name) + llglobal.linkage = "private" if llvalue is not None: - llglobal.linkage = "private" llglobal.initializer = llvalue return llglobal @@ -705,7 +704,7 @@ class LLVMIRGenerator: llty = self.llty_of_type(typ).pointee return self.get_or_define_global("C.{}".format(typ.name), llty) - def get_global_closure(self, typ, attr): + def get_global_closure_ptr(self, typ, attr): closure_type = typ.attributes[attr] assert types.is_constructor(typ) assert types.is_function(closure_type) or types.is_rpc(closure_type) @@ -713,7 +712,13 @@ class LLVMIRGenerator: return None llty = self.llty_of_type(typ.attributes[attr]) - llclosureptr = self.get_or_define_global("F.{}.{}".format(typ.name, attr), llty) + return self.get_or_define_global("F.{}.{}".format(typ.name, attr), llty) + + def get_global_closure(self, typ, attr): + llclosureptr = self.get_global_closure_ptr(typ, attr) + if llclosureptr is None: + return None + # LLVM's GlobalOpt pass only considers for SROA the globals that # are used only by GEPs, so we have to do this stupid hack. llenvptr = self.llbuilder.gep(llclosureptr, [self.llindex(0), self.llindex(0)]) @@ -721,12 +726,12 @@ class LLVMIRGenerator: return [llenvptr, llfunptr] def load_closure(self, typ, attr): - llclosureptrs = self.get_global_closure(typ, attr) - if llclosureptrs is None: + llclosureparts = self.get_global_closure(typ, attr) + if llclosureparts is None: return ll.Constant(llunit, []) # See above. - llenvptr, llfunptr = llclosureptrs + llenvptr, llfunptr = llclosureparts llenv = self.llbuilder.load(llenvptr) llfun = self.llbuilder.load(llfunptr) llclosure = ll.Constant(ll.LiteralStructType([llenv.type, llfun.type]), ll.Undefined) @@ -735,10 +740,10 @@ class LLVMIRGenerator: return llclosure def store_closure(self, llclosure, typ, attr): - llclosureptrs = self.get_global_closure(typ, attr) - assert llclosureptrs is not None + llclosureparts = self.get_global_closure(typ, attr) + assert llclosureparts is not None - llenvptr, llfunptr = llclosureptrs + llenvptr, llfunptr = llclosureparts llenv = self.llbuilder.extract_value(llclosure, 0) llfun = self.llbuilder.extract_value(llclosure, 1) self.llbuilder.store(llenv, llenvptr) @@ -1343,6 +1348,12 @@ class LLVMIRGenerator: llty = self.llty_of_type(typ) if types.is_constructor(typ) or types.is_instance(typ): + if types.is_instance(typ): + # Make sure the class functions are quoted, as this has the side effect of + # initializing the global closures. + self._quote(type(value), typ.constructor, + lambda: path() + ['__class__']) + llglobal = None llfields = [] for attr in typ.attributes: @@ -1359,8 +1370,18 @@ class LLVMIRGenerator: self.llobject_map[value_id] = llglobal else: - llfields.append(self._quote(getattr(value, attr), typ.attributes[attr], - lambda: path() + [attr])) + attrvalue = getattr(value, attr) + is_class_function = (types.is_constructor(typ) and + types.is_function(typ.attributes[attr]) and + not types.is_c_function(typ.attributes[attr])) + if is_class_function: + attrvalue = self.embedding_map.specialize_function(typ.instance, attrvalue) + llattrvalue = self._quote(attrvalue, typ.attributes[attr], + lambda: path() + [attr]) + llfields.append(llattrvalue) + if is_class_function: + llclosureptr = self.get_global_closure_ptr(typ, attr) + llclosureptr.initializer = llattrvalue llglobal.initializer = ll.Constant(llty.pointee, llfields) llglobal.linkage = "private" @@ -1400,12 +1421,8 @@ class LLVMIRGenerator: # RPC and C functions have no runtime representation. return ll.Constant(llty, ll.Undefined) elif types.is_function(typ): - llfun = self.get_function(typ.find(), self.embedding_map.retrieve_function(value)) - llclosure = ll.Constant(self.llty_of_type(typ), [ - ll.Constant(llptr, ll.Undefined), - llfun - ]) - return llclosure + return self.get_function_with_undef_env(typ.find(), + self.embedding_map.retrieve_function(value)) elif types.is_method(typ): llclosure = self._quote(value.__func__, types.get_method_function(typ), lambda: path() + ['__func__']) diff --git a/artiq/test/lit/embedding/inheritance.py b/artiq/test/lit/embedding/inheritance.py new file mode 100644 index 000000000..0863d9163 --- /dev/null +++ b/artiq/test/lit/embedding/inheritance.py @@ -0,0 +1,22 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s + +from artiq.language.core import * +from artiq.language.types import * + +class a: + @kernel + def f(self): + print(self.x) + return None + +class b(a): + x = 1 +class c(a): + x = 2 + +bi = b() +ci = c() +@kernel +def entrypoint(): + bi.f() + ci.f()