From c21387dc0926553174461fec76ea957eefbb8779 Mon Sep 17 00:00:00 2001 From: whitequark Date: Thu, 27 Aug 2015 19:46:50 -0500 Subject: [PATCH] compiler.embedding: support calling methods marked as @kernel. --- artiq/compiler/embedding.py | 160 +++++++++++++----- artiq/compiler/transforms/inferencer.py | 6 +- .../compiler/transforms/llvm_ir_generator.py | 10 +- 3 files changed, 125 insertions(+), 51 deletions(-) diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 33cf2b9d0..c93a49f7b 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -34,10 +34,11 @@ class ObjectMap: return self.forward_map[obj_key] class ASTSynthesizer: - def __init__(self, type_map, value_map, expanded_from=None): + def __init__(self, type_map, value_map, quote_function=None, expanded_from=None): self.source = "" self.source_buffer = source.Buffer(self.source, "") self.type_map, self.value_map = type_map, value_map + self.quote_function = quote_function self.expanded_from = expanded_from def finalize(self): @@ -82,6 +83,10 @@ class ASTSynthesizer: return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(), begin_loc=begin_loc, end_loc=end_loc, loc=begin_loc.join(end_loc)) + elif inspect.isfunction(value) or inspect.ismethod(value): + function_name, function_type = self.quote_function(value, self.expanded_from) + return asttyped.NameT(id=function_name, ctx=None, type=function_type, + loc=self._add(repr(value))) else: quote_loc = self._add('`') repr_loc = self._add(repr(value)) @@ -155,6 +160,36 @@ class ASTSynthesizer: begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None, loc=name_loc.join(end_loc)) + def assign_local(self, var_name, value): + name_loc = self._add(var_name) + _ = self._add(" ") + equals_loc = self._add("=") + _ = self._add(" ") + value_node = self.quote(value) + + var_node = asttyped.NameT(id=var_name, ctx=None, type=value_node.type, + loc=name_loc) + + return ast.Assign(targets=[var_node], value=value_node, + op_locs=[equals_loc], loc=name_loc.join(value_node.loc)) + + def assign_attribute(self, obj, attr_name, value): + obj_node = self.quote(obj) + dot_loc = self._add(".") + name_loc = self._add(attr_name) + _ = self._add(" ") + equals_loc = self._add("=") + _ = self._add(" ") + value_node = self.quote(value) + + attr_node = asttyped.AttributeT(value=obj_node, attr=attr_name, ctx=None, + type=value_node.type, + dot_loc=dot_loc, attr_loc=name_loc, + loc=obj_node.loc.join(name_loc)) + + return ast.Assign(targets=[attr_node], value=value_node, + op_locs=[equals_loc], loc=name_loc.join(value_node.loc)) + class StitchingASTTypedRewriter(ASTTypedRewriter): def __init__(self, engine, prelude, globals, host_environment, quote): super().__init__(engine, prelude) @@ -221,7 +256,20 @@ class StitchingInferencer(Inferencer): # overhead (i.e. synthesizing a source buffer), but has the advantage # of having the host-to-ARTIQ mapping code in only one place and # also immediately getting proper diagnostics on type errors. - ast = self.quote(getattr(object_value, node.attr), object_loc.expanded_from) + attr_value = getattr(object_value, node.attr) + if (inspect.ismethod(attr_value) and hasattr(attr_value.__func__, 'artiq_embedded') + and types.is_instance(object_type)): + # In cases like: + # class c: + # @kernel + # def f(self): pass + # we want f to be defined on the class, not on the instance. + attributes = object_type.constructor.attributes + attr_value = attr_value.__func__ + else: + attributes = object_type.attributes + + ast = self.quote(attr_value, None) def proxy_diagnostic(diag): note = diagnostic.Diagnostic("note", @@ -238,17 +286,17 @@ class StitchingInferencer(Inferencer): Inferencer(engine=proxy_engine).visit(ast) IntMonomorphizer(engine=proxy_engine).visit(ast) - if node.attr not in object_type.attributes: + if node.attr not in attributes: # We just figured out what the type should be. Add it. - object_type.attributes[node.attr] = ast.type - elif object_type.attributes[node.attr] != ast.type: + attributes[node.attr] = ast.type + elif attributes[node.attr] != ast.type: # Does this conflict with an earlier guess? printer = types.TypePrinter() diag = diagnostic.Diagnostic("error", "host object has an attribute of type {typea}, which is" " different from previously inferred type {typeb}", {"typea": printer.name(ast.type), - "typeb": printer.name(object_type.attributes[node.attr])}, + "typeb": printer.name(attributes[node.attr])}, object_loc) self.engine.process(diag) @@ -261,11 +309,9 @@ class TypedtreeHasher(algorithm.Visitor): return self.visit(obj) elif isinstance(obj, types.Type): return hash(obj.find()) - elif isinstance(obj, list): - return tuple(obj) else: - assert obj is None or isinstance(obj, (bool, int, float, str)) - return obj + # We don't care; only types change during inference. + pass fields = node._fields if hasattr(node, '_types'): @@ -281,6 +327,7 @@ class Stitcher: self.name = "" self.typedtree = [] + self.inject_at = 0 self.prelude = prelude.globals() self.globals = {} @@ -290,6 +337,17 @@ class Stitcher: self.type_map = {} self.value_map = defaultdict(lambda: []) + def stitch_call(self, function, args, kwargs): + function_node = self._quote_embedded_function(function) + self.typedtree.append(function_node) + + # We synthesize source code for the initial call so that + # diagnostics would have something meaningful to display to the user. + synthesizer = self._synthesizer() + call_node = synthesizer.call(function_node, args, kwargs) + synthesizer.finalize() + self.typedtree.append(call_node) + def finalize(self): inferencer = StitchingInferencer(engine=self.engine, value_map=self.value_map, @@ -306,12 +364,50 @@ class Stitcher: break old_typedtree_hash = typedtree_hash + # For every host class we embed, add an appropriate constructor + # as a global. This is necessary for method lookup, which uses + # the getconstructor instruction. + for instance_type, constructor_type in list(self.type_map.values()): + # Do we have any direct reference to a constructor? + if len(self.value_map[constructor_type]) > 0: + # Yes, use it. + constructor, _constructor_loc = self.value_map[constructor_type][0] + else: + # No, extract one from a reference to an instance. + instance, _instance_loc = self.value_map[instance_type][0] + constructor = type(instance) + + self.globals[constructor_type.name] = constructor_type + + synthesizer = self._synthesizer() + ast = synthesizer.assign_local(constructor_type.name, constructor) + synthesizer.finalize() + self._inject(ast) + + for attr in constructor_type.attributes: + if types.is_function(constructor_type.attributes[attr]): + synthesizer = self._synthesizer() + ast = synthesizer.assign_attribute(constructor, attr, + getattr(constructor, attr)) + synthesizer.finalize() + self._inject(ast) + # After we have found all functions, synthesize a module to hold them. source_buffer = source.Buffer("", "") self.typedtree = asttyped.ModuleT( typing_env=self.globals, globals_in_scope=set(), body=self.typedtree, loc=source.Range(source_buffer, 0, 0)) + def _inject(self, node): + self.typedtree.insert(self.inject_at, node) + self.inject_at += 1 + + def _synthesizer(self, expanded_from=None): + return ASTSynthesizer(expanded_from=expanded_from, + type_map=self.type_map, + value_map=self.value_map, + quote_function=self._quote_function) + def _quote_embedded_function(self, function): if not hasattr(function, "artiq_embedded"): raise ValueError("{} is not an embedded function".format(repr(function))) @@ -414,10 +510,7 @@ class Stitcher: # This is tricky, because the default value might not have # a well-defined type in APython. # In this case, we bail out, but mention why we do it. - synthesizer = ASTSynthesizer(type_map=self.type_map, - value_map=self.value_map) - ast = synthesizer.quote(param.default) - synthesizer.finalize() + ast = self._quote(param.default, None) def proxy_diagnostic(diag): note = diagnostic.Diagnostic("note", @@ -499,11 +592,12 @@ class Stitcher: self.globals[function_name] = function_type self.functions[function] = function_name - return function_name + return function_name, function_type def _quote_function(self, function, loc): if function in self.functions: - return self.functions[function] + function_name = self.functions[function] + return function_name, self.globals[function_name] if hasattr(function, "artiq_embedded"): if function.artiq_embedded.function is not None: @@ -511,8 +605,8 @@ class Stitcher: # It doesn't really matter where we insert as long as it is before # the final call. function_node = self._quote_embedded_function(function) - self.typedtree.insert(0, function_node) - return function_node.name + self._inject(function_node) + return function_node.name, self.globals[function_node.name] elif function.artiq_embedded.syscall is not None: # Insert a storage-less global whose type instructs the compiler # to perform a system call instead of a regular call. @@ -527,31 +621,7 @@ class Stitcher: syscall=None) def _quote(self, value, loc): - if inspect.isfunction(value) or inspect.ismethod(value): - # It's a function. We need to translate the function and insert - # a reference to it. - function_name = self._quote_function(value, loc) - return asttyped.NameT(id=function_name, ctx=None, - type=self.globals[function_name], - loc=loc) - - else: - # It's just a value. Quote it. - synthesizer = ASTSynthesizer(expanded_from=loc, - type_map=self.type_map, - value_map=self.value_map) - node = synthesizer.quote(value) - synthesizer.finalize() - return node - - def stitch_call(self, function, args, kwargs): - function_node = self._quote_embedded_function(function) - self.typedtree.append(function_node) - - # We synthesize source code for the initial call so that - # diagnostics would have something meaningful to display to the user. - synthesizer = ASTSynthesizer(type_map=self.type_map, - value_map=self.value_map) - call_node = synthesizer.call(function_node, args, kwargs) + synthesizer = self._synthesizer(loc) + node = synthesizer.quote(value) synthesizer.finalize() - self.typedtree.append(call_node) + return node diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 119ffdab5..aa726e98f 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -127,8 +127,10 @@ class Inferencer(algorithm.Visitor): when=" while inferring the type for self argument") attr_type = types.TMethod(object_type, attr_type) - self._unify(node.type, attr_type, - node.loc, None) + + if not types.is_var(attr_type): + self._unify(node.type, attr_type, + node.loc, None) else: if node.attr_loc.source_buffer == node.value.loc.source_buffer: highlights, notes = [node.value.loc], [] diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 5fe65363b..51218b294 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -266,7 +266,7 @@ class LLVMIRGenerator: elif types.is_constructor(typ): name = "class.{}".format(typ.name) else: - name = typ.name + name = "instance.{}".format(typ.name) llty = self.llcontext.get_identified_type(name) if llty.elements is None: @@ -991,7 +991,7 @@ class LLVMIRGenerator: llfields.append(self._quote(getattr(value, attr), typ.attributes[attr], lambda: path() + [attr])) - llvalue = ll.Constant.literal_struct(llfields) + llvalue = ll.Constant(llty.pointee, llfields) llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name) llconst.initializer = llvalue llconst.linkage = "private" @@ -1012,8 +1012,10 @@ class LLVMIRGenerator: elif builtins.is_str(typ): assert isinstance(value, (str, bytes)) return self.llstr_of_str(value) - elif types.is_rpc_function(typ): - return ll.Constant.literal_struct([]) + elif types.is_function(typ): + # RPC and C functions have no runtime representation; ARTIQ + # functions are initialized explicitly. + return ll.Constant(llty, ll.Undefined) else: print(typ) assert False