transforms.artiq_ir_generator: devirtualize closure calls.

This commit is contained in:
whitequark 2015-10-09 01:32:27 +03:00
parent 6922bd5638
commit 0bb793199f
3 changed files with 56 additions and 7 deletions

View File

@ -48,7 +48,7 @@ class Value:
self.uses, self.type = set(), typ.find() self.uses, self.type = set(), typ.find()
def replace_all_uses_with(self, value): def replace_all_uses_with(self, value):
for user in self.uses: for user in set(self.uses):
user.replace_uses_of(self, value) user.replace_uses_of(self, value)
class Constant(Value): class Constant(Value):
@ -126,11 +126,11 @@ class User(NamedValue):
self.set_operands([]) self.set_operands([])
def replace_uses_of(self, value, replacement): def replace_uses_of(self, value, replacement):
assert value in operands assert value in self.operands
for index, operand in enumerate(operands): for index, operand in enumerate(self.operands):
if operand == value: if operand == value:
operands[index] = replacement self.operands[index] = replacement
value.uses.remove(self) value.uses.remove(self)
replacement.uses.add(self) replacement.uses.add(self)
@ -851,6 +851,9 @@ class Closure(Instruction):
class Call(Instruction): class Call(Instruction):
""" """
A function call operation. A function call operation.
:ivar static_target_function: (:class:`Function` or None)
statically resolved callee
""" """
""" """
@ -861,6 +864,7 @@ class Call(Instruction):
assert isinstance(func, Value) assert isinstance(func, Value)
for arg in args: assert isinstance(arg, Value) for arg in args: assert isinstance(arg, Value)
super().__init__([func] + args, func.type.ret, name) super().__init__([func] + args, func.type.ret, name)
self.static_target_function = None
def opcode(self): def opcode(self):
return "call" return "call"
@ -871,6 +875,12 @@ class Call(Instruction):
def arguments(self): def arguments(self):
return self.operands[1:] return self.operands[1:]
def __str__(self):
result = super().__str__()
if self.static_target_function is not None:
result += " ; calls {}".format(self.static_target_function.name)
return result
class Select(Instruction): class Select(Instruction):
""" """
A conditional select instruction. A conditional select instruction.
@ -1080,6 +1090,9 @@ class Reraise(Terminator):
class Invoke(Terminator): class Invoke(Terminator):
""" """
A function call operation that supports exception handling. A function call operation that supports exception handling.
:ivar static_target_function: (:class:`Function` or None)
statically resolved callee
""" """
""" """
@ -1094,6 +1107,7 @@ class Invoke(Terminator):
assert isinstance(normal, BasicBlock) assert isinstance(normal, BasicBlock)
assert isinstance(exn, BasicBlock) assert isinstance(exn, BasicBlock)
super().__init__([func] + args + [normal, exn], func.type.ret, name) super().__init__([func] + args + [normal, exn], func.type.ret, name)
self.static_target_function = None
def opcode(self): def opcode(self):
return "invoke" return "invoke"
@ -1116,6 +1130,12 @@ class Invoke(Terminator):
self.operands[-1].as_operand()) self.operands[-1].as_operand())
return result return result
def __str__(self):
result = super().__str__()
if self.static_target_function is not None:
result += " ; calls {}".format(self.static_target_function.name)
return result
class LandingPad(Terminator): class LandingPad(Terminator):
""" """
An instruction that gives an incoming exception a name and An instruction that gives an incoming exception a name and

View File

@ -68,10 +68,9 @@ class Module:
iodelay_estimator.visit_fixpoint(src.typedtree) iodelay_estimator.visit_fixpoint(src.typedtree)
devirtualization.visit(src.typedtree) devirtualization.visit(src.typedtree)
self.artiq_ir = artiq_ir_generator.visit(src.typedtree) self.artiq_ir = artiq_ir_generator.visit(src.typedtree)
artiq_ir_generator.annotate_calls(devirtualization)
dead_code_eliminator.process(self.artiq_ir) dead_code_eliminator.process(self.artiq_ir)
local_access_validator.process(self.artiq_ir) local_access_validator.process(self.artiq_ir)
# for f in self.artiq_ir:
# print(f)
def build_llvm_ir(self, target): def build_llvm_ir(self, target):
"""Compile the module to LLVM IR for the specified target.""" """Compile the module to LLVM IR for the specified target."""

View File

@ -63,6 +63,18 @@ class ARTIQIRGenerator(algorithm.Visitor):
the basic block to which ``return`` will transfer control the basic block to which ``return`` will transfer control
:ivar unwind_target: (:class:`ir.BasicBlock` or None) :ivar unwind_target: (:class:`ir.BasicBlock` or None)
the basic block to which unwinding will transfer control the basic block to which unwinding will transfer control
There is, additionally, some global state that is used to translate
the results of analyses on AST level to IR level:
:ivar function_map: (map of :class:`ast.FunctionDefT` to :class:`ir.Function`)
the map from function definition nodes to IR functions
:ivar variable_map: (map of :class:`ast.NameT` to :class:`ir.GetLocal`)
the map from variable name nodes to instructions retrieving
the variable values
:ivar method_map: (map of :class:`ast.AttributeT` to :class:`ir.GetAttribute`)
the map from method resolution nodes to instructions retrieving
the called function inside a translated :class:`ast.CallT` node
""" """
_size_type = builtins.TInt(types.TValue(32)) _size_type = builtins.TInt(types.TValue(32))
@ -86,6 +98,19 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.continue_target = None self.continue_target = None
self.return_target = None self.return_target = None
self.unwind_target = None self.unwind_target = None
self.function_map = dict()
self.variable_map = dict()
self.method_map = dict()
def annotate_calls(self, devirtualization):
for var_node in devirtualization.variable_map:
callee_node = devirtualization.variable_map[var_node]
callee = self.function_map[callee_node]
call_target = self.variable_map[var_node]
for use in call_target.uses:
if isinstance(use, (ir.Call, ir.Invoke)) and \
use.target_function() == call_target:
use.static_target_function = callee
def add_block(self, name=""): def add_block(self, name=""):
block = ir.BasicBlock([], name) block = ir.BasicBlock([], name)
@ -204,6 +229,9 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.functions.append(func) self.functions.append(func)
old_func, self.current_function = self.current_function, func old_func, self.current_function = self.current_function, func
if not is_lambda:
self.function_map[node] = func
entry = self.add_block() entry = self.add_block()
old_block, self.current_block = self.current_block, entry old_block, self.current_block = self.current_block, entry
@ -701,7 +729,9 @@ class ARTIQIRGenerator(algorithm.Visitor):
def visit_NameT(self, node): def visit_NameT(self, node):
if self.current_assign is None: if self.current_assign is None:
return self._get_local(node.id) insn = self._get_local(node.id)
self.variable_map[node] = insn
return insn
else: else:
return self._set_local(node.id, self.current_assign) return self._set_local(node.id, self.current_assign)