transforms.artiq_ir_generator: devirtualize method calls.

This commit is contained in:
whitequark 2015-10-09 02:27:52 +03:00
parent b6c8c9f480
commit 48f1f48f09
2 changed files with 37 additions and 6 deletions

View File

@ -6,9 +6,9 @@ but without too much detail, such as exposing the reference/value
semantics explicitly. semantics explicitly.
""" """
from collections import OrderedDict from collections import OrderedDict, defaultdict
from pythonparser import algorithm, diagnostic, ast from pythonparser import algorithm, diagnostic, ast
from .. import types, builtins, ir from .. import types, builtins, asttyped, ir
def _readable_name(insn): def _readable_name(insn):
if isinstance(insn, ir.Constant): if isinstance(insn, ir.Constant):
@ -100,18 +100,27 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.unwind_target = None self.unwind_target = None
self.function_map = dict() self.function_map = dict()
self.variable_map = dict() self.variable_map = dict()
self.method_map = dict() self.method_map = defaultdict(lambda: [])
def annotate_calls(self, devirtualization): def annotate_calls(self, devirtualization):
for var_node in devirtualization.variable_map: for var_node in devirtualization.variable_map:
callee_node = devirtualization.variable_map[var_node] callee_node = devirtualization.variable_map[var_node]
callee = self.function_map[callee_node] callee = self.function_map[callee_node]
call_target = self.variable_map[var_node] call_target = self.variable_map[var_node]
for use in call_target.uses: for use in call_target.uses:
if isinstance(use, (ir.Call, ir.Invoke)) and \ if isinstance(use, (ir.Call, ir.Invoke)) and \
use.target_function() == call_target: use.target_function() == call_target:
use.static_target_function = callee use.static_target_function = callee
for type_and_method in devirtualization.method_map:
callee_node = devirtualization.method_map[type_and_method]
callee = self.function_map[callee_node]
for call in self.method_map[type_and_method]:
assert isinstance(call, (ir.Call, ir.Invoke))
call.static_target_function = callee
def add_block(self, name=""): def add_block(self, name=""):
block = ir.BasicBlock([], name) block = ir.BasicBlock([], name)
self.current_function.add(block) self.current_function.add(block)
@ -1553,12 +1562,18 @@ class ARTIQIRGenerator(algorithm.Visitor):
assert None not in args assert None not in args
if self.unwind_target is None: if self.unwind_target is None:
return self.append(ir.Call(func, args)) insn = self.append(ir.Call(func, args))
else: else:
after_invoke = self.add_block() after_invoke = self.add_block()
invoke = self.append(ir.Invoke(func, args, after_invoke, self.unwind_target)) insn = self.append(ir.Invoke(func, args, after_invoke, self.unwind_target))
self.current_block = after_invoke self.current_block = after_invoke
return invoke
method_key = None
if isinstance(node.func, asttyped.AttributeT):
attr_node = node.func
self.method_map[(attr_node.value.type, attr_node.attr)].append(insn)
return insn
def visit_QuoteT(self, node): def visit_QuoteT(self, node):
return self.append(ir.Quote(node.value, node.type)) return self.append(ir.Quote(node.value, node.type))

View File

@ -0,0 +1,16 @@
# RUN: env ARTIQ_DUMP_IR=1 %python -m artiq.compiler.testbench.embedding +compile %s 2>%t
# RUN: OutputCheck %s --file-to-check=%t
from artiq.language.core import *
from artiq.language.types import *
class foo:
@kernel
def bar(self):
pass
x = foo()
@kernel
def entrypoint():
# CHECK-L: ; calls testbench.foo.bar
x.bar()