compiler: make quoted functions independent of outer environment.

This commit is contained in:
whitequark 2016-03-26 20:01:51 +00:00
parent f5c720c3ee
commit 186a564ba8
11 changed files with 74 additions and 38 deletions

View File

@ -29,6 +29,8 @@ class ClassDefT(ast.ClassDef):
_types = ("constructor_type",) _types = ("constructor_type",)
class FunctionDefT(ast.FunctionDef, scoped): class FunctionDefT(ast.FunctionDef, scoped):
_types = ("signature_type",) _types = ("signature_type",)
class QuotedFunctionDefT(FunctionDefT):
pass
class ModuleT(ast.Module, scoped): class ModuleT(ast.Module, scoped):
pass pass

View File

@ -16,6 +16,7 @@ from Levenshtein import ratio as similarity, jaro_winkler
from ..language import core as language_core from ..language import core as language_core
from . import types, builtins, asttyped, prelude from . import types, builtins, asttyped, prelude
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
from .transforms.asttyped_rewriter import LocalExtractor
def coredevice_print(x): print(x) def coredevice_print(x): print(x)
@ -241,6 +242,35 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
self.host_environment = host_environment self.host_environment = host_environment
self.quote = quote self.quote = quote
def visit_quoted_function(self, node, function):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
extractor.visit(node)
# We quote the defaults so they end up in the global data in LLVM IR.
# This way there is no "life before main", i.e. they do not have to be
# constructed before the main translated call executes; but the Python
# semantics is kept.
defaults = function.__defaults__ or ()
quoted_defaults = []
for default, default_node in zip(defaults, node.args.defaults):
quoted_defaults.append(self.quote(default, default_node.loc))
node.args.defaults = quoted_defaults
node = asttyped.QuotedFunctionDefT(
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
signature_type=types.TVar(), return_type=types.TVar(),
name=node.name, args=node.args, returns=node.returns,
body=node.body, decorator_list=node.decorator_list,
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
loc=node.loc)
try:
self.env_stack.append(node.typing_env)
return self.generic_visit(node)
finally:
self.env_stack.pop()
def visit_Name(self, node): def visit_Name(self, node):
typ = super()._try_find_name(node.id) typ = super()._try_find_name(node.id)
if typ is not None: if typ is not None:
@ -562,7 +592,7 @@ class Stitcher:
engine=self.engine, prelude=self.prelude, engine=self.engine, prelude=self.prelude,
globals=self.globals, host_environment=host_environment, globals=self.globals, host_environment=host_environment,
quote=self._quote) quote=self._quote)
function_node = asttyped_rewriter.visit(function_node) function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function)
# Add it into our typedtree so that it gets inferenced and codegen'd. # Add it into our typedtree so that it gets inferenced and codegen'd.
self._inject(function_node) self._inject(function_node)

View File

@ -94,6 +94,7 @@ class Target:
llpassmgr.add_sroa_pass() llpassmgr.add_sroa_pass()
llpassmgr.add_dead_code_elimination_pass() llpassmgr.add_dead_code_elimination_pass()
llpassmgr.add_gvn_pass() llpassmgr.add_gvn_pass()
llpassmgr.add_function_attrs_pass()
llpassmgr.run(llmodule) llpassmgr.run(llmodule)
def compile(self, module): def compile(self, module):

View File

