forked from M-Labs/artiq
1
0
Fork 0

transforms.devirtualizer.FunctionResolver: implement.

This commit is contained in:
whitequark 2015-10-07 22:21:29 +03:00
parent 42b0089a4c
commit 6ac82e1439
3 changed files with 97 additions and 0 deletions

View File

@ -57,6 +57,7 @@ class Module:
ref_period=ref_period) ref_period=ref_period)
dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine) dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine)
local_access_validator = validators.LocalAccessValidator(engine=self.engine) local_access_validator = validators.LocalAccessValidator(engine=self.engine)
devirtualizer = transforms.Devirtualizer()
self.name = src.name self.name = src.name
self.globals = src.globals self.globals = src.globals
@ -65,9 +66,12 @@ class Module:
monomorphism_validator.visit(src.typedtree) monomorphism_validator.visit(src.typedtree)
escape_validator.visit(src.typedtree) escape_validator.visit(src.typedtree)
iodelay_estimator.visit_fixpoint(src.typedtree) iodelay_estimator.visit_fixpoint(src.typedtree)
devirtualizer.visit(src.typedtree)
self.artiq_ir = artiq_ir_generator.visit(src.typedtree) self.artiq_ir = artiq_ir_generator.visit(src.typedtree)
dead_code_eliminator.process(self.artiq_ir) dead_code_eliminator.process(self.artiq_ir)
local_access_validator.process(self.artiq_ir) local_access_validator.process(self.artiq_ir)
# for f in self.artiq_ir:
# print(f)
def build_llvm_ir(self, target): def build_llvm_ir(self, target):
"""Compile the module to LLVM IR for the specified target.""" """Compile the module to LLVM IR for the specified target."""

View File

@ -3,5 +3,6 @@ from .inferencer import Inferencer
from .int_monomorphizer import IntMonomorphizer from .int_monomorphizer import IntMonomorphizer
from .iodelay_estimator import IODelayEstimator from .iodelay_estimator import IODelayEstimator
from .artiq_ir_generator import ARTIQIRGenerator from .artiq_ir_generator import ARTIQIRGenerator
from .devirtualizer import Devirtualizer
from .dead_code_eliminator import DeadCodeEliminator from .dead_code_eliminator import DeadCodeEliminator
from .llvm_ir_generator import LLVMIRGenerator from .llvm_ir_generator import LLVMIRGenerator

View File

@ -0,0 +1,92 @@
"""
:class:`Devirtualizer` performs method resolution at
compile time.
Devirtualization is implemented using a lattice
with three states: unknown assigned once diverges.
The lattice is computed individually for every
variable in scope as well as every
(constructor type, field name) pair.
"""
from pythonparser import algorithm
from .. import ir, types
def _advance(target_map, key, value):
if key not in target_map:
target_map[key] = value # unknown → assigned once
else:
target_map[key] = None # assigned once → diverges
class FunctionResolver(algorithm.Visitor):
def __init__(self, variable_map):
self.variable_map = variable_map
self.in_assign = False
self.scope_map = dict()
self.scope = None
self.queue = []
def finalize(self):
for thunk in self.queue:
thunk()
def visit_scope(self, node):
old_scope, self.scope = self.scope, node
self.generic_visit(node)
self.scope = old_scope
def visit_in_assign(self, node):
self.in_assign = True
self.visit(node)
self.in_assign = False
def visit_Assign(self, node):
self.visit(node.value)
self.visit_in_assign(node.targets)
def visit_For(self, node):
self.visit(node.iter)
self.visit_in_assign(node.target)
self.visit(node.body)
self.visit(node.orelse)
def visit_withitem(self, node):
self.visit(node.context_expr)
self.visit_in_assign(node.optional_vars)
def visit_comprehension(self, node):
self.visit(node.iter)
self.visit_in_assign(node.target)
self.visit(node.ifs)
def visit_ModuleT(self, node):
self.visit_scope(node)
def visit_FunctionDefT(self, node):
_advance(self.scope_map, (self.scope, node.name), node)
self.visit_scope(node)
def visit_NameT(self, node):
if self.in_assign:
# Just give up if we assign anything at all to a variable, and
# assume it diverges.
_advance(self.scope_map, (self.scope, node.id), None)
else:
# Copy the final value in scope_map into variable_map.
key = (self.scope, node.id)
def thunk():
if key in self.scope_map:
self.variable_map[node] = self.scope_map[key]
self.queue.append(thunk)
class Devirtualizer:
def __init__(self):
self.variable_map = dict()
self.method_map = dict()
def visit(self, node):
resolver = FunctionResolver(self.variable_map)
resolver.visit(node)
resolver.finalize()
# print(self.variable_map)