compiler: implement local variable demotion.

This commit is contained in:
whitequark 2018-05-19 17:05:34 +00:00
parent fd110e9848
commit 9b4ad8b5af
5 changed files with 82 additions and 0 deletions

View File

@ -60,6 +60,7 @@ class Module:
ref_period=ref_period)
dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine)
local_access_validator = validators.LocalAccessValidator(engine=self.engine)
local_demoter = transforms.LocalDemoter()
devirtualization = analyses.Devirtualization()
interleaver = transforms.Interleaver(engine=self.engine)
invariant_detection = analyses.InvariantDetection(engine=self.engine)
@ -77,6 +78,7 @@ class Module:
dead_code_eliminator.process(self.artiq_ir)
interleaver.process(self.artiq_ir)
local_access_validator.process(self.artiq_ir)
local_demoter.process(self.artiq_ir)
if remarks:
invariant_detection.process(self.artiq_ir)

View File

@ -5,6 +5,7 @@ from .cast_monomorphizer import CastMonomorphizer
from .iodelay_estimator import IODelayEstimator
from .artiq_ir_generator import ARTIQIRGenerator
from .dead_code_eliminator import DeadCodeEliminator
from .local_demoter import LocalDemoter
from .llvm_ir_generator import LLVMIRGenerator
from .interleaver import Interleaver
from .typedtree_printer import TypedtreePrinter

View File

@ -0,0 +1,51 @@
"""
:class:`LocalDemoter` is a constant propagation transform:
it replaces reads of any local variable with only one write
in a function without closures with the value that was written.
:class:`LocalAccessValidator` must be run before this transform
to ensure that the transformation it performs is sound.
"""
from collections import defaultdict
from .. import ir
class LocalDemoter:
def process(self, functions):
for func in functions:
self.process_function(func)
def process_function(self, func):
env_safe = {}
env_gets = defaultdict(lambda: set())
env_sets = defaultdict(lambda: set())
for insn in func.instructions():
if isinstance(insn, (ir.GetLocal, ir.SetLocal)):
if "$" in insn.var_name:
continue
env = insn.environment()
if env not in env_safe:
for use in env.uses:
if not isinstance(use, (ir.GetLocal, ir.SetLocal)):
env_safe[env] = False
break
else:
env_safe[env] = True
if not env_safe[env]:
continue
if isinstance(insn, ir.SetLocal):
env_sets[(env, insn.var_name)].add(insn)
else:
env_gets[(env, insn.var_name)].add(insn)
for (env, var_name) in env_sets:
if len(env_sets[(env, var_name)]) == 1:
set_insn = next(iter(env_sets[(env, var_name)]))
for get_insn in env_gets[(env, var_name)]:
get_insn.replace_all_uses_with(set_insn.value())
get_insn.erase()

View File

@ -0,0 +1,15 @@
# RUN: %python -m artiq.compiler.testbench.irgen %s >%t
# RUN: OutputCheck %s --file-to-check=%t
def x(y): pass
# CHECK-L: NoneType input.a(environment(...) %ARG.ENV, NoneType %ARG.self) {
# CHECK-L: setlocal('self') %ENV, NoneType %ARG.self
# CHECK-NOT-L: call (y:NoneType)->NoneType %LOC.x, NoneType %ARG.self
def a(self):
def b():
pass
x(self)
a(None)

View File

@ -0,0 +1,13 @@
# RUN: %python -m artiq.compiler.testbench.irgen %s >%t
# RUN: OutputCheck %s --file-to-check=%t
def x(y): pass
# CHECK-L: NoneType input.a(environment(...) %ARG.ENV, NoneType %ARG.self) {
# CHECK-NOT-L: getlocal('self') %ENV
# CHECK-L: call (y:NoneType)->NoneType %LOC.x, NoneType %ARG.self
def a(self):
x(self)
a(None)