compiler: quote functions directly instead of going through a local.

This commit is contained in:
whitequark 2016-03-25 21:03:19 +00:00
parent 39d23793a4
commit 8d0566661a
8 changed files with 101 additions and 97 deletions

View File

@ -107,11 +107,8 @@ class ASTSynthesizer:
unquote_loc = self._add('`') unquote_loc = self._add('`')
loc = quote_loc.join(unquote_loc) loc = quote_loc.join(unquote_loc)
function_name, function_type = self.quote_function(value, self.expanded_from) function_type = self.quote_function(value, self.expanded_from)
if function_name is None: return asttyped.QuoteT(value=value, type=function_type, loc=loc)
return asttyped.QuoteT(value=value, type=function_type, loc=loc)
else:
return asttyped.NameT(id=function_name, ctx=None, type=function_type, loc=loc)
else: else:
quote_loc = self._add('`') quote_loc = self._add('`')
repr_loc = self._add(repr(value)) repr_loc = self._add(repr(value))
@ -155,7 +152,7 @@ class ASTSynthesizer:
return asttyped.QuoteT(value=value, type=instance_type, return asttyped.QuoteT(value=value, type=instance_type,
loc=loc) loc=loc)
def call(self, function_node, args, kwargs, callback=None): def call(self, callee, args, kwargs, callback=None):
""" """
Construct an AST fragment calling a function specified by Construct an AST fragment calling a function specified by
an AST node `function_node`, with given arguments. an AST node `function_node`, with given arguments.
@ -164,11 +161,11 @@ class ASTSynthesizer:
callback_node = self.quote(callback) callback_node = self.quote(callback)
cb_begin_loc = self._add("(") cb_begin_loc = self._add("(")
callee_node = self.quote(callee)
arg_nodes = [] arg_nodes = []
kwarg_nodes = [] kwarg_nodes = []
kwarg_locs = [] kwarg_locs = []
name_loc = self._add(function_node.name)
begin_loc = self._add("(") begin_loc = self._add("(")
for index, arg in enumerate(args): for index, arg in enumerate(args):
arg_nodes.append(self.quote(arg)) arg_nodes.append(self.quote(arg))
@ -189,9 +186,7 @@ class ASTSynthesizer:
cb_end_loc = self._add(")") cb_end_loc = self._add(")")
node = asttyped.CallT( node = asttyped.CallT(
func=asttyped.NameT(id=function_node.name, ctx=None, func=callee_node,
type=function_node.signature_type,
loc=name_loc),
args=arg_nodes, args=arg_nodes,
keywords=[ast.keyword(arg=kw, value=value, keywords=[ast.keyword(arg=kw, value=value,
arg_loc=arg_loc, equals_loc=equals_loc, arg_loc=arg_loc, equals_loc=equals_loc,
@ -201,7 +196,7 @@ class ASTSynthesizer:
starargs=None, kwargs=None, starargs=None, kwargs=None,
type=types.TVar(), iodelay=None, arg_exprs={}, type=types.TVar(), iodelay=None, arg_exprs={},
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None, begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
loc=name_loc.join(end_loc)) loc=callee_node.loc.join(end_loc))
if callback is not None: if callback is not None:
node = asttyped.CallT( node = asttyped.CallT(
@ -213,19 +208,6 @@ class ASTSynthesizer:
return node return node
def assign_local(self, var_name, value):
name_loc = self._add(var_name)
_ = self._add(" ")
equals_loc = self._add("=")
_ = self._add(" ")
value_node = self.quote(value)
var_node = asttyped.NameT(id=var_name, ctx=None, type=value_node.type,
loc=name_loc)
return ast.Assign(targets=[var_node], value=value_node,
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))
def assign_attribute(self, obj, attr_name, value): def assign_attribute(self, obj, attr_name, value):
obj_node = self.quote(obj) obj_node = self.quote(obj)
dot_loc = self._add(".") dot_loc = self._add(".")
@ -465,18 +447,16 @@ class Stitcher:
self.functions = {} self.functions = {}
self.function_map = {}
self.object_map = ObjectMap() self.object_map = ObjectMap()
self.type_map = {} self.type_map = {}
self.value_map = defaultdict(lambda: []) self.value_map = defaultdict(lambda: [])
def stitch_call(self, function, args, kwargs, callback=None): def stitch_call(self, function, args, kwargs, callback=None):
function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node)
# We synthesize source code for the initial call so that # We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user. # diagnostics would have something meaningful to display to the user.
synthesizer = self._synthesizer(self._function_loc(function.artiq_embedded.function)) synthesizer = self._synthesizer(self._function_loc(function.artiq_embedded.function))
call_node = synthesizer.call(function_node, args, kwargs, callback) call_node = synthesizer.call(function, args, kwargs, callback)
synthesizer.finalize() synthesizer.finalize()
self.typedtree.append(call_node) self.typedtree.append(call_node)
@ -496,9 +476,8 @@ class Stitcher:
break break
old_typedtree_hash = typedtree_hash old_typedtree_hash = typedtree_hash
# For every host class we embed, add an appropriate constructor # For every host class we embed, fill in the function slots
# as a global. This is necessary for method lookup, which uses # with their corresponding closures.
# the getconstructor instruction.
for instance_type, constructor_type in list(self.type_map.values()): for instance_type, constructor_type in list(self.type_map.values()):
# Do we have any direct reference to a constructor? # Do we have any direct reference to a constructor?
if len(self.value_map[constructor_type]) > 0: if len(self.value_map[constructor_type]) > 0:
@ -509,13 +488,6 @@ class Stitcher:
instance, _instance_loc = self.value_map[instance_type][0] instance, _instance_loc = self.value_map[instance_type][0]
constructor = type(instance) constructor = type(instance)
self.globals[constructor_type.name] = constructor_type
synthesizer = self._synthesizer()
ast = synthesizer.assign_local(constructor_type.name, constructor)
synthesizer.finalize()
self._inject(ast)
for attr in constructor_type.attributes: for attr in constructor_type.attributes:
if types.is_function(constructor_type.attributes[attr]): if types.is_function(constructor_type.attributes[attr]):
synthesizer = self._synthesizer() synthesizer = self._synthesizer()
@ -577,23 +549,28 @@ class Stitcher:
# Mangle the name, since we put everything into a single module. # Mangle the name, since we put everything into a single module.
function_node.name = "{}.{}".format(module_name, function.__qualname__) function_node.name = "{}.{}".format(module_name, function.__qualname__)
# Normally, LocalExtractor would populate the typing environment # Record the function in the function map so that LLVM IR generator
# of the module with the function name. However, since we run # can handle quoting it.
# ASTTypedRewriter on the function node directly, we need to do it self.function_map[function] = function_node.name
# explicitly.
function_type = types.TVar()
self.globals[function_node.name] = function_type
# Memoize the function before typing it to handle recursive # Memoize the function type before typing it to handle recursive
# invocations. # invocations.
self.functions[function] = function_node.name, function_type self.functions[function] = types.TVar()
# Rewrite into typed form. # Rewrite into typed form.
asttyped_rewriter = StitchingASTTypedRewriter( asttyped_rewriter = StitchingASTTypedRewriter(
engine=self.engine, prelude=self.prelude, engine=self.engine, prelude=self.prelude,
globals=self.globals, host_environment=host_environment, globals=self.globals, host_environment=host_environment,
quote=self._quote) quote=self._quote)
return asttyped_rewriter.visit(function_node) function_node = asttyped_rewriter.visit(function_node)
# Add it into our typedtree so that it gets inferenced and codegen'd.
self._inject(function_node)
# Tie the typing knot.
self.functions[function].unify(function_node.signature_type)
return function_node
def _function_loc(self, function): def _function_loc(self, function):
filename = function.__code__.co_filename filename = function.__code__.co_filename
@ -734,14 +711,12 @@ class Stitcher:
function_type = types.TCFunction(arg_types, ret_type, function_type = types.TCFunction(arg_types, ret_type,
name=syscall) name=syscall)
self.functions[function] = None, function_type self.functions[function] = function_type
return None, function_type return function_type
def _quote_function(self, function, loc): def _quote_function(self, function, loc):
if function in self.functions: if function not in self.functions:
result = self.functions[function]
else:
if hasattr(function, "artiq_embedded"): if hasattr(function, "artiq_embedded"):
if function.artiq_embedded.function is not None: if function.artiq_embedded.function is not None:
if function.__name__ == "<lambda>": if function.__name__ == "<lambda>":
@ -766,17 +741,12 @@ class Stitcher:
notes=[note]) notes=[note])
self.engine.process(diag) self.engine.process(diag)
# Insert the typed AST for the new function and restart inference. self._quote_embedded_function(function)
# It doesn't really matter where we insert as long as it is before
# the final call.
function_node = self._quote_embedded_function(function)
self._inject(function_node)
result = function_node.name, self.globals[function_node.name]
elif function.artiq_embedded.syscall is not None: elif function.artiq_embedded.syscall is not None:
# Insert a storage-less global whose type instructs the compiler # Insert a storage-less global whose type instructs the compiler
# to perform a system call instead of a regular call. # to perform a system call instead of a regular call.
result = self._quote_foreign_function(function, loc, self._quote_foreign_function(function, loc,
syscall=function.artiq_embedded.syscall) syscall=function.artiq_embedded.syscall)
elif function.artiq_embedded.forbidden is not None: elif function.artiq_embedded.forbidden is not None:
diag = diagnostic.Diagnostic("fatal", diag = diagnostic.Diagnostic("fatal",
"this function cannot be called as an RPC", {}, "this function cannot be called as an RPC", {},
@ -788,12 +758,12 @@ class Stitcher:
else: else:
# Insert a storage-less global whose type instructs the compiler # Insert a storage-less global whose type instructs the compiler
# to perform an RPC instead of a regular call. # to perform an RPC instead of a regular call.
result = self._quote_foreign_function(function, loc, syscall=None) self._quote_foreign_function(function, loc, syscall=None)
function_name, function_type = result function_type = self.functions[function]
if types.is_rpc_function(function_type): if types.is_rpc_function(function_type):
function_type = types.instantiate(function_type) function_type = types.instantiate(function_type)
return function_name, function_type return function_type
def _quote(self, value, loc): def _quote(self, value, loc):
synthesizer = self._synthesizer(loc) synthesizer = self._synthesizer(loc)

View File

@ -19,6 +19,7 @@ class Source:
else: else:
self.engine = engine self.engine = engine
self.function_map = {}
self.object_map = None self.object_map = None
self.type_map = {} self.type_map = {}
@ -45,6 +46,7 @@ class Source:
class Module: class Module:
def __init__(self, src, ref_period=1e-6): def __init__(self, src, ref_period=1e-6):
self.engine = src.engine self.engine = src.engine
self.function_map = src.function_map
self.object_map = src.object_map self.object_map = src.object_map
self.type_map = src.type_map self.type_map = src.type_map
@ -80,7 +82,7 @@ class Module:
"""Compile the module to LLVM IR for the specified target.""" """Compile the module to LLVM IR for the specified target."""
llvm_ir_generator = transforms.LLVMIRGenerator( llvm_ir_generator = transforms.LLVMIRGenerator(
engine=self.engine, module_name=self.name, target=target, engine=self.engine, module_name=self.name, target=target,
object_map=self.object_map, type_map=self.type_map) function_map=self.function_map, object_map=self.object_map, type_map=self.type_map)
return llvm_ir_generator.process(self.artiq_ir, attribute_writeback=True) return llvm_ir_generator.process(self.artiq_ir, attribute_writeback=True)
def entry_point(self): def entry_point(self):

View File

@ -319,13 +319,14 @@ class ARTIQIRGenerator(algorithm.Visitor):
return self.append(ir.Closure(func, self.current_env)) return self.append(ir.Closure(func, self.current_env))
def visit_FunctionDefT(self, node, in_class=None): def visit_FunctionDefT(self, node):
func = self.visit_function(node, is_lambda=False, func = self.visit_function(node, is_lambda=False,
is_internal=len(self.name) > 0 or '.' in node.name) is_internal=len(self.name) > 0 or '.' in node.name)
if in_class is None: if self.current_class is None:
self._set_local(node.name, func) if node.name in self.current_env.type.params:
self._set_local(node.name, func)
else: else:
self.append(ir.SetAttr(in_class, node.name, func)) self.append(ir.SetAttr(self.current_class, node.name, func))
def visit_Return(self, node): def visit_Return(self, node):
if node.value is None: if node.value is None:

View File

@ -190,16 +190,16 @@ class ASTTypedRewriter(algorithm.Transformer):
self.in_class = None self.in_class = None
def _try_find_name(self, name): def _try_find_name(self, name):
for typing_env in reversed(self.env_stack):
if name in typing_env:
return typing_env[name]
def _find_name(self, name, loc):
if self.in_class is not None: if self.in_class is not None:
typ = self.in_class.constructor_type.attributes.get(name) typ = self.in_class.constructor_type.attributes.get(name)
if typ is not None: if typ is not None:
return typ return typ
for typing_env in reversed(self.env_stack):
if name in typing_env:
return typing_env[name]
def _find_name(self, name, loc):
typ = self._try_find_name(name) typ = self._try_find_name(name)
if typ is not None: if typ is not None:
return typ return typ
@ -229,9 +229,13 @@ class ASTTypedRewriter(algorithm.Transformer):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
extractor.visit(node) extractor.visit(node)
signature_type = self._try_find_name(node.name)
if signature_type is None:
signature_type = types.TVar()
node = asttyped.FunctionDefT( node = asttyped.FunctionDefT(
typing_env=extractor.typing_env, globals_in_scope=extractor.global_, typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
signature_type=self._find_name(node.name, node.name_loc), return_type=types.TVar(), signature_type=signature_type, return_type=types.TVar(),
name=node.name, args=node.args, returns=node.returns, name=node.name, args=node.args, returns=node.returns,
body=node.body, decorator_list=node.decorator_list, body=node.body, decorator_list=node.decorator_list,
keyword_loc=node.keyword_loc, name_loc=node.name_loc, keyword_loc=node.keyword_loc, name_loc=node.name_loc,

View File

@ -33,7 +33,7 @@ class DeadCodeEliminator:
# it also has to run after the interleaver, but interleaver # it also has to run after the interleaver, but interleaver
# doesn't like to work with IR before DCE. # doesn't like to work with IR before DCE.
if isinstance(insn, (ir.Phi, ir.Alloc, ir.GetAttr, ir.GetElem, ir.Coerce, if isinstance(insn, (ir.Phi, ir.Alloc, ir.GetAttr, ir.GetElem, ir.Coerce,
ir.Arith, ir.Compare, ir.Closure, ir.Select, ir.Quote)) \ ir.Arith, ir.Compare, ir.Select, ir.Quote)) \
and not any(insn.uses): and not any(insn.uses):
insn.erase() insn.erase()
modified = True modified = True

View File

@ -164,9 +164,10 @@ class DebugInfoEmitter:
class LLVMIRGenerator: class LLVMIRGenerator:
def __init__(self, engine, module_name, target, object_map, type_map): def __init__(self, engine, module_name, target, function_map, object_map, type_map):
self.engine = engine self.engine = engine
self.target = target self.target = target
self.function_map = function_map
self.object_map = object_map self.object_map = object_map
self.type_map = type_map self.type_map = type_map
self.llcontext = target.llcontext self.llcontext = target.llcontext
@ -393,22 +394,24 @@ class LLVMIRGenerator:
return llglobal return llglobal
def get_function(self, typ, name):
llfun = self.llmodule.get_global(name)
if llfun is None:
llfunty = self.llty_of_type(typ, bare=True)
llfun = ll.Function(self.llmodule, llfunty, name)
llretty = self.llty_of_type(typ.ret, for_return=True)
if self.needs_sret(llretty):
llfun.args[0].add_attribute('sret')
return llfun
def map(self, value): def map(self, value):
if isinstance(value, (ir.Argument, ir.Instruction, ir.BasicBlock)): if isinstance(value, (ir.Argument, ir.Instruction, ir.BasicBlock)):
return self.llmap[value] return self.llmap[value]
elif isinstance(value, ir.Constant): elif isinstance(value, ir.Constant):
return self.llconst_of_const(value) return self.llconst_of_const(value)
elif isinstance(value, ir.Function): elif isinstance(value, ir.Function):
llfun = self.llmodule.get_global(value.name) return self.get_function(value.type, value.name)
if llfun is None:
llfunty = self.llty_of_type(value.type, bare=True)
llfun = ll.Function(self.llmodule, llfunty, value.name)
llretty = self.llty_of_type(value.type.ret, for_return=True)
if self.needs_sret(llretty):
llfun.args[0].add_attribute('sret')
return llfun
else: else:
assert False assert False
@ -669,6 +672,9 @@ class LLVMIRGenerator:
return list(typ.attributes.keys()).index(attr) return list(typ.attributes.keys()).index(attr)
def get_or_define_global(self, name, llty, llvalue=None): def get_or_define_global(self, name, llty, llvalue=None):
if llvalue is None:
llvalue = ll.Constant(llty, ll.Undefined)
if name in self.llmodule.globals: if name in self.llmodule.globals:
llglobal = self.llmodule.get_global(name) llglobal = self.llmodule.get_global(name)
else: else:
@ -683,12 +689,11 @@ class LLVMIRGenerator:
llty = self.llty_of_type(typ).pointee llty = self.llty_of_type(typ).pointee
return self.get_or_define_global("class.{}".format(typ.name), llty) return self.get_or_define_global("class.{}".format(typ.name), llty)
def get_closure(self, typ, attr): def get_method(self, typ, attr):
assert types.is_constructor(typ) assert types.is_constructor(typ)
assert types.is_function(typ.attributes[attr]) assert types.is_function(typ.attributes[attr])
llty = self.llty_of_type(typ.attributes[attr]) llty = self.llty_of_type(typ.attributes[attr])
return self.get_or_define_global("method.{}.{}".format(typ.name, attr), return self.get_or_define_global("method.{}.{}".format(typ.name, attr), llty)
llty, ll.Constant(llty, ll.Undefined))
def process_GetAttr(self, insn): def process_GetAttr(self, insn):
typ, attr = insn.object().type, insn.attr typ, attr = insn.object().type, insn.attr
@ -710,7 +715,7 @@ class LLVMIRGenerator:
assert False assert False
if types.is_method(insn.type) and attr not in typ.attributes: if types.is_method(insn.type) and attr not in typ.attributes:
llfun = self.llbuilder.load(self.get_closure(typ.constructor, attr)) llfun = self.llbuilder.load(self.get_method(typ.constructor, attr))
llself = self.map(insn.object()) llself = self.map(insn.object())
llmethodty = self.llty_of_type(insn.type) llmethodty = self.llty_of_type(insn.type)
@ -722,7 +727,7 @@ class LLVMIRGenerator:
return llmethod return llmethod
elif types.is_function(insn.type) and attr in typ.attributes and \ elif types.is_function(insn.type) and attr in typ.attributes and \
types.is_constructor(typ): types.is_constructor(typ):
return self.llbuilder.load(self.get_closure(typ, attr)) return self.llbuilder.load(self.get_method(typ, attr))
else: else:
llptr = self.llbuilder.gep(obj, [self.llindex(0), self.llindex(index)], llptr = self.llbuilder.gep(obj, [self.llindex(0), self.llindex(index)],
inbounds=True, name=insn.name) inbounds=True, name=insn.name)
@ -742,7 +747,7 @@ class LLVMIRGenerator:
if types.is_function(insn.value().type) and attr in typ.attributes and \ if types.is_function(insn.value().type) and attr in typ.attributes and \
types.is_constructor(typ): types.is_constructor(typ):
llptr = self.get_closure(typ, attr) llptr = self.get_method(typ, attr)
else: else:
llptr = self.llbuilder.gep(obj, [self.llindex(0), llptr = self.llbuilder.gep(obj, [self.llindex(0),
self.llindex(self.attr_index(typ, attr))], self.llindex(self.attr_index(typ, attr))],
@ -987,8 +992,15 @@ class LLVMIRGenerator:
assert False assert False
def process_Closure(self, insn): def process_Closure(self, insn):
llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined)
llenv = self.llbuilder.bitcast(self.map(insn.environment()), llptr) llenv = self.llbuilder.bitcast(self.map(insn.environment()), llptr)
if insn.target_function.name in self.function_map.values():
# If this closure belongs to a quoted function, we assume this is the only
# time that the closure is created, and record the environment globally
llenvptr = self.get_or_define_global("env.{}".format(insn.target_function.name),
llptr)
self.llbuilder.store(llenv, llenvptr)
llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined)
llvalue = self.llbuilder.insert_value(llvalue, llenv, 0) llvalue = self.llbuilder.insert_value(llvalue, llenv, 0)
llvalue = self.llbuilder.insert_value(llvalue, self.map(insn.target_function), 1, llvalue = self.llbuilder.insert_value(llvalue, self.map(insn.target_function), 1,
name=insn.name) name=insn.name)
@ -1273,16 +1285,29 @@ class LLVMIRGenerator:
llconst = ll.Constant(llty, [ll.Constant(lli32, len(llelts)), lleltsptr]) llconst = ll.Constant(llty, [ll.Constant(lli32, len(llelts)), lleltsptr])
return llconst return llconst
elif types.is_function(typ): elif types.is_function(typ):
# RPC and C functions have no runtime representation; ARTIQ # RPC and C functions have no runtime representation.
# functions are initialized explicitly. # We only get down this codepath for ARTIQ Python functions when they're
# referenced from a constructor, and the value inside the constructor
# is never used.
return ll.Constant(llty, ll.Undefined) return ll.Constant(llty, ll.Undefined)
else: else:
print(typ) print(typ)
assert False assert False
def process_Quote(self, insn): def process_Quote(self, insn):
assert self.object_map is not None if insn.value in self.function_map:
return self._quote(insn.value, insn.type, lambda: [repr(insn.value)]) func_name = self.function_map[insn.value]
llenvptr = self.get_or_define_global("env.{}".format(func_name), llptr)
llenv = self.llbuilder.load(llenvptr)
llfun = self.get_function(insn.type.find(), func_name)
llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
llclosure = self.llbuilder.insert_value(llclosure, llenv, 0)
llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name)
return llclosure
else:
assert self.object_map is not None
return self._quote(insn.value, insn.type, lambda: [repr(insn.value)])
def process_Select(self, insn): def process_Select(self, insn):
return self.llbuilder.select(self.map(insn.condition()), return self.llbuilder.select(self.map(insn.condition()),

View File

@ -1,5 +1,6 @@
# RUN: env ARTIQ_DUMP_IR=%t %python -m artiq.compiler.testbench.embedding +compile %s # RUN: env ARTIQ_DUMP_IR=%t %python -m artiq.compiler.testbench.embedding +compile %s
# RUN: OutputCheck %s --file-to-check=%t.txt # RUN: OutputCheck %s --file-to-check=%t.txt
# XFAIL: *
from artiq.language.core import * from artiq.language.core import *
from artiq.language.types import * from artiq.language.types import *

View File

@ -1,5 +1,6 @@
# RUN: env ARTIQ_DUMP_IR=%t %python -m artiq.compiler.testbench.embedding +compile %s 2>%t # RUN: env ARTIQ_DUMP_IR=%t %python -m artiq.compiler.testbench.embedding +compile %s 2>%t
# RUN: OutputCheck %s --file-to-check=%t.txt # RUN: OutputCheck %s --file-to-check=%t.txt
# XFAIL: *
from artiq.language.core import * from artiq.language.core import *
from artiq.language.types import * from artiq.language.types import *