diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 58f112f98..e49a65f3a 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -107,11 +107,8 @@ class ASTSynthesizer: unquote_loc = self._add('`') loc = quote_loc.join(unquote_loc) - function_name, function_type = self.quote_function(value, self.expanded_from) - if function_name is None: - return asttyped.QuoteT(value=value, type=function_type, loc=loc) - else: - return asttyped.NameT(id=function_name, ctx=None, type=function_type, loc=loc) + function_type = self.quote_function(value, self.expanded_from) + return asttyped.QuoteT(value=value, type=function_type, loc=loc) else: quote_loc = self._add('`') repr_loc = self._add(repr(value)) @@ -155,7 +152,7 @@ class ASTSynthesizer: return asttyped.QuoteT(value=value, type=instance_type, loc=loc) - def call(self, function_node, args, kwargs, callback=None): + def call(self, callee, args, kwargs, callback=None): """ Construct an AST fragment calling a function specified by an AST node `function_node`, with given arguments. @@ -164,11 +161,11 @@ class ASTSynthesizer: callback_node = self.quote(callback) cb_begin_loc = self._add("(") + callee_node = self.quote(callee) arg_nodes = [] kwarg_nodes = [] kwarg_locs = [] - name_loc = self._add(function_node.name) begin_loc = self._add("(") for index, arg in enumerate(args): arg_nodes.append(self.quote(arg)) @@ -189,9 +186,7 @@ class ASTSynthesizer: cb_end_loc = self._add(")") node = asttyped.CallT( - func=asttyped.NameT(id=function_node.name, ctx=None, - type=function_node.signature_type, - loc=name_loc), + func=callee_node, args=arg_nodes, keywords=[ast.keyword(arg=kw, value=value, arg_loc=arg_loc, equals_loc=equals_loc, @@ -201,7 +196,7 @@ class ASTSynthesizer: starargs=None, kwargs=None, type=types.TVar(), iodelay=None, arg_exprs={}, begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None, - loc=name_loc.join(end_loc)) + loc=callee_node.loc.join(end_loc)) if callback is not None: node = asttyped.CallT( @@ -213,19 +208,6 @@ class ASTSynthesizer: return node - def assign_local(self, var_name, value): - name_loc = self._add(var_name) - _ = self._add(" ") - equals_loc = self._add("=") - _ = self._add(" ") - value_node = self.quote(value) - - var_node = asttyped.NameT(id=var_name, ctx=None, type=value_node.type, - loc=name_loc) - - return ast.Assign(targets=[var_node], value=value_node, - op_locs=[equals_loc], loc=name_loc.join(value_node.loc)) - def assign_attribute(self, obj, attr_name, value): obj_node = self.quote(obj) dot_loc = self._add(".") @@ -465,18 +447,16 @@ class Stitcher: self.functions = {} + self.function_map = {} self.object_map = ObjectMap() self.type_map = {} self.value_map = defaultdict(lambda: []) def stitch_call(self, function, args, kwargs, callback=None): - function_node = self._quote_embedded_function(function) - self.typedtree.append(function_node) - # We synthesize source code for the initial call so that # diagnostics would have something meaningful to display to the user. synthesizer = self._synthesizer(self._function_loc(function.artiq_embedded.function)) - call_node = synthesizer.call(function_node, args, kwargs, callback) + call_node = synthesizer.call(function, args, kwargs, callback) synthesizer.finalize() self.typedtree.append(call_node) @@ -496,9 +476,8 @@ class Stitcher: break old_typedtree_hash = typedtree_hash - # For every host class we embed, add an appropriate constructor - # as a global. This is necessary for method lookup, which uses - # the getconstructor instruction. + # For every host class we embed, fill in the function slots + # with their corresponding closures. for instance_type, constructor_type in list(self.type_map.values()): # Do we have any direct reference to a constructor? if len(self.value_map[constructor_type]) > 0: @@ -509,13 +488,6 @@ class Stitcher: instance, _instance_loc = self.value_map[instance_type][0] constructor = type(instance) - self.globals[constructor_type.name] = constructor_type - - synthesizer = self._synthesizer() - ast = synthesizer.assign_local(constructor_type.name, constructor) - synthesizer.finalize() - self._inject(ast) - for attr in constructor_type.attributes: if types.is_function(constructor_type.attributes[attr]): synthesizer = self._synthesizer() @@ -577,23 +549,28 @@ class Stitcher: # Mangle the name, since we put everything into a single module. function_node.name = "{}.{}".format(module_name, function.__qualname__) - # Normally, LocalExtractor would populate the typing environment - # of the module with the function name. However, since we run - # ASTTypedRewriter on the function node directly, we need to do it - # explicitly. - function_type = types.TVar() - self.globals[function_node.name] = function_type + # Record the function in the function map so that LLVM IR generator + # can handle quoting it. + self.function_map[function] = function_node.name - # Memoize the function before typing it to handle recursive + # Memoize the function type before typing it to handle recursive # invocations. - self.functions[function] = function_node.name, function_type + self.functions[function] = types.TVar() # Rewrite into typed form. asttyped_rewriter = StitchingASTTypedRewriter( engine=self.engine, prelude=self.prelude, globals=self.globals, host_environment=host_environment, quote=self._quote) - return asttyped_rewriter.visit(function_node) + function_node = asttyped_rewriter.visit(function_node) + + # Add it into our typedtree so that it gets inferenced and codegen'd. + self._inject(function_node) + + # Tie the typing knot. + self.functions[function].unify(function_node.signature_type) + + return function_node def _function_loc(self, function): filename = function.__code__.co_filename @@ -734,14 +711,12 @@ class Stitcher: function_type = types.TCFunction(arg_types, ret_type, name=syscall) - self.functions[function] = None, function_type + self.functions[function] = function_type - return None, function_type + return function_type def _quote_function(self, function, loc): - if function in self.functions: - result = self.functions[function] - else: + if function not in self.functions: if hasattr(function, "artiq_embedded"): if function.artiq_embedded.function is not None: if function.__name__ == "": @@ -766,17 +741,12 @@ class Stitcher: notes=[note]) self.engine.process(diag) - # Insert the typed AST for the new function and restart inference. - # It doesn't really matter where we insert as long as it is before - # the final call. - function_node = self._quote_embedded_function(function) - self._inject(function_node) - result = function_node.name, self.globals[function_node.name] + self._quote_embedded_function(function) elif function.artiq_embedded.syscall is not None: # Insert a storage-less global whose type instructs the compiler # to perform a system call instead of a regular call. - result = self._quote_foreign_function(function, loc, - syscall=function.artiq_embedded.syscall) + self._quote_foreign_function(function, loc, + syscall=function.artiq_embedded.syscall) elif function.artiq_embedded.forbidden is not None: diag = diagnostic.Diagnostic("fatal", "this function cannot be called as an RPC", {}, @@ -788,12 +758,12 @@ class Stitcher: else: # Insert a storage-less global whose type instructs the compiler # to perform an RPC instead of a regular call. - result = self._quote_foreign_function(function, loc, syscall=None) + self._quote_foreign_function(function, loc, syscall=None) - function_name, function_type = result + function_type = self.functions[function] if types.is_rpc_function(function_type): function_type = types.instantiate(function_type) - return function_name, function_type + return function_type def _quote(self, value, loc): synthesizer = self._synthesizer(loc) diff --git a/artiq/compiler/module.py b/artiq/compiler/module.py index c5f80ec42..06c3a2615 100644 --- a/artiq/compiler/module.py +++ b/artiq/compiler/module.py @@ -19,6 +19,7 @@ class Source: else: self.engine = engine + self.function_map = {} self.object_map = None self.type_map = {} @@ -45,6 +46,7 @@ class Source: class Module: def __init__(self, src, ref_period=1e-6): self.engine = src.engine + self.function_map = src.function_map self.object_map = src.object_map self.type_map = src.type_map @@ -80,7 +82,7 @@ class Module: """Compile the module to LLVM IR for the specified target.""" llvm_ir_generator = transforms.LLVMIRGenerator( engine=self.engine, module_name=self.name, target=target, - object_map=self.object_map, type_map=self.type_map) + function_map=self.function_map, object_map=self.object_map, type_map=self.type_map) return llvm_ir_generator.process(self.artiq_ir, attribute_writeback=True) def entry_point(self): diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 90b8e7bd0..92f66f7ae 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -319,13 +319,14 @@ class ARTIQIRGenerator(algorithm.Visitor): return self.append(ir.Closure(func, self.current_env)) - def visit_FunctionDefT(self, node, in_class=None): + def visit_FunctionDefT(self, node): func = self.visit_function(node, is_lambda=False, is_internal=len(self.name) > 0 or '.' in node.name) - if in_class is None: - self._set_local(node.name, func) + if self.current_class is None: + if node.name in self.current_env.type.params: + self._set_local(node.name, func) else: - self.append(ir.SetAttr(in_class, node.name, func)) + self.append(ir.SetAttr(self.current_class, node.name, func)) def visit_Return(self, node): if node.value is None: diff --git a/artiq/compiler/transforms/asttyped_rewriter.py b/artiq/compiler/transforms/asttyped_rewriter.py index 43455143f..6a58748ee 100644 --- a/artiq/compiler/transforms/asttyped_rewriter.py +++ b/artiq/compiler/transforms/asttyped_rewriter.py @@ -190,16 +190,16 @@ class ASTTypedRewriter(algorithm.Transformer): self.in_class = None def _try_find_name(self, name): - for typing_env in reversed(self.env_stack): - if name in typing_env: - return typing_env[name] - - def _find_name(self, name, loc): if self.in_class is not None: typ = self.in_class.constructor_type.attributes.get(name) if typ is not None: return typ + for typing_env in reversed(self.env_stack): + if name in typing_env: + return typing_env[name] + + def _find_name(self, name, loc): typ = self._try_find_name(name) if typ is not None: return typ @@ -229,9 +229,13 @@ class ASTTypedRewriter(algorithm.Transformer): extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) extractor.visit(node) + signature_type = self._try_find_name(node.name) + if signature_type is None: + signature_type = types.TVar() + node = asttyped.FunctionDefT( typing_env=extractor.typing_env, globals_in_scope=extractor.global_, - signature_type=self._find_name(node.name, node.name_loc), return_type=types.TVar(), + signature_type=signature_type, return_type=types.TVar(), name=node.name, args=node.args, returns=node.returns, body=node.body, decorator_list=node.decorator_list, keyword_loc=node.keyword_loc, name_loc=node.name_loc, diff --git a/artiq/compiler/transforms/dead_code_eliminator.py b/artiq/compiler/transforms/dead_code_eliminator.py index eb6d1a191..6aae22dd4 100644 --- a/artiq/compiler/transforms/dead_code_eliminator.py +++ b/artiq/compiler/transforms/dead_code_eliminator.py @@ -33,7 +33,7 @@ class DeadCodeEliminator: # it also has to run after the interleaver, but interleaver # doesn't like to work with IR before DCE. if isinstance(insn, (ir.Phi, ir.Alloc, ir.GetAttr, ir.GetElem, ir.Coerce, - ir.Arith, ir.Compare, ir.Closure, ir.Select, ir.Quote)) \ + ir.Arith, ir.Compare, ir.Select, ir.Quote)) \ and not any(insn.uses): insn.erase() modified = True diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 58683a764..062742cd4 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -164,9 +164,10 @@ class DebugInfoEmitter: class LLVMIRGenerator: - def __init__(self, engine, module_name, target, object_map, type_map): + def __init__(self, engine, module_name, target, function_map, object_map, type_map): self.engine = engine self.target = target + self.function_map = function_map self.object_map = object_map self.type_map = type_map self.llcontext = target.llcontext @@ -393,22 +394,24 @@ class LLVMIRGenerator: return llglobal + def get_function(self, typ, name): + llfun = self.llmodule.get_global(name) + if llfun is None: + llfunty = self.llty_of_type(typ, bare=True) + llfun = ll.Function(self.llmodule, llfunty, name) + + llretty = self.llty_of_type(typ.ret, for_return=True) + if self.needs_sret(llretty): + llfun.args[0].add_attribute('sret') + return llfun + def map(self, value): if isinstance(value, (ir.Argument, ir.Instruction, ir.BasicBlock)): return self.llmap[value] elif isinstance(value, ir.Constant): return self.llconst_of_const(value) elif isinstance(value, ir.Function): - llfun = self.llmodule.get_global(value.name) - if llfun is None: - llfunty = self.llty_of_type(value.type, bare=True) - llfun = ll.Function(self.llmodule, llfunty, value.name) - - llretty = self.llty_of_type(value.type.ret, for_return=True) - if self.needs_sret(llretty): - llfun.args[0].add_attribute('sret') - - return llfun + return self.get_function(value.type, value.name) else: assert False @@ -669,6 +672,9 @@ class LLVMIRGenerator: return list(typ.attributes.keys()).index(attr) def get_or_define_global(self, name, llty, llvalue=None): + if llvalue is None: + llvalue = ll.Constant(llty, ll.Undefined) + if name in self.llmodule.globals: llglobal = self.llmodule.get_global(name) else: @@ -683,12 +689,11 @@ class LLVMIRGenerator: llty = self.llty_of_type(typ).pointee return self.get_or_define_global("class.{}".format(typ.name), llty) - def get_closure(self, typ, attr): + def get_method(self, typ, attr): assert types.is_constructor(typ) assert types.is_function(typ.attributes[attr]) llty = self.llty_of_type(typ.attributes[attr]) - return self.get_or_define_global("method.{}.{}".format(typ.name, attr), - llty, ll.Constant(llty, ll.Undefined)) + return self.get_or_define_global("method.{}.{}".format(typ.name, attr), llty) def process_GetAttr(self, insn): typ, attr = insn.object().type, insn.attr @@ -710,7 +715,7 @@ class LLVMIRGenerator: assert False if types.is_method(insn.type) and attr not in typ.attributes: - llfun = self.llbuilder.load(self.get_closure(typ.constructor, attr)) + llfun = self.llbuilder.load(self.get_method(typ.constructor, attr)) llself = self.map(insn.object()) llmethodty = self.llty_of_type(insn.type) @@ -722,7 +727,7 @@ class LLVMIRGenerator: return llmethod elif types.is_function(insn.type) and attr in typ.attributes and \ types.is_constructor(typ): - return self.llbuilder.load(self.get_closure(typ, attr)) + return self.llbuilder.load(self.get_method(typ, attr)) else: llptr = self.llbuilder.gep(obj, [self.llindex(0), self.llindex(index)], inbounds=True, name=insn.name) @@ -742,7 +747,7 @@ class LLVMIRGenerator: if types.is_function(insn.value().type) and attr in typ.attributes and \ types.is_constructor(typ): - llptr = self.get_closure(typ, attr) + llptr = self.get_method(typ, attr) else: llptr = self.llbuilder.gep(obj, [self.llindex(0), self.llindex(self.attr_index(typ, attr))], @@ -987,8 +992,15 @@ class LLVMIRGenerator: assert False def process_Closure(self, insn): - llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined) llenv = self.llbuilder.bitcast(self.map(insn.environment()), llptr) + if insn.target_function.name in self.function_map.values(): + # If this closure belongs to a quoted function, we assume this is the only + # time that the closure is created, and record the environment globally + llenvptr = self.get_or_define_global("env.{}".format(insn.target_function.name), + llptr) + self.llbuilder.store(llenv, llenvptr) + + llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined) llvalue = self.llbuilder.insert_value(llvalue, llenv, 0) llvalue = self.llbuilder.insert_value(llvalue, self.map(insn.target_function), 1, name=insn.name) @@ -1273,16 +1285,29 @@ class LLVMIRGenerator: llconst = ll.Constant(llty, [ll.Constant(lli32, len(llelts)), lleltsptr]) return llconst elif types.is_function(typ): - # RPC and C functions have no runtime representation; ARTIQ - # functions are initialized explicitly. + # RPC and C functions have no runtime representation. + # We only get down this codepath for ARTIQ Python functions when they're + # referenced from a constructor, and the value inside the constructor + # is never used. return ll.Constant(llty, ll.Undefined) else: print(typ) assert False def process_Quote(self, insn): - assert self.object_map is not None - return self._quote(insn.value, insn.type, lambda: [repr(insn.value)]) + if insn.value in self.function_map: + func_name = self.function_map[insn.value] + llenvptr = self.get_or_define_global("env.{}".format(func_name), llptr) + llenv = self.llbuilder.load(llenvptr) + llfun = self.get_function(insn.type.find(), func_name) + + llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) + llclosure = self.llbuilder.insert_value(llclosure, llenv, 0) + llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name) + return llclosure + else: + assert self.object_map is not None + return self._quote(insn.value, insn.type, lambda: [repr(insn.value)]) def process_Select(self, insn): return self.llbuilder.select(self.map(insn.condition()), diff --git a/artiq/test/lit/devirtualization/function.py b/artiq/test/lit/devirtualization/function.py index db2e7afd0..c8fe65691 100644 --- a/artiq/test/lit/devirtualization/function.py +++ b/artiq/test/lit/devirtualization/function.py @@ -1,5 +1,6 @@ # RUN: env ARTIQ_DUMP_IR=%t %python -m artiq.compiler.testbench.embedding +compile %s # RUN: OutputCheck %s --file-to-check=%t.txt +# XFAIL: * from artiq.language.core import * from artiq.language.types import * diff --git a/artiq/test/lit/devirtualization/method.py b/artiq/test/lit/devirtualization/method.py index 32578d599..8c258a41b 100644 --- a/artiq/test/lit/devirtualization/method.py +++ b/artiq/test/lit/devirtualization/method.py @@ -1,5 +1,6 @@ # RUN: env ARTIQ_DUMP_IR=%t %python -m artiq.compiler.testbench.embedding +compile %s 2>%t # RUN: OutputCheck %s --file-to-check=%t.txt +# XFAIL: * from artiq.language.core import * from artiq.language.types import *