forked from M-Labs/artiq
compiler: quote functions directly instead of going through a local.
This commit is contained in:
parent
f72e050af5
commit
e534941383
@ -107,11 +107,8 @@ class ASTSynthesizer:
|
||||
unquote_loc = self._add('`')
|
||||
loc = quote_loc.join(unquote_loc)
|
||||
|
||||
function_name, function_type = self.quote_function(value, self.expanded_from)
|
||||
if function_name is None:
|
||||
return asttyped.QuoteT(value=value, type=function_type, loc=loc)
|
||||
else:
|
||||
return asttyped.NameT(id=function_name, ctx=None, type=function_type, loc=loc)
|
||||
function_type = self.quote_function(value, self.expanded_from)
|
||||
return asttyped.QuoteT(value=value, type=function_type, loc=loc)
|
||||
else:
|
||||
quote_loc = self._add('`')
|
||||
repr_loc = self._add(repr(value))
|
||||
@ -155,7 +152,7 @@ class ASTSynthesizer:
|
||||
return asttyped.QuoteT(value=value, type=instance_type,
|
||||
loc=loc)
|
||||
|
||||
def call(self, function_node, args, kwargs, callback=None):
|
||||
def call(self, callee, args, kwargs, callback=None):
|
||||
"""
|
||||
Construct an AST fragment calling a function specified by
|
||||
an AST node `function_node`, with given arguments.
|
||||
@ -164,11 +161,11 @@ class ASTSynthesizer:
|
||||
callback_node = self.quote(callback)
|
||||
cb_begin_loc = self._add("(")
|
||||
|
||||
callee_node = self.quote(callee)
|
||||
arg_nodes = []
|
||||
kwarg_nodes = []
|
||||
kwarg_locs = []
|
||||
|
||||
name_loc = self._add(function_node.name)
|
||||
begin_loc = self._add("(")
|
||||
for index, arg in enumerate(args):
|
||||
arg_nodes.append(self.quote(arg))
|
||||
@ -189,9 +186,7 @@ class ASTSynthesizer:
|
||||
cb_end_loc = self._add(")")
|
||||
|
||||
node = asttyped.CallT(
|
||||
func=asttyped.NameT(id=function_node.name, ctx=None,
|
||||
type=function_node.signature_type,
|
||||
loc=name_loc),
|
||||
func=callee_node,
|
||||
args=arg_nodes,
|
||||
keywords=[ast.keyword(arg=kw, value=value,
|
||||
arg_loc=arg_loc, equals_loc=equals_loc,
|
||||
@ -201,7 +196,7 @@ class ASTSynthesizer:
|
||||
starargs=None, kwargs=None,
|
||||
type=types.TVar(), iodelay=None, arg_exprs={},
|
||||
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
|
||||
loc=name_loc.join(end_loc))
|
||||
loc=callee_node.loc.join(end_loc))
|
||||
|
||||
if callback is not None:
|
||||
node = asttyped.CallT(
|
||||
@ -213,19 +208,6 @@ class ASTSynthesizer:
|
||||
|
||||
return node
|
||||
|
||||
def assign_local(self, var_name, value):
|
||||
name_loc = self._add(var_name)
|
||||
_ = self._add(" ")
|
||||
equals_loc = self._add("=")
|
||||
_ = self._add(" ")
|
||||
value_node = self.quote(value)
|
||||
|
||||
var_node = asttyped.NameT(id=var_name, ctx=None, type=value_node.type,
|
||||
loc=name_loc)
|
||||
|
||||
return ast.Assign(targets=[var_node], value=value_node,
|
||||
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))
|
||||
|
||||
def assign_attribute(self, obj, attr_name, value):
|
||||
obj_node = self.quote(obj)
|
||||
dot_loc = self._add(".")
|
||||
@ -465,18 +447,16 @@ class Stitcher:
|
||||
|
||||
self.functions = {}
|
||||
|
||||
self.function_map = {}
|
||||
self.object_map = ObjectMap()
|
||||
self.type_map = {}
|
||||
self.value_map = defaultdict(lambda: [])
|
||||
|
||||
def stitch_call(self, function, args, kwargs, callback=None):
|
||||
function_node = self._quote_embedded_function(function)
|
||||
self.typedtree.append(function_node)
|
||||
|
||||
# We synthesize source code for the initial call so that
|
||||
# diagnostics would have something meaningful to display to the user.
|
||||
synthesizer = self._synthesizer(self._function_loc(function.artiq_embedded.function))
|
||||
call_node = synthesizer.call(function_node, args, kwargs, callback)
|
||||
call_node = synthesizer.call(function, args, kwargs, callback)
|
||||
synthesizer.finalize()
|
||||
self.typedtree.append(call_node)
|
||||
|
||||
@ -496,9 +476,8 @@ class Stitcher:
|
||||
break
|
||||
old_typedtree_hash = typedtree_hash
|
||||
|
||||
# For every host class we embed, add an appropriate constructor
|
||||
# as a global. This is necessary for method lookup, which uses
|
||||
# the getconstructor instruction.
|
||||
# For every host class we embed, fill in the function slots
|
||||
# with their corresponding closures.
|
||||
for instance_type, constructor_type in list(self.type_map.values()):
|
||||
# Do we have any direct reference to a constructor?
|
||||
if len(self.value_map[constructor_type]) > 0:
|
||||
@ -509,13 +488,6 @@ class Stitcher:
|
||||
instance, _instance_loc = self.value_map[instance_type][0]
|
||||
constructor = type(instance)
|
||||
|
||||
self.globals[constructor_type.name] = constructor_type
|
||||
|
||||
synthesizer = self._synthesizer()
|
||||
ast = synthesizer.assign_local(constructor_type.name, constructor)
|
||||
synthesizer.finalize()
|
||||
self._inject(ast)
|
||||
|
||||
for attr in constructor_type.attributes:
|
||||
if types.is_function(constructor_type.attributes[attr]):
|
||||
synthesizer = self._synthesizer()
|
||||
@ -577,23 +549,28 @@ class Stitcher:
|
||||
# Mangle the name, since we put everything into a single module.
|
||||
function_node.name = "{}.{}".format(module_name, function.__qualname__)
|
||||
|
||||
# Normally, LocalExtractor would populate the typing environment
|
||||
# of the module with the function name. However, since we run
|
||||
# ASTTypedRewriter on the function node directly, we need to do it
|
||||
# explicitly.
|
||||
function_type = types.TVar()
|
||||
self.globals[function_node.name] = function_type
|
||||
# Record the function in the function map so that LLVM IR generator
|
||||
# can handle quoting it.
|
||||
self.function_map[function] = function_node.name
|
||||
|
||||
# Memoize the function before typing it to handle recursive
|
||||
# Memoize the function type before typing it to handle recursive
|
||||
# invocations.
|
||||
self.functions[function] = function_node.name, function_type
|
||||
self.functions[function] = types.TVar()
|
||||
|
||||
# Rewrite into typed form.
|
||||
asttyped_rewriter = StitchingASTTypedRewriter(
|
||||
engine=self.engine, prelude=self.prelude,
|
||||
globals=self.globals, host_environment=host_environment,
|
||||
quote=self._quote)
|
||||
return asttyped_rewriter.visit(function_node)
|
||||
function_node = asttyped_rewriter.visit(function_node)
|
||||
|
||||
# Add it into our typedtree so that it gets inferenced and codegen'd.
|
||||
self._inject(function_node)
|
||||
|
||||
# Tie the typing knot.
|
||||
self.functions[function].unify(function_node.signature_type)
|
||||
|
||||
return function_node
|
||||
|
||||
def _function_loc(self, function):
|
||||
filename = function.__code__.co_filename
|
||||
@ -734,14 +711,12 @@ class Stitcher:
|
||||
function_type = types.TCFunction(arg_types, ret_type,
|
||||
name=syscall)
|
||||
|
||||
self.functions[function] = None, function_type
|
||||
self.functions[function] = function_type
|
||||
|
||||
return None, function_type
|
||||
return function_type
|
||||
|
||||
def _quote_function(self, function, loc):
|
||||
if function in self.functions:
|
||||
result = self.functions[function]
|
||||
else:
|
||||
if function not in self.functions:
|
||||
if hasattr(function, "artiq_embedded"):
|
||||
if function.artiq_embedded.function is not None:
|
||||
if function.__name__ == "<lambda>":
|
||||
@ -766,17 +741,12 @@ class Stitcher:
|
||||
notes=[note])
|
||||
self.engine.process(diag)
|
||||
|
||||
# Insert the typed AST for the new function and restart inference.
|
||||
# It doesn't really matter where we insert as long as it is before
|
||||
# the final call.
|
||||
function_node = self._quote_embedded_function(function)
|
||||
self._inject(function_node)
|
||||
result = function_node.name, self.globals[function_node.name]
|
||||
self._quote_embedded_function(function)
|
||||
elif function.artiq_embedded.syscall is not None:
|
||||
# Insert a storage-less global whose type instructs the compiler
|
||||
# to perform a system call instead of a regular call.
|
||||
result = self._quote_foreign_function(function, loc,
|
||||
syscall=function.artiq_embedded.syscall)
|
||||
self._quote_foreign_function(function, loc,
|
||||
syscall=function.artiq_embedded.syscall)
|
||||
elif function.artiq_embedded.forbidden is not None:
|
||||
diag = diagnostic.Diagnostic("fatal",
|
||||
"this function cannot be called as an RPC", {},
|
||||
@ -788,12 +758,12 @@ class Stitcher:
|
||||
else:
|
||||
# Insert a storage-less global whose type instructs the compiler
|
||||
# to perform an RPC instead of a regular call.
|
||||
result = self._quote_foreign_function(function, loc, syscall=None)
|
||||
self._quote_foreign_function(function, loc, syscall=None)
|
||||
|
||||
function_name, function_type = result
|
||||
function_type = self.functions[function]
|
||||
if types.is_rpc_function(function_type):
|
||||
function_type = types.instantiate(function_type)
|
||||
return function_name, function_type
|
||||
return function_type
|
||||
|
||||
def _quote(self, value, loc):
|
||||
synthesizer = self._synthesizer(loc)
|
||||
|
@ -19,6 +19,7 @@ class Source:
|
||||
else:
|
||||
self.engine = engine
|
||||
|
||||
self.function_map = {}
|
||||
self.object_map = None
|
||||
self.type_map = {}
|
||||
|
||||
@ -45,6 +46,7 @@ class Source:
|
||||
class Module:
|
||||
def __init__(self, src, ref_period=1e-6):
|
||||
self.engine = src.engine
|
||||
self.function_map = src.function_map
|
||||
self.object_map = src.object_map
|
||||
self.type_map = src.type_map
|
||||
|
||||
@ -80,7 +82,7 @@ class Module:
|
||||
"""Compile the module to LLVM IR for the specified target."""
|
||||
llvm_ir_generator = transforms.LLVMIRGenerator(
|
||||
engine=self.engine, module_name=self.name, target=target,
|
||||
object_map=self.object_map, type_map=self.type_map)
|
||||
function_map=self.function_map, object_map=self.object_map, type_map=self.type_map)
|
||||
return llvm_ir_generator.process(self.artiq_ir, attribute_writeback=True)
|
||||
|
||||
def entry_point(self):
|
||||
|
@ -319,13 +319,14 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
|
||||
return self.append(ir.Closure(func, self.current_env))
|
||||
|
||||
def visit_FunctionDefT(self, node, in_class=None):
|
||||
def visit_FunctionDefT(self, node):
|
||||
func = self.visit_function(node, is_lambda=False,
|
||||
is_internal=len(self.name) > 0 or '.' in node.name)
|
||||
if in_class is None:
|
||||
self._set_local(node.name, func)
|
||||
if self.current_class is None:
|
||||
if node.name in self.current_env.type.params:
|
||||
self._set_local(node.name, func)
|
||||
else:
|
||||
self.append(ir.SetAttr(in_class, node.name, func))
|
||||
self.append(ir.SetAttr(self.current_class, node.name, func))
|
||||
|
||||
def visit_Return(self, node):
|
||||
if node.value is None:
|
||||
|
@ -190,16 +190,16 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||
self.in_class = None
|
||||
|
||||
def _try_find_name(self, name):
|
||||
for typing_env in reversed(self.env_stack):
|
||||
if name in typing_env:
|
||||
return typing_env[name]
|
||||
|
||||
def _find_name(self, name, loc):
|
||||
if self.in_class is not None:
|
||||
typ = self.in_class.constructor_type.attributes.get(name)
|
||||
if typ is not None:
|
||||
return typ
|
||||
|
||||
for typing_env in reversed(self.env_stack):
|
||||
if name in typing_env:
|
||||
return typing_env[name]
|
||||
|
||||
def _find_name(self, name, loc):
|
||||
typ = self._try_find_name(name)
|
||||
if typ is not None:
|
||||
return typ
|
||||
@ -229,9 +229,13 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||
extractor.visit(node)
|
||||
|
||||
signature_type = self._try_find_name(node.name)
|
||||
if signature_type is None:
|
||||
signature_type = types.TVar()
|
||||
|
||||
node = asttyped.FunctionDefT(
|
||||
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
||||
signature_type=self._find_name(node.name, node.name_loc), return_type=types.TVar(),
|
||||
signature_type=signature_type, return_type=types.TVar(),
|
||||
name=node.name, args=node.args, returns=node.returns,
|
||||
body=node.body, decorator_list=node.decorator_list,
|
||||
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
|
||||
|
@ -33,7 +33,7 @@ class DeadCodeEliminator:
|
||||
# it also has to run after the interleaver, but interleaver
|
||||
# doesn't like to work with IR before DCE.
|
||||
if isinstance(insn, (ir.Phi, ir.Alloc, ir.GetAttr, ir.GetElem, ir.Coerce,
|
||||
ir.Arith, ir.Compare, ir.Closure, ir.Select, ir.Quote)) \
|
||||
ir.Arith, ir.Compare, ir.Select, ir.Quote)) \
|
||||
and not any(insn.uses):
|
||||
insn.erase()
|
||||
modified = True
|
||||
|
@ -164,9 +164,10 @@ class DebugInfoEmitter:
|
||||
|
||||
|
||||
class LLVMIRGenerator:
|
||||
def __init__(self, engine, module_name, target, object_map, type_map):
|
||||
def __init__(self, engine, module_name, target, function_map, object_map, type_map):
|
||||
self.engine = engine
|
||||
self.target = target
|
||||
self.function_map = function_map
|
||||
self.object_map = object_map
|
||||
self.type_map = type_map
|
||||
self.llcontext = target.llcontext
|
||||
@ -393,22 +394,24 @@ class LLVMIRGenerator:
|
||||
|
||||
return llglobal
|
||||
|
||||
def get_function(self, typ, name):
|
||||
llfun = self.llmodule.get_global(name)
|
||||
if llfun is None:
|
||||
llfunty = self.llty_of_type(typ, bare=True)
|
||||
llfun = ll.Function(self.llmodule, llfunty, name)
|
||||
|
||||
llretty = self.llty_of_type(typ.ret, for_return=True)
|
||||
if self.needs_sret(llretty):
|
||||
llfun.args[0].add_attribute('sret')
|
||||
return llfun
|
||||
|
||||
def map(self, value):
|
||||
if isinstance(value, (ir.Argument, ir.Instruction, ir.BasicBlock)):
|
||||
return self.llmap[value]
|
||||
elif isinstance(value, ir.Constant):
|
||||
return self.llconst_of_const(value)
|
||||
elif isinstance(value, ir.Function):
|
||||
llfun = self.llmodule.get_global(value.name)
|
||||
if llfun is None:
|
||||
llfunty = self.llty_of_type(value.type, bare=True)
|
||||
llfun = ll.Function(self.llmodule, llfunty, value.name)
|
||||
|
||||
llretty = self.llty_of_type(value.type.ret, for_return=True)
|
||||
if self.needs_sret(llretty):
|
||||
llfun.args[0].add_attribute('sret')
|
||||
|
||||
return llfun
|
||||
return self.get_function(value.type, value.name)
|
||||
else:
|
||||
assert False
|
||||
|
||||
@ -669,6 +672,9 @@ class LLVMIRGenerator:
|
||||
return list(typ.attributes.keys()).index(attr)
|
||||
|
||||
def get_or_define_global(self, name, llty, llvalue=None):
|
||||
if llvalue is None:
|
||||
llvalue = ll.Constant(llty, ll.Undefined)
|
||||
|
||||
if name in self.llmodule.globals:
|
||||
llglobal = self.llmodule.get_global(name)
|
||||
else:
|
||||
@ -683,12 +689,11 @@ class LLVMIRGenerator:
|
||||
llty = self.llty_of_type(typ).pointee
|
||||
return self.get_or_define_global("class.{}".format(typ.name), llty)
|
||||
|
||||
def get_closure(self, typ, attr):
|
||||
def get_method(self, typ, attr):
|
||||
assert types.is_constructor(typ)
|
||||
assert types.is_function(typ.attributes[attr])
|
||||
llty = self.llty_of_type(typ.attributes[attr])
|
||||
return self.get_or_define_global("method.{}.{}".format(typ.name, attr),
|
||||
llty, ll.Constant(llty, ll.Undefined))
|
||||
return self.get_or_define_global("method.{}.{}".format(typ.name, attr), llty)
|
||||
|
||||
def process_GetAttr(self, insn):
|
||||
typ, attr = insn.object().type, insn.attr
|
||||
@ -710,7 +715,7 @@ class LLVMIRGenerator:
|
||||
assert False
|
||||
|
||||
if types.is_method(insn.type) and attr not in typ.attributes:
|
||||
llfun = self.llbuilder.load(self.get_closure(typ.constructor, attr))
|
||||
llfun = self.llbuilder.load(self.get_method(typ.constructor, attr))
|
||||
llself = self.map(insn.object())
|
||||
|
||||
llmethodty = self.llty_of_type(insn.type)
|
||||
@ -722,7 +727,7 @@ class LLVMIRGenerator:
|
||||
return llmethod
|
||||
elif types.is_function(insn.type) and attr in typ.attributes and \
|
||||
types.is_constructor(typ):
|
||||
return self.llbuilder.load(self.get_closure(typ, attr))
|
||||
return self.llbuilder.load(self.get_method(typ, attr))
|
||||
else:
|
||||
llptr = self.llbuilder.gep(obj, [self.llindex(0), self.llindex(index)],
|
||||
inbounds=True, name=insn.name)
|
||||
@ -742,7 +747,7 @@ class LLVMIRGenerator:
|
||||
|
||||
if types.is_function(insn.value().type) and attr in typ.attributes and \
|
||||
types.is_constructor(typ):
|
||||
llptr = self.get_closure(typ, attr)
|
||||
llptr = self.get_method(typ, attr)
|
||||
else:
|
||||
llptr = self.llbuilder.gep(obj, [self.llindex(0),
|
||||
self.llindex(self.attr_index(typ, attr))],
|
||||
@ -987,8 +992,15 @@ class LLVMIRGenerator:
|
||||
assert False
|
||||
|
||||
def process_Closure(self, insn):
|
||||
llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined)
|
||||
llenv = self.llbuilder.bitcast(self.map(insn.environment()), llptr)
|
||||
if insn.target_function.name in self.function_map.values():
|
||||
# If this closure belongs to a quoted function, we assume this is the only
|
||||
# time that the closure is created, and record the environment globally
|
||||
llenvptr = self.get_or_define_global("env.{}".format(insn.target_function.name),
|
||||
llptr)
|
||||
self.llbuilder.store(llenv, llenvptr)
|
||||
|
||||
llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined)
|
||||
llvalue = self.llbuilder.insert_value(llvalue, llenv, 0)
|
||||
llvalue = self.llbuilder.insert_value(llvalue, self.map(insn.target_function), 1,
|
||||
name=insn.name)
|
||||
@ -1273,16 +1285,29 @@ class LLVMIRGenerator:
|
||||
llconst = ll.Constant(llty, [ll.Constant(lli32, len(llelts)), lleltsptr])
|
||||
return llconst
|
||||
elif types.is_function(typ):
|
||||
# RPC and C functions have no runtime representation; ARTIQ
|
||||
# functions are initialized explicitly.
|
||||
# RPC and C functions have no runtime representation.
|
||||
# We only get down this codepath for ARTIQ Python functions when they're
|
||||
# referenced from a constructor, and the value inside the constructor
|
||||
# is never used.
|
||||
return ll.Constant(llty, ll.Undefined)
|
||||
else:
|
||||
print(typ)
|
||||
assert False
|
||||
|
||||
def process_Quote(self, insn):
|
||||
assert self.object_map is not None
|
||||
return self._quote(insn.value, insn.type, lambda: [repr(insn.value)])
|
||||
if insn.value in self.function_map:
|
||||
func_name = self.function_map[insn.value]
|
||||
llenvptr = self.get_or_define_global("env.{}".format(func_name), llptr)
|
||||
llenv = self.llbuilder.load(llenvptr)
|
||||
llfun = self.get_function(insn.type.find(), func_name)
|
||||
|
||||
llclosure = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
|
||||
llclosure = self.llbuilder.insert_value(llclosure, llenv, 0)
|
||||
llclosure = self.llbuilder.insert_value(llclosure, llfun, 1, name=insn.name)
|
||||
return llclosure
|
||||
else:
|
||||
assert self.object_map is not None
|
||||
return self._quote(insn.value, insn.type, lambda: [repr(insn.value)])
|
||||
|
||||
def process_Select(self, insn):
|
||||
return self.llbuilder.select(self.map(insn.condition()),
|
||||
|
@ -1,5 +1,6 @@
|
||||
# RUN: env ARTIQ_DUMP_IR=%t %python -m artiq.compiler.testbench.embedding +compile %s
|
||||
# RUN: OutputCheck %s --file-to-check=%t.txt
|
||||
# XFAIL: *
|
||||
|
||||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
|
@ -1,5 +1,6 @@
|
||||
# RUN: env ARTIQ_DUMP_IR=%t %python -m artiq.compiler.testbench.embedding +compile %s 2>%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t.txt
|
||||
# XFAIL: *
|
||||
|
||||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
|
Loading…
Reference in New Issue
Block a user