diff --git a/artiq/compiler/asttyped.py b/artiq/compiler/asttyped.py index 85e86d59d..77a5349ae 100644 --- a/artiq/compiler/asttyped.py +++ b/artiq/compiler/asttyped.py @@ -29,6 +29,8 @@ class ClassDefT(ast.ClassDef): _types = ("constructor_type",) class FunctionDefT(ast.FunctionDef, scoped): _types = ("signature_type",) +class QuotedFunctionDefT(FunctionDefT): + pass class ModuleT(ast.Module, scoped): pass diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index e49a65f3a..caa02bc9b 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -16,6 +16,7 @@ from Levenshtein import ratio as similarity, jaro_winkler from ..language import core as language_core from . import types, builtins, asttyped, prelude from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer +from .transforms.asttyped_rewriter import LocalExtractor def coredevice_print(x): print(x) @@ -241,6 +242,35 @@ class StitchingASTTypedRewriter(ASTTypedRewriter): self.host_environment = host_environment self.quote = quote + def visit_quoted_function(self, node, function): + extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) + extractor.visit(node) + + # We quote the defaults so they end up in the global data in LLVM IR. + # This way there is no "life before main", i.e. they do not have to be + # constructed before the main translated call executes; but the Python + # semantics is kept. + defaults = function.__defaults__ or () + quoted_defaults = [] + for default, default_node in zip(defaults, node.args.defaults): + quoted_defaults.append(self.quote(default, default_node.loc)) + node.args.defaults = quoted_defaults + + node = asttyped.QuotedFunctionDefT( + typing_env=extractor.typing_env, globals_in_scope=extractor.global_, + signature_type=types.TVar(), return_type=types.TVar(), + name=node.name, args=node.args, returns=node.returns, + body=node.body, decorator_list=node.decorator_list, + keyword_loc=node.keyword_loc, name_loc=node.name_loc, + arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, + loc=node.loc) + + try: + self.env_stack.append(node.typing_env) + return self.generic_visit(node) + finally: + self.env_stack.pop() + def visit_Name(self, node): typ = super()._try_find_name(node.id) if typ is not None: @@ -562,7 +592,7 @@ class Stitcher: engine=self.engine, prelude=self.prelude, globals=self.globals, host_environment=host_environment, quote=self._quote) - function_node = asttyped_rewriter.visit(function_node) + function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function) # Add it into our typedtree so that it gets inferenced and codegen'd. self._inject(function_node) diff --git a/artiq/compiler/targets.py b/artiq/compiler/targets.py index ea8ed4ef0..3cd93dcde 100644 --- a/artiq/compiler/targets.py +++ b/artiq/compiler/targets.py @@ -94,6 +94,7 @@ class Target: llpassmgr.add_sroa_pass() llpassmgr.add_dead_code_elimination_pass() llpassmgr.add_gvn_pass() + llpassmgr.add_function_attrs_pass() llpassmgr.run(llmodule) def compile(self, module): diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 5030f0482..a0bac70b3 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -224,7 +224,7 @@ class ARTIQIRGenerator(algorithm.Visitor): finally: self.current_class = old_class - def visit_function(self, node, is_lambda, is_internal): + def visit_function(self, node, is_lambda=False, is_internal=False, is_quoted=False): if is_lambda: name = "lambda@{}:{}".format(node.loc.line(), node.loc.column()) typ = node.type.find() @@ -234,16 +234,24 @@ class ARTIQIRGenerator(algorithm.Visitor): try: defaults = [] - for arg_name, default_node in zip(typ.optargs, node.args.defaults): - default = self.visit(default_node) - env_default_name = \ - self.current_env.type.add("$default." + arg_name, default.type) - self.append(ir.SetLocal(self.current_env, env_default_name, default)) - defaults.append(env_default_name) + if not is_quoted: + for arg_name, default_node in zip(typ.optargs, node.args.defaults): + default = self.visit(default_node) + env_default_name = \ + self.current_env.type.add("$default." + arg_name, default.type) + self.append(ir.SetLocal(self.current_env, env_default_name, default)) + def codegen_default(env_default_name): + return lambda: self.append(ir.GetLocal(self.current_env, env_default_name)) + defaults.append(codegen_default(env_default_name)) + else: + for default_node in node.args.defaults: + def codegen_default(default_node): + return lambda: self.visit(default_node) + defaults.append(codegen_default(default_node)) old_name, self.name = self.name, self.name + [name] - env_arg = ir.EnvironmentArgument(self.current_env.type, "CLS") + env_arg = ir.EnvironmentArgument(self.current_env.type, "ARG.ENV") old_args, self.current_args = self.current_args, {} @@ -291,8 +299,8 @@ class ARTIQIRGenerator(algorithm.Visitor): self.append(ir.SetLocal(env, "$outer", env_arg)) for index, arg_name in enumerate(typ.args): self.append(ir.SetLocal(env, arg_name, args[index])) - for index, (arg_name, env_default_name) in enumerate(zip(typ.optargs, defaults)): - default = self.append(ir.GetLocal(self.current_env, env_default_name)) + for index, (arg_name, codegen_default) in enumerate(zip(typ.optargs, defaults)): + default = codegen_default() value = self.append(ir.Builtin("unwrap_or", [optargs[index], default], typ.optargs[arg_name])) self.append(ir.SetLocal(env, arg_name, value)) @@ -320,14 +328,15 @@ class ARTIQIRGenerator(algorithm.Visitor): return self.append(ir.Closure(func, self.current_env)) def visit_FunctionDefT(self, node): - func = self.visit_function(node, is_lambda=False, - is_internal=len(self.name) > 0 or '.' in node.name) + func = self.visit_function(node, is_internal=len(self.name) > 0) if self.current_class is None: - if node.name in self.current_env.type.params: - self._set_local(node.name, func) + self._set_local(node.name, func) else: self.append(ir.SetAttr(self.current_class, node.name, func)) + def visit_QuotedFunctionDefT(self, node): + self.visit_function(node, is_internal=True, is_quoted=True) + def visit_Return(self, node): if node.value is None: return_value = ir.Constant(None, builtins.TNone()) @@ -1676,7 +1685,7 @@ class ARTIQIRGenerator(algorithm.Visitor): offset = 0 elif types.is_method(callee.type): func = self.append(ir.GetAttr(callee, "__func__", - name="{}.CLS".format(callee.name))) + name="{}.ENV".format(callee.name))) self_arg = self.append(ir.GetAttr(callee, "__self__", name="{}.SLF".format(callee.name))) fn_typ = types.get_method_function(callee.type) diff --git a/artiq/compiler/transforms/asttyped_rewriter.py b/artiq/compiler/transforms/asttyped_rewriter.py index 6a58748ee..b0714bb6f 100644 --- a/artiq/compiler/transforms/asttyped_rewriter.py +++ b/artiq/compiler/transforms/asttyped_rewriter.py @@ -229,9 +229,7 @@ class ASTTypedRewriter(algorithm.Transformer): extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) extractor.visit(node) - signature_type = self._try_find_name(node.name) - if signature_type is None: - signature_type = types.TVar() + signature_type = self._find_name(node.name, node.name_loc) node = asttyped.FunctionDefT( typing_env=extractor.typing_env, globals_in_scope=extractor.global_, diff --git a/artiq/compiler/transforms/dead_code_eliminator.py b/artiq/compiler/transforms/dead_code_eliminator.py index 6aae22dd4..2aceebeec 100644 --- a/artiq/compiler/transforms/dead_code_eliminator.py +++ b/artiq/compiler/transforms/dead_code_eliminator.py @@ -33,7 +33,7 @@ class DeadCodeEliminator: # it also has to run after the interleaver, but interleaver # doesn't like to work with IR before DCE. if isinstance(insn, (ir.Phi, ir.Alloc, ir.GetAttr, ir.GetElem, ir.Coerce, - ir.Arith, ir.Compare, ir.Select, ir.Quote)) \ + ir.Arith, ir.Compare, ir.Select, ir.Quote, ir.Closure)) \ and not any(insn.uses): insn.erase() modified = True diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 69465c0d7..fc8f6212e 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -1227,6 +1227,8 @@ class Inferencer(algorithm.Visitor): self._unify(node.signature_type, signature_type, node.name_loc, None) + visit_QuotedFunctionDefT = visit_FunctionDefT + def visit_ClassDefT(self, node): if any(node.decorator_list): diag = diagnostic.Diagnostic("error", diff --git a/artiq/compiler/transforms/iodelay_estimator.py b/artiq/compiler/transforms/iodelay_estimator.py index 1e8cd2b36..fff7470de 100644 --- a/artiq/compiler/transforms/iodelay_estimator.py +++ b/artiq/compiler/transforms/iodelay_estimator.py @@ -142,6 +142,8 @@ class IODelayEstimator(algorithm.Visitor): body = node.body self.visit_function(node.args, body, node.signature_type.find(), node.loc) + visit_QuotedFunctionDefT = visit_FunctionDefT + def visit_LambdaT(self, node): self.visit_function(node.args, node.body, node.type.find(), node.loc) diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 66113924d..705f48b96 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -947,8 +947,8 @@ class LLVMIRGenerator: name=insn.name) elif insn.op == "unwrap_or": lloptarg, lldefault = map(self.map, insn.operands) - llhas_arg = self.llbuilder.extract_value(lloptarg, 0) - llarg = self.llbuilder.extract_value(lloptarg, 1) + llhas_arg = self.llbuilder.extract_value(lloptarg, 0, name="opt.has") + llarg = self.llbuilder.extract_value(lloptarg, 1, name="opt.val") return self.llbuilder.select(llhas_arg, llarg, lldefault, name=insn.name) elif insn.op == "round": @@ -1005,28 +1005,21 @@ class LLVMIRGenerator: def process_Closure(self, insn): llenv = self.map(insn.environment()) llenv = self.llbuilder.bitcast(llenv, llptr, name="ptr.{}".format(llenv.name)) - if insn.target_function.name in self.function_map.values(): - # If this closure belongs to a quoted function, we assume this is the only - # time that the closure is created, and record the environment globally - llenvptr = self.get_or_define_global("E.{}".format(insn.target_function.name), - llptr) - self.llbuilder.store(llenv, llenvptr) - + llfun = self.map(insn.target_function) llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined) llvalue = self.llbuilder.insert_value(llvalue, llenv, 0) - llvalue = self.llbuilder.insert_value(llvalue, self.map(insn.target_function), 1, - name=insn.name) + llvalue = self.llbuilder.insert_value(llvalue, llfun, 1, name=insn.name) return llvalue def _prepare_closure_call(self, insn): llargs = [self.map(arg) for arg in insn.arguments()] llclosure = self.map(insn.target_function()) - llenv = self.llbuilder.extract_value(llclosure, 0, name="env.call") if insn.static_target_function is None: llfun = self.llbuilder.extract_value(llclosure, 1, name="fun.{}".format(llclosure.name)) else: llfun = self.map(insn.static_target_function) + llenv = self.llbuilder.extract_value(llclosure, 0, name="env.fun") return llfun, [llenv] + list(llargs) def _prepare_ffi_call(self, insn): @@ -1315,13 +1308,8 @@ class LLVMIRGenerator: def process_Quote(self, insn): if insn.value in self.function_map: - func_name = self.function_map[insn.value] - llenvptr = self.get_or_define_global("E.{}".format(func_name), llptr) - llenv = self.llbuilder.load(llenvptr) - llfun = self.get_function(insn.type.find(), func_name) - + llfun = self.get_function(insn.type.find(), self.function_map[insn.value]) llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) - llclosure = self.llbuilder.insert_value(llclosure, llenv, 0) llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name) return llclosure else: diff --git a/artiq/compiler/validators/escape.py b/artiq/compiler/validators/escape.py index 00df17786..b2a9d47bc 100644 --- a/artiq/compiler/validators/escape.py +++ b/artiq/compiler/validators/escape.py @@ -277,6 +277,8 @@ class EscapeValidator(algorithm.Visitor): self.visit_in_region(node, Region(node.loc), node.typing_env, args={ arg.arg: Argument(arg.loc) for arg in node.args.args }) + visit_QuotedFunctionDefT = visit_FunctionDefT + def visit_ClassDefT(self, node): self.youngest_env[node.name] = self.youngest_region self.visit_in_region(node, Region(node.loc), node.constructor_type.attributes) diff --git a/artiq/compiler/validators/monomorphism.py b/artiq/compiler/validators/monomorphism.py index 3eeddfac5..f30ac5288 100644 --- a/artiq/compiler/validators/monomorphism.py +++ b/artiq/compiler/validators/monomorphism.py @@ -24,6 +24,8 @@ class MonomorphismValidator(algorithm.Visitor): node.name_loc, notes=[note]) self.engine.process(diag) + visit_QuotedFunctionDefT = visit_FunctionDefT + def generic_visit(self, node): super().generic_visit(node)