forked from M-Labs/artiq
compiler: make quoted functions independent of outer environment.
This commit is contained in:
parent
f5c720c3ee
commit
186a564ba8
@ -29,6 +29,8 @@ class ClassDefT(ast.ClassDef):
|
|||||||
_types = ("constructor_type",)
|
_types = ("constructor_type",)
|
||||||
class FunctionDefT(ast.FunctionDef, scoped):
|
class FunctionDefT(ast.FunctionDef, scoped):
|
||||||
_types = ("signature_type",)
|
_types = ("signature_type",)
|
||||||
|
class QuotedFunctionDefT(FunctionDefT):
|
||||||
|
pass
|
||||||
class ModuleT(ast.Module, scoped):
|
class ModuleT(ast.Module, scoped):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ from Levenshtein import ratio as similarity, jaro_winkler
|
|||||||
from ..language import core as language_core
|
from ..language import core as language_core
|
||||||
from . import types, builtins, asttyped, prelude
|
from . import types, builtins, asttyped, prelude
|
||||||
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
|
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
|
||||||
|
from .transforms.asttyped_rewriter import LocalExtractor
|
||||||
|
|
||||||
|
|
||||||
def coredevice_print(x): print(x)
|
def coredevice_print(x): print(x)
|
||||||
@ -241,6 +242,35 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
|
|||||||
self.host_environment = host_environment
|
self.host_environment = host_environment
|
||||||
self.quote = quote
|
self.quote = quote
|
||||||
|
|
||||||
|
def visit_quoted_function(self, node, function):
|
||||||
|
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||||
|
extractor.visit(node)
|
||||||
|
|
||||||
|
# We quote the defaults so they end up in the global data in LLVM IR.
|
||||||
|
# This way there is no "life before main", i.e. they do not have to be
|
||||||
|
# constructed before the main translated call executes; but the Python
|
||||||
|
# semantics is kept.
|
||||||
|
defaults = function.__defaults__ or ()
|
||||||
|
quoted_defaults = []
|
||||||
|
for default, default_node in zip(defaults, node.args.defaults):
|
||||||
|
quoted_defaults.append(self.quote(default, default_node.loc))
|
||||||
|
node.args.defaults = quoted_defaults
|
||||||
|
|
||||||
|
node = asttyped.QuotedFunctionDefT(
|
||||||
|
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
||||||
|
signature_type=types.TVar(), return_type=types.TVar(),
|
||||||
|
name=node.name, args=node.args, returns=node.returns,
|
||||||
|
body=node.body, decorator_list=node.decorator_list,
|
||||||
|
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
|
||||||
|
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
|
||||||
|
loc=node.loc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.env_stack.append(node.typing_env)
|
||||||
|
return self.generic_visit(node)
|
||||||
|
finally:
|
||||||
|
self.env_stack.pop()
|
||||||
|
|
||||||
def visit_Name(self, node):
|
def visit_Name(self, node):
|
||||||
typ = super()._try_find_name(node.id)
|
typ = super()._try_find_name(node.id)
|
||||||
if typ is not None:
|
if typ is not None:
|
||||||
@ -562,7 +592,7 @@ class Stitcher:
|
|||||||
engine=self.engine, prelude=self.prelude,
|
engine=self.engine, prelude=self.prelude,
|
||||||
globals=self.globals, host_environment=host_environment,
|
globals=self.globals, host_environment=host_environment,
|
||||||
quote=self._quote)
|
quote=self._quote)
|
||||||
function_node = asttyped_rewriter.visit(function_node)
|
function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function)
|
||||||
|
|
||||||
# Add it into our typedtree so that it gets inferenced and codegen'd.
|
# Add it into our typedtree so that it gets inferenced and codegen'd.
|
||||||
self._inject(function_node)
|
self._inject(function_node)
|
||||||
|
@ -94,6 +94,7 @@ class Target:
|
|||||||
llpassmgr.add_sroa_pass()
|
llpassmgr.add_sroa_pass()
|
||||||
llpassmgr.add_dead_code_elimination_pass()
|
llpassmgr.add_dead_code_elimination_pass()
|
||||||
llpassmgr.add_gvn_pass()
|
llpassmgr.add_gvn_pass()
|
||||||
|
llpassmgr.add_function_attrs_pass()
|
||||||
llpassmgr.run(llmodule)
|
llpassmgr.run(llmodule)
|
||||||
|
|
||||||
def compile(self, module):
|
def compile(self, module):
|
||||||
|
@ -224,7 +224,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
finally:
|
finally:
|
||||||
self.current_class = old_class
|
self.current_class = old_class
|
||||||
|
|
||||||
def visit_function(self, node, is_lambda, is_internal):
|
def visit_function(self, node, is_lambda=False, is_internal=False, is_quoted=False):
|
||||||
if is_lambda:
|
if is_lambda:
|
||||||
name = "lambda@{}:{}".format(node.loc.line(), node.loc.column())
|
name = "lambda@{}:{}".format(node.loc.line(), node.loc.column())
|
||||||
typ = node.type.find()
|
typ = node.type.find()
|
||||||
@ -234,16 +234,24 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
defaults = []
|
defaults = []
|
||||||
|
if not is_quoted:
|
||||||
for arg_name, default_node in zip(typ.optargs, node.args.defaults):
|
for arg_name, default_node in zip(typ.optargs, node.args.defaults):
|
||||||
default = self.visit(default_node)
|
default = self.visit(default_node)
|
||||||
env_default_name = \
|
env_default_name = \
|
||||||
self.current_env.type.add("$default." + arg_name, default.type)
|
self.current_env.type.add("$default." + arg_name, default.type)
|
||||||
self.append(ir.SetLocal(self.current_env, env_default_name, default))
|
self.append(ir.SetLocal(self.current_env, env_default_name, default))
|
||||||
defaults.append(env_default_name)
|
def codegen_default(env_default_name):
|
||||||
|
return lambda: self.append(ir.GetLocal(self.current_env, env_default_name))
|
||||||
|
defaults.append(codegen_default(env_default_name))
|
||||||
|
else:
|
||||||
|
for default_node in node.args.defaults:
|
||||||
|
def codegen_default(default_node):
|
||||||
|
return lambda: self.visit(default_node)
|
||||||
|
defaults.append(codegen_default(default_node))
|
||||||
|
|
||||||
old_name, self.name = self.name, self.name + [name]
|
old_name, self.name = self.name, self.name + [name]
|
||||||
|
|
||||||
env_arg = ir.EnvironmentArgument(self.current_env.type, "CLS")
|
env_arg = ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")
|
||||||
|
|
||||||
old_args, self.current_args = self.current_args, {}
|
old_args, self.current_args = self.current_args, {}
|
||||||
|
|
||||||
@ -291,8 +299,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
self.append(ir.SetLocal(env, "$outer", env_arg))
|
self.append(ir.SetLocal(env, "$outer", env_arg))
|
||||||
for index, arg_name in enumerate(typ.args):
|
for index, arg_name in enumerate(typ.args):
|
||||||
self.append(ir.SetLocal(env, arg_name, args[index]))
|
self.append(ir.SetLocal(env, arg_name, args[index]))
|
||||||
for index, (arg_name, env_default_name) in enumerate(zip(typ.optargs, defaults)):
|
for index, (arg_name, codegen_default) in enumerate(zip(typ.optargs, defaults)):
|
||||||
default = self.append(ir.GetLocal(self.current_env, env_default_name))
|
default = codegen_default()
|
||||||
value = self.append(ir.Builtin("unwrap_or", [optargs[index], default],
|
value = self.append(ir.Builtin("unwrap_or", [optargs[index], default],
|
||||||
typ.optargs[arg_name]))
|
typ.optargs[arg_name]))
|
||||||
self.append(ir.SetLocal(env, arg_name, value))
|
self.append(ir.SetLocal(env, arg_name, value))
|
||||||
@ -320,14 +328,15 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
return self.append(ir.Closure(func, self.current_env))
|
return self.append(ir.Closure(func, self.current_env))
|
||||||
|
|
||||||
def visit_FunctionDefT(self, node):
|
def visit_FunctionDefT(self, node):
|
||||||
func = self.visit_function(node, is_lambda=False,
|
func = self.visit_function(node, is_internal=len(self.name) > 0)
|
||||||
is_internal=len(self.name) > 0 or '.' in node.name)
|
|
||||||
if self.current_class is None:
|
if self.current_class is None:
|
||||||
if node.name in self.current_env.type.params:
|
|
||||||
self._set_local(node.name, func)
|
self._set_local(node.name, func)
|
||||||
else:
|
else:
|
||||||
self.append(ir.SetAttr(self.current_class, node.name, func))
|
self.append(ir.SetAttr(self.current_class, node.name, func))
|
||||||
|
|
||||||
|
def visit_QuotedFunctionDefT(self, node):
|
||||||
|
self.visit_function(node, is_internal=True, is_quoted=True)
|
||||||
|
|
||||||
def visit_Return(self, node):
|
def visit_Return(self, node):
|
||||||
if node.value is None:
|
if node.value is None:
|
||||||
return_value = ir.Constant(None, builtins.TNone())
|
return_value = ir.Constant(None, builtins.TNone())
|
||||||
@ -1676,7 +1685,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||||||
offset = 0
|
offset = 0
|
||||||
elif types.is_method(callee.type):
|
elif types.is_method(callee.type):
|
||||||
func = self.append(ir.GetAttr(callee, "__func__",
|
func = self.append(ir.GetAttr(callee, "__func__",
|
||||||
name="{}.CLS".format(callee.name)))
|
name="{}.ENV".format(callee.name)))
|
||||||
self_arg = self.append(ir.GetAttr(callee, "__self__",
|
self_arg = self.append(ir.GetAttr(callee, "__self__",
|
||||||
name="{}.SLF".format(callee.name)))
|
name="{}.SLF".format(callee.name)))
|
||||||
fn_typ = types.get_method_function(callee.type)
|
fn_typ = types.get_method_function(callee.type)
|
||||||
|
@ -229,9 +229,7 @@ class ASTTypedRewriter(algorithm.Transformer):
|
|||||||
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||||
extractor.visit(node)
|
extractor.visit(node)
|
||||||
|
|
||||||
signature_type = self._try_find_name(node.name)
|
signature_type = self._find_name(node.name, node.name_loc)
|
||||||
if signature_type is None:
|
|
||||||
signature_type = types.TVar()
|
|
||||||
|
|
||||||
node = asttyped.FunctionDefT(
|
node = asttyped.FunctionDefT(
|
||||||
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
||||||
|
@ -33,7 +33,7 @@ class DeadCodeEliminator:
|
|||||||
# it also has to run after the interleaver, but interleaver
|
# it also has to run after the interleaver, but interleaver
|
||||||
# doesn't like to work with IR before DCE.
|
# doesn't like to work with IR before DCE.
|
||||||
if isinstance(insn, (ir.Phi, ir.Alloc, ir.GetAttr, ir.GetElem, ir.Coerce,
|
if isinstance(insn, (ir.Phi, ir.Alloc, ir.GetAttr, ir.GetElem, ir.Coerce,
|
||||||
ir.Arith, ir.Compare, ir.Select, ir.Quote)) \
|
ir.Arith, ir.Compare, ir.Select, ir.Quote, ir.Closure)) \
|
||||||
and not any(insn.uses):
|
and not any(insn.uses):
|
||||||
insn.erase()
|
insn.erase()
|
||||||
modified = True
|
modified = True
|
||||||
|
@ -1227,6 +1227,8 @@ class Inferencer(algorithm.Visitor):
|
|||||||
self._unify(node.signature_type, signature_type,
|
self._unify(node.signature_type, signature_type,
|
||||||
node.name_loc, None)
|
node.name_loc, None)
|
||||||
|
|
||||||
|
visit_QuotedFunctionDefT = visit_FunctionDefT
|
||||||
|
|
||||||
def visit_ClassDefT(self, node):
|
def visit_ClassDefT(self, node):
|
||||||
if any(node.decorator_list):
|
if any(node.decorator_list):
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
@ -142,6 +142,8 @@ class IODelayEstimator(algorithm.Visitor):
|
|||||||
body = node.body
|
body = node.body
|
||||||
self.visit_function(node.args, body, node.signature_type.find(), node.loc)
|
self.visit_function(node.args, body, node.signature_type.find(), node.loc)
|
||||||
|
|
||||||
|
visit_QuotedFunctionDefT = visit_FunctionDefT
|
||||||
|
|
||||||
def visit_LambdaT(self, node):
|
def visit_LambdaT(self, node):
|
||||||
self.visit_function(node.args, node.body, node.type.find(), node.loc)
|
self.visit_function(node.args, node.body, node.type.find(), node.loc)
|
||||||
|
|
||||||
|
@ -947,8 +947,8 @@ class LLVMIRGenerator:
|
|||||||
name=insn.name)
|
name=insn.name)
|
||||||
elif insn.op == "unwrap_or":
|
elif insn.op == "unwrap_or":
|
||||||
lloptarg, lldefault = map(self.map, insn.operands)
|
lloptarg, lldefault = map(self.map, insn.operands)
|
||||||
llhas_arg = self.llbuilder.extract_value(lloptarg, 0)
|
llhas_arg = self.llbuilder.extract_value(lloptarg, 0, name="opt.has")
|
||||||
llarg = self.llbuilder.extract_value(lloptarg, 1)
|
llarg = self.llbuilder.extract_value(lloptarg, 1, name="opt.val")
|
||||||
return self.llbuilder.select(llhas_arg, llarg, lldefault,
|
return self.llbuilder.select(llhas_arg, llarg, lldefault,
|
||||||
name=insn.name)
|
name=insn.name)
|
||||||
elif insn.op == "round":
|
elif insn.op == "round":
|
||||||
@ -1005,28 +1005,21 @@ class LLVMIRGenerator:
|
|||||||
def process_Closure(self, insn):
|
def process_Closure(self, insn):
|
||||||
llenv = self.map(insn.environment())
|
llenv = self.map(insn.environment())
|
||||||
llenv = self.llbuilder.bitcast(llenv, llptr, name="ptr.{}".format(llenv.name))
|
llenv = self.llbuilder.bitcast(llenv, llptr, name="ptr.{}".format(llenv.name))
|
||||||
if insn.target_function.name in self.function_map.values():
|
llfun = self.map(insn.target_function)
|
||||||
# If this closure belongs to a quoted function, we assume this is the only
|
|
||||||
# time that the closure is created, and record the environment globally
|
|
||||||
llenvptr = self.get_or_define_global("E.{}".format(insn.target_function.name),
|
|
||||||
llptr)
|
|
||||||
self.llbuilder.store(llenv, llenvptr)
|
|
||||||
|
|
||||||
llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined)
|
llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined)
|
||||||
llvalue = self.llbuilder.insert_value(llvalue, llenv, 0)
|
llvalue = self.llbuilder.insert_value(llvalue, llenv, 0)
|
||||||
llvalue = self.llbuilder.insert_value(llvalue, self.map(insn.target_function), 1,
|
llvalue = self.llbuilder.insert_value(llvalue, llfun, 1, name=insn.name)
|
||||||
name=insn.name)
|
|
||||||
return llvalue
|
return llvalue
|
||||||
|
|
||||||
def _prepare_closure_call(self, insn):
|
def _prepare_closure_call(self, insn):
|
||||||
llargs = [self.map(arg) for arg in insn.arguments()]
|
llargs = [self.map(arg) for arg in insn.arguments()]
|
||||||
llclosure = self.map(insn.target_function())
|
llclosure = self.map(insn.target_function())
|
||||||
llenv = self.llbuilder.extract_value(llclosure, 0, name="env.call")
|
|
||||||
if insn.static_target_function is None:
|
if insn.static_target_function is None:
|
||||||
llfun = self.llbuilder.extract_value(llclosure, 1,
|
llfun = self.llbuilder.extract_value(llclosure, 1,
|
||||||
name="fun.{}".format(llclosure.name))
|
name="fun.{}".format(llclosure.name))
|
||||||
else:
|
else:
|
||||||
llfun = self.map(insn.static_target_function)
|
llfun = self.map(insn.static_target_function)
|
||||||
|
llenv = self.llbuilder.extract_value(llclosure, 0, name="env.fun")
|
||||||
return llfun, [llenv] + list(llargs)
|
return llfun, [llenv] + list(llargs)
|
||||||
|
|
||||||
def _prepare_ffi_call(self, insn):
|
def _prepare_ffi_call(self, insn):
|
||||||
@ -1315,13 +1308,8 @@ class LLVMIRGenerator:
|
|||||||
|
|
||||||
def process_Quote(self, insn):
|
def process_Quote(self, insn):
|
||||||
if insn.value in self.function_map:
|
if insn.value in self.function_map:
|
||||||
func_name = self.function_map[insn.value]
|
llfun = self.get_function(insn.type.find(), self.function_map[insn.value])
|
||||||
llenvptr = self.get_or_define_global("E.{}".format(func_name), llptr)
|
|
||||||
llenv = self.llbuilder.load(llenvptr)
|
|
||||||
llfun = self.get_function(insn.type.find(), func_name)
|
|
||||||
|
|
||||||
llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
|
llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
|
||||||
llclosure = self.llbuilder.insert_value(llclosure, llenv, 0)
|
|
||||||
llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name)
|
llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name)
|
||||||
return llclosure
|
return llclosure
|
||||||
else:
|
else:
|
||||||
|
@ -277,6 +277,8 @@ class EscapeValidator(algorithm.Visitor):
|
|||||||
self.visit_in_region(node, Region(node.loc), node.typing_env,
|
self.visit_in_region(node, Region(node.loc), node.typing_env,
|
||||||
args={ arg.arg: Argument(arg.loc) for arg in node.args.args })
|
args={ arg.arg: Argument(arg.loc) for arg in node.args.args })
|
||||||
|
|
||||||
|
visit_QuotedFunctionDefT = visit_FunctionDefT
|
||||||
|
|
||||||
def visit_ClassDefT(self, node):
|
def visit_ClassDefT(self, node):
|
||||||
self.youngest_env[node.name] = self.youngest_region
|
self.youngest_env[node.name] = self.youngest_region
|
||||||
self.visit_in_region(node, Region(node.loc), node.constructor_type.attributes)
|
self.visit_in_region(node, Region(node.loc), node.constructor_type.attributes)
|
||||||
|
@ -24,6 +24,8 @@ class MonomorphismValidator(algorithm.Visitor):
|
|||||||
node.name_loc, notes=[note])
|
node.name_loc, notes=[note])
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
visit_QuotedFunctionDefT = visit_FunctionDefT
|
||||||
|
|
||||||
def generic_visit(self, node):
|
def generic_visit(self, node):
|
||||||
super().generic_visit(node)
|
super().generic_visit(node)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user