@ -224,7 +224,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
finally: finally:
self.current_class = old_class self.current_class = old_class
def visit_function(self, node, is_lambda, is_internal): def visit_function(self, node, is_lambda=False, is_internal=False, is_quoted=False):
if is_lambda: if is_lambda:
name = "lambda@{}:{}".format(node.loc.line(), node.loc.column()) name = "lambda@{}:{}".format(node.loc.line(), node.loc.column())
typ = node.type.find() typ = node.type.find()
@ -234,16 +234,24 @@ class ARTIQIRGenerator(algorithm.Visitor):
try: try:
defaults = [] defaults = []
if not is_quoted:
for arg_name, default_node in zip(typ.optargs, node.args.defaults): for arg_name, default_node in zip(typ.optargs, node.args.defaults):
default = self.visit(default_node) default = self.visit(default_node)
env_default_name = \ env_default_name = \
self.current_env.type.add("$default." + arg_name, default.type) self.current_env.type.add("$default." + arg_name, default.type)
self.append(ir.SetLocal(self.current_env, env_default_name, default)) self.append(ir.SetLocal(self.current_env, env_default_name, default))
defaults.append(env_default_name) def codegen_default(env_default_name):
return lambda: self.append(ir.GetLocal(self.current_env, env_default_name))
defaults.append(codegen_default(env_default_name))
else:
for default_node in node.args.defaults:
def codegen_default(default_node):
return lambda: self.visit(default_node)
defaults.append(codegen_default(default_node))
old_name, self.name = self.name, self.name + [name] old_name, self.name = self.name, self.name + [name]
env_arg = ir.EnvironmentArgument(self.current_env.type, "CLS") env_arg = ir.EnvironmentArgument(self.current_env.type, "ARG.ENV")
old_args, self.current_args = self.current_args, {} old_args, self.current_args = self.current_args, {}
@ -291,8 +299,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.append(ir.SetLocal(env, "$outer", env_arg)) self.append(ir.SetLocal(env, "$outer", env_arg))
for index, arg_name in enumerate(typ.args): for index, arg_name in enumerate(typ.args):
self.append(ir.SetLocal(env, arg_name, args[index])) self.append(ir.SetLocal(env, arg_name, args[index]))
for index, (arg_name, env_default_name) in enumerate(zip(typ.optargs, defaults)): for index, (arg_name, codegen_default) in enumerate(zip(typ.optargs, defaults)):
default = self.append(ir.GetLocal(self.current_env, env_default_name)) default = codegen_default()
value = self.append(ir.Builtin("unwrap_or", [optargs[index], default], value = self.append(ir.Builtin("unwrap_or", [optargs[index], default],
typ.optargs[arg_name])) typ.optargs[arg_name]))
self.append(ir.SetLocal(env, arg_name, value)) self.append(ir.SetLocal(env, arg_name, value))
@ -320,14 +328,15 @@ class ARTIQIRGenerator(algorithm.Visitor):
return self.append(ir.Closure(func, self.current_env)) return self.append(ir.Closure(func, self.current_env))
def visit_FunctionDefT(self, node): def visit_FunctionDefT(self, node):
func = self.visit_function(node, is_lambda=False, func = self.visit_function(node, is_internal=len(self.name) > 0)
is_internal=len(self.name) > 0 or '.' in node.name)
if self.current_class is None: if self.current_class is None:
if node.name in self.current_env.type.params:
self._set_local(node.name, func) self._set_local(node.name, func)
else: else:
self.append(ir.SetAttr(self.current_class, node.name, func)) self.append(ir.SetAttr(self.current_class, node.name, func))
def visit_QuotedFunctionDefT(self, node):
self.visit_function(node, is_internal=True, is_quoted=True)
def visit_Return(self, node): def visit_Return(self, node):
if node.value is None: if node.value is None:
return_value = ir.Constant(None, builtins.TNone()) return_value = ir.Constant(None, builtins.TNone())
@ -1676,7 +1685,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
offset = 0 offset = 0
elif types.is_method(callee.type): elif types.is_method(callee.type):
func = self.append(ir.GetAttr(callee, "__func__", func = self.append(ir.GetAttr(callee, "__func__",
name="{}.CLS".format(callee.name))) name="{}.ENV".format(callee.name)))
self_arg = self.append(ir.GetAttr(callee, "__self__", self_arg = self.append(ir.GetAttr(callee, "__self__",
name="{}.SLF".format(callee.name))) name="{}.SLF".format(callee.name)))
fn_typ = types.get_method_function(callee.type) fn_typ = types.get_method_function(callee.type)

View File

