forked from M-Labs/artiq
compiler: make quoted functions independent of outer environment.
This commit is contained in:
parent
f5c720c3ee
commit
186a564ba8
@ -29,6 +29,8 @@ class ClassDefT(ast.ClassDef):
|
||||
_types = ("constructor_type",)
|
||||
class FunctionDefT(ast.FunctionDef, scoped):
|
||||
_types = ("signature_type",)
|
||||
class QuotedFunctionDefT(FunctionDefT):
|
||||
pass
|
||||
class ModuleT(ast.Module, scoped):
|
||||
pass
|
||||
|
||||
|
@ -16,6 +16,7 @@ from Levenshtein import ratio as similarity, jaro_winkler
|
||||
from ..language import core as language_core
|
||||
from . import types, builtins, asttyped, prelude
|
||||
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
|
||||
from .transforms.asttyped_rewriter import LocalExtractor
|
||||
|
||||
|
||||
def coredevice_print(x): print(x)
|
||||
@ -241,6 +242,35 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
|
||||
self.host_environment = host_environment
|
||||
self.quote = quote
|
||||
|
||||
def visit_quoted_function(self, node, function):
|
||||
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||
extractor.visit(node)
|
||||
|
||||
# We quote the defaults so they end up in the global data in LLVM IR.
|
||||
# This way there is no "life before main", i.e. they do not have to be
|
||||
# constructed before the main translated call executes; but the Python
|
||||
# semantics is kept.
|
||||
defaults = function.__defaults__ or ()
|
||||
quoted_defaults = []
|
||||
for default, default_node in zip(defaults, node.args.defaults):
|
||||
quoted_defaults.append(self.quote(default, default_node.loc))
|
||||
node.args.defaults = quoted_defaults
|
||||
|
||||
node = asttyped.QuotedFunctionDefT(
|
||||
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
||||
signature_type=types.TVar(), return_type=types.TVar(),
|
||||
name=node.name, args=node.args, returns=node.returns,
|
||||
body=node.body, decorator_list=node.decorator_list,
|
||||
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
|
||||
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
|
||||
loc=node.loc)
|
||||
|
||||
try:
|
||||
self.env_stack.append(node.typing_env)
|
||||
return self.generic_visit(node)
|
||||
finally:
|
||||
self.env_stack.pop()
|
||||
|
||||
def visit_Name(self, node):
|
||||
typ = super()._try_find_name(node.id)
|
||||
if typ is not None:
|
||||
@ -562,7 +592,7 @@ class Stitcher:
|
||||
engine=self.engine, prelude=self.prelude,
|
||||
globals=self.globals, host_environment=host_environment,
|
||||
quote=self._quote)
|
||||
function_node = asttyped_rewriter.visit(function_node)
|
||||
function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function)
|
||||
|
||||
# Add it into our typedtree so that it gets inferenced and codegen'd.
|
||||
self._inject(function_node)
|
||||
|
@ -94,6 +94,7 @@ class Target:
|
||||
llpassmgr.add_sroa_pass()
|
||||
llpassmgr.add_dead_code_elimination_pass()
|
||||
llpassmgr.add_gvn_pass()
|
||||
llpassmgr.add_function_attrs_pass()
|
||||
llpassmgr.run(llmodule)
|
||||
|
||||
def compile(self, module):
|
||||
|
@ -224,7 +224,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
finally:
|
||||
self.current_class = old_class
|
||||
|
||||
def visit_function(self, node, is_lambda, is_internal):
|
||||
def visit_function(self, node, is_lambda=False, is_internal=False, is_quoted=False):
|
||||
if is_lambda:
|
||||
name = "lambda@{}:{}".format(node.loc.line(), node.loc.column())
|
||||
typ = node.type.find()
|
||||
@ -234,16 +234,24 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
|
||||
try:
|
||||
defaults = []
|
||||
for arg_name, default_node in zip(typ.optargs, node.args.defaults):
|
||||
default = self.visit(default_node)
|
||||
env_default_name = \
|
||||
self.current_env.type.add("$default." + arg_name, default.type)
|
||||
self.append(ir.SetLocal(self.current_env, env_default_name, default))
|
||||
defaults.append(env_default_name)
|
||||
if not is_quoted:
|
||||
for arg_name, default_node in zip(typ.optargs, node.args.defaults):
|
||||
default = self.visit(default_node)
|
||||
env_default_name = \
|
||||
self.current_env.type.add("$default." + arg_name, default.type)
|
||||
self.append(ir.SetLocal(self.current_env, env_default_name, default))
|
||||
def codegen_default(env_default_name):
|
||||
return lambda: self.append(ir.GetLocal(self.current_env, env_default_name))
|
||||
defaults.append(codegen_default(env_default_name))
|
||||
else:
|
||||
for default_node in node.args.defaults:
|
||||
def codegen_default(default_node):
|
||||
return lambda: self.visit(default_node)
|
||||
defaults.append(codegen_default(default_node))
|
||||
|
||||
old_name, self.name = self.name, self.name + [name]
|
||||
|
||||
env_arg = ir.EnvironmentArgument(self.current_env.type, "CLS")
|
||||
env_arg = ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")
|
||||
|
||||
old_args, self.current_args = self.current_args, {}
|
||||
|
||||
@ -291,8 +299,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
self.append(ir.SetLocal(env, "$outer", env_arg))
|
||||
for index, arg_name in enumerate(typ.args):
|
||||
self.append(ir.SetLocal(env, arg_name, args[index]))
|
||||
for index, (arg_name, env_default_name) in enumerate(zip(typ.optargs, defaults)):
|
||||
default = self.append(ir.GetLocal(self.current_env, env_default_name))
|
||||
for index, (arg_name, codegen_default) in enumerate(zip(typ.optargs, defaults)):
|
||||
default = codegen_default()
|
||||
value = self.append(ir.Builtin("unwrap_or", [optargs[index], default],
|
||||
typ.optargs[arg_name]))
|
||||
self.append(ir.SetLocal(env, arg_name, value))
|
||||
@ -320,14 +328,15 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
return self.append(ir.Closure(func, self.current_env))
|
||||
|
||||
def visit_FunctionDefT(self, node):
|
||||
func = self.visit_function(node, is_lambda=False,
|
||||
is_internal=len(self.name) > 0 or '.' in node.name)
|
||||
func = self.visit_function(node, is_internal=len(self.name) > 0)
|
||||
if self.current_class is None:
|
||||
if node.name in self.current_env.type.params:
|
||||
self._set_local(node.name, func)
|
||||
self._set_local(node.name, func)
|
||||
else:
|
||||
self.append(ir.SetAttr(self.current_class, node.name, func))
|
||||
|
||||
def visit_QuotedFunctionDefT(self, node):
|
||||
self.visit_function(node, is_internal=True, is_quoted=True)
|
||||
|
||||
def visit_Return(self, node):
|
||||
if node.value is None:
|
||||
return_value = ir.Constant(None, builtins.TNone())
|
||||
@ -1676,7 +1685,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
offset = 0
|
||||
elif types.is_method(callee.type):
|
||||
func = self.append(ir.GetAttr(callee, "__func__",
|
||||
name="{}.CLS".format(callee.name)))
|
||||
name="{}.ENV".format(callee.name)))
|
||||
self_arg = self.append(ir.GetAttr(callee, "__self__",
|
||||
name="{}.SLF".format(callee.name)))
|
||||
fn_typ = types.get_method_function(callee.type)
|
||||
|
@ -229,9 +229,7 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||
extractor.visit(node)
|
||||
|
||||
signature_type = self._try_find_name(node.name)
|
||||
if signature_type is None:
|
||||
signature_type = types.TVar()
|
||||
signature_type = self._find_name(node.name, node.name_loc)
|
||||
|
||||
node = asttyped.FunctionDefT(
|
||||
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
||||
|
@ -33,7 +33,7 @@ class DeadCodeEliminator:
|
||||
# it also has to run after the interleaver, but interleaver
|
||||
# doesn't like to work with IR before DCE.
|
||||
if isinstance(insn, (ir.Phi, ir.Alloc, ir.GetAttr, ir.GetElem, ir.Coerce,
|
||||
ir.Arith, ir.Compare, ir.Select, ir.Quote)) \
|
||||
ir.Arith, ir.Compare, ir.Select, ir.Quote, ir.Closure)) \
|
||||
and not any(insn.uses):
|
||||
insn.erase()
|
||||
modified = True
|
||||
|
@ -1227,6 +1227,8 @@ class Inferencer(algorithm.Visitor):
|
||||
self._unify(node.signature_type, signature_type,
|
||||
node.name_loc, None)
|
||||
|
||||
visit_QuotedFunctionDefT = visit_FunctionDefT
|
||||
|
||||
def visit_ClassDefT(self, node):
|
||||
if any(node.decorator_list):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
|
@ -142,6 +142,8 @@ class IODelayEstimator(algorithm.Visitor):
|
||||
body = node.body
|
||||
self.visit_function(node.args, body, node.signature_type.find(), node.loc)
|
||||
|
||||
visit_QuotedFunctionDefT = visit_FunctionDefT
|
||||
|
||||
def visit_LambdaT(self, node):
|
||||
self.visit_function(node.args, node.body, node.type.find(), node.loc)
|
||||
|
||||
|
@ -947,8 +947,8 @@ class LLVMIRGenerator:
|
||||
name=insn.name)
|
||||
elif insn.op == "unwrap_or":
|
||||
lloptarg, lldefault = map(self.map, insn.operands)
|
||||
llhas_arg = self.llbuilder.extract_value(lloptarg, 0)
|
||||
llarg = self.llbuilder.extract_value(lloptarg, 1)
|
||||
llhas_arg = self.llbuilder.extract_value(lloptarg, 0, name="opt.has")
|
||||
llarg = self.llbuilder.extract_value(lloptarg, 1, name="opt.val")
|
||||
return self.llbuilder.select(llhas_arg, llarg, lldefault,
|
||||
name=insn.name)
|
||||
elif insn.op == "round":
|
||||
@ -1005,28 +1005,21 @@ class LLVMIRGenerator:
|
||||
def process_Closure(self, insn):
|
||||
llenv = self.map(insn.environment())
|
||||
llenv = self.llbuilder.bitcast(llenv, llptr, name="ptr.{}".format(llenv.name))
|
||||
if insn.target_function.name in self.function_map.values():
|
||||
# If this closure belongs to a quoted function, we assume this is the only
|
||||
# time that the closure is created, and record the environment globally
|
||||
llenvptr = self.get_or_define_global("E.{}".format(insn.target_function.name),
|
||||
llptr)
|
||||
self.llbuilder.store(llenv, llenvptr)
|
||||
|
||||
llfun = self.map(insn.target_function)
|
||||
llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined)
|
||||
llvalue = self.llbuilder.insert_value(llvalue, llenv, 0)
|
||||
llvalue = self.llbuilder.insert_value(llvalue, self.map(insn.target_function), 1,
|
||||
name=insn.name)
|
||||
llvalue = self.llbuilder.insert_value(llvalue, llfun, 1, name=insn.name)
|
||||
return llvalue
|
||||
|
||||
def _prepare_closure_call(self, insn):
|
||||
llargs = [self.map(arg) for arg in insn.arguments()]
|
||||
llclosure = self.map(insn.target_function())
|
||||
llenv = self.llbuilder.extract_value(llclosure, 0, name="env.call")
|
||||
if insn.static_target_function is None:
|
||||
llfun = self.llbuilder.extract_value(llclosure, 1,
|
||||
name="fun.{}".format(llclosure.name))
|
||||
else:
|
||||
llfun = self.map(insn.static_target_function)
|
||||
llenv = self.llbuilder.extract_value(llclosure, 0, name="env.fun")
|
||||
return llfun, [llenv] + list(llargs)
|
||||
|
||||
def _prepare_ffi_call(self, insn):
|
||||
@ -1315,13 +1308,8 @@ class LLVMIRGenerator:
|
||||
|
||||
def process_Quote(self, insn):
|
||||
if insn.value in self.function_map:
|
||||
func_name = self.function_map[insn.value]
|
||||
llenvptr = self.get_or_define_global("E.{}".format(func_name), llptr)
|
||||
llenv = self.llbuilder.load(llenvptr)
|
||||
llfun = self.get_function(insn.type.find(), func_name)
|
||||
|
||||
llfun = self.get_function(insn.type.find(), self.function_map[insn.value])
|
||||
llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
|
||||
llclosure = self.llbuilder.insert_value(llclosure, llenv, 0)
|
||||
llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name)
|
||||
return llclosure
|
||||
else:
|
||||
|
@ -277,6 +277,8 @@ class EscapeValidator(algorithm.Visitor):
|
||||
self.visit_in_region(node, Region(node.loc), node.typing_env,
|
||||
args={ arg.arg: Argument(arg.loc) for arg in node.args.args })
|
||||
|
||||
visit_QuotedFunctionDefT = visit_FunctionDefT
|
||||
|
||||
def visit_ClassDefT(self, node):
|
||||
self.youngest_env[node.name] = self.youngest_region
|
||||
self.visit_in_region(node, Region(node.loc), node.constructor_type.attributes)
|
||||
|
@ -24,6 +24,8 @@ class MonomorphismValidator(algorithm.Visitor):
|
||||
node.name_loc, notes=[note])
|
||||
self.engine.process(diag)
|
||||
|
||||
visit_QuotedFunctionDefT = visit_FunctionDefT
|
||||
|
||||
def generic_visit(self, node):
|
||||
super().generic_visit(node)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user