@ -229,9 +229,7 @@ class ASTTypedRewriter(algorithm.Transformer):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
extractor.visit(node) extractor.visit(node)
signature_type = self._try_find_name(node.name) signature_type = self._find_name(node.name, node.name_loc)
if signature_type is None:
signature_type = types.TVar()
node = asttyped.FunctionDefT( node = asttyped.FunctionDefT(
typing_env=extractor.typing_env, globals_in_scope=extractor.global_, typing_env=extractor.typing_env, globals_in_scope=extractor.global_,

View File

@ -33,7 +33,7 @@ class DeadCodeEliminator:
# it also has to run after the interleaver, but interleaver # it also has to run after the interleaver, but interleaver
# doesn't like to work with IR before DCE. # doesn't like to work with IR before DCE.
if isinstance(insn, (ir.Phi, ir.Alloc, ir.GetAttr, ir.GetElem, ir.Coerce, if isinstance(insn, (ir.Phi, ir.Alloc, ir.GetAttr, ir.GetElem, ir.Coerce,
ir.Arith, ir.Compare, ir.Select, ir.Quote)) \ ir.Arith, ir.Compare, ir.Select, ir.Quote, ir.Closure)) \
and not any(insn.uses): and not any(insn.uses):
insn.erase() insn.erase()
modified = True modified = True

View File

@ -1227,6 +1227,8 @@ class Inferencer(algorithm.Visitor):
self._unify(node.signature_type, signature_type, self._unify(node.signature_type, signature_type,
node.name_loc, None) node.name_loc, None)
visit_QuotedFunctionDefT = visit_FunctionDefT
def visit_ClassDefT(self, node): def visit_ClassDefT(self, node):
if any(node.decorator_list): if any(node.decorator_list):
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",

View File

@ -142,6 +142,8 @@ class IODelayEstimator(algorithm.Visitor):
body = node.body body = node.body
self.visit_function(node.args, body, node.signature_type.find(), node.loc) self.visit_function(node.args, body, node.signature_type.find(), node.loc)
visit_QuotedFunctionDefT = visit_FunctionDefT
def visit_LambdaT(self, node): def visit_LambdaT(self, node):
self.visit_function(node.args, node.body, node.type.find(), node.loc) self.visit_function(node.args, node.body, node.type.find(), node.loc)

View File

@ -947,8 +947,8 @@ class LLVMIRGenerator:
name=insn.name) name=insn.name)
elif insn.op == "unwrap_or": elif insn.op == "unwrap_or":
lloptarg, lldefault = map(self.map, insn.operands) lloptarg, lldefault = map(self.map, insn.operands)
llhas_arg = self.llbuilder.extract_value(lloptarg, 0) llhas_arg = self.llbuilder.extract_value(lloptarg, 0, name="opt.has")
llarg = self.llbuilder.extract_value(lloptarg, 1) llarg = self.llbuilder.extract_value(lloptarg, 1, name="opt.val")
return self.llbuilder.select(llhas_arg, llarg, lldefault, return self.llbuilder.select(llhas_arg, llarg, lldefault,
name=insn.name) name=insn.name)
elif insn.op == "round": elif insn.op == "round":
@ -1005,28 +1005,21 @@ class LLVMIRGenerator:
def process_Closure(self, insn): def process_Closure(self, insn):
llenv = self.map(insn.environment()) llenv = self.map(insn.environment())
llenv = self.llbuilder.bitcast(llenv, llptr, name="ptr.{}".format(llenv.name)) llenv = self.llbuilder.bitcast(llenv, llptr, name="ptr.{}".format(llenv.name))
if insn.target_function.name in self.function_map.values(): llfun = self.map(insn.target_function)
# If this closure belongs to a quoted function, we assume this is the only
# time that the closure is created, and record the environment globally
llenvptr = self.get_or_define_global("E.{}".format(insn.target_function.name),
llptr)
self.llbuilder.store(llenv, llenvptr)
llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined) llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined)
llvalue = self.llbuilder.insert_value(llvalue, llenv, 0) llvalue = self.llbuilder.insert_value(llvalue, llenv, 0)
llvalue = self.llbuilder.insert_value(llvalue, self.map(insn.target_function), 1, llvalue = self.llbuilder.insert_value(llvalue, llfun, 1, name=insn.name)
name=insn.name)
return llvalue return llvalue
def _prepare_closure_call(self, insn): def _prepare_closure_call(self, insn):
llargs = [self.map(arg) for arg in insn.arguments()] llargs = [self.map(arg) for arg in insn.arguments()]
llclosure = self.map(insn.target_function()) llclosure = self.map(insn.target_function())
llenv = self.llbuilder.extract_value(llclosure, 0, name="env.call")
if insn.static_target_function is None: if insn.static_target_function is None:
llfun = self.llbuilder.extract_value(llclosure, 1, llfun = self.llbuilder.extract_value(llclosure, 1,
name="fun.{}".format(llclosure.name)) name="fun.{}".format(llclosure.name))
else: else:
llfun = self.map(insn.static_target_function) llfun = self.map(insn.static_target_function)
llenv = self.llbuilder.extract_value(llclosure, 0, name="env.fun")
return llfun, [llenv] + list(llargs) return llfun, [llenv] + list(llargs)
def _prepare_ffi_call(self, insn): def _prepare_ffi_call(self, insn):
@ -1315,13 +1308,8 @@ class LLVMIRGenerator:
def process_Quote(self, insn): def process_Quote(self, insn):
if insn.value in self.function_map: if insn.value in self.function_map:
func_name = self.function_map[insn.value] llfun = self.get_function(insn.type.find(), self.function_map[insn.value])
llenvptr = self.get_or_define_global("E.{}".format(func_name), llptr)
llenv = self.llbuilder.load(llenvptr)
llfun = self.get_function(insn.type.find(), func_name)
llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
llclosure = self.llbuilder.insert_value(llclosure, llenv, 0)
llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name) llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name)
return llclosure return llclosure
else: else:

View File

@ -277,6 +277,8 @@ class EscapeValidator(algorithm.Visitor):
self.visit_in_region(node, Region(node.loc), node.typing_env, self.visit_in_region(node, Region(node.loc), node.typing_env,
args={ arg.arg: Argument(arg.loc) for arg in node.args.args }) args={ arg.arg: Argument(arg.loc) for arg in node.args.args })
visit_QuotedFunctionDefT = visit_FunctionDefT
def visit_ClassDefT(self, node): def visit_ClassDefT(self, node):
self.youngest_env[node.name] = self.youngest_region self.youngest_env[node.name] = self.youngest_region
self.visit_in_region(node, Region(node.loc), node.constructor_type.attributes) self.visit_in_region(node, Region(node.loc), node.constructor_type.attributes)

View File

@ -24,6 +24,8 @@ class MonomorphismValidator(algorithm.Visitor):
node.name_loc, notes=[note]) node.name_loc, notes=[note])
self.engine.process(diag) self.engine.process(diag)
visit_QuotedFunctionDefT = visit_FunctionDefT
def generic_visit(self, node): def generic_visit(self, node):
super().generic_visit(node) super().generic_visit(node)