2
0
mirror of https://github.com/m-labs/artiq.git synced 2024-12-25 03:08:27 +08:00

Add LocalAccessValidator.

This commit is contained in:
whitequark 2015-07-19 11:44:51 +03:00
parent f5d9e11b38
commit ac491fae47
7 changed files with 216 additions and 53 deletions

View File

@ -395,6 +395,9 @@ class Function:
assert any(self.basic_blocks)
return self.basic_blocks[0]
def exits(self):
return [block for block in self.basic_blocks if not any(block.successors())]
def instructions(self):
for basic_block in self.basic_blocks:
yield from iter(basic_block.instructions)
@ -508,7 +511,7 @@ class GetLocal(Instruction):
def opcode(self):
return "getlocal({})".format(self.var_name)
def get_env(self):
def environment(self):
return self.operands[0]
class SetLocal(Instruction):
@ -536,10 +539,10 @@ class SetLocal(Instruction):
def opcode(self):
return "setlocal({})".format(self.var_name)
def get_env(self):
def environment(self):
return self.operands[0]
def get_value(self):
def value(self):
return self.operands[1]
class GetAttr(Instruction):
@ -568,7 +571,7 @@ class GetAttr(Instruction):
def opcode(self):
return "getattr({})".format(repr(self.attr))
def get_env(self):
def env(self):
return self.operands[0]
class SetAttr(Instruction):
@ -597,10 +600,10 @@ class SetAttr(Instruction):
def opcode(self):
return "setattr({})".format(repr(self.attr))
def get_env(self):
def env(self):
return self.operands[0]
def get_value(self):
def value(self):
return self.operands[1]
class GetElem(Instruction):
@ -620,10 +623,10 @@ class GetElem(Instruction):
def opcode(self):
return "getelem"
def get_list(self):
def list(self):
return self.operands[0]
def get_index(self):
def index(self):
return self.operands[1]
class SetElem(Instruction):
@ -646,13 +649,13 @@ class SetElem(Instruction):
def opcode(self):
return "setelem"
def get_list(self):
def list(self):
return self.operands[0]
def get_index(self):
def index(self):
return self.operands[1]
def get_value(self):
def value(self):
return self.operands[2]
class UnaryOp(Instruction):
@ -675,7 +678,7 @@ class UnaryOp(Instruction):
def opcode(self):
return "unaryop({})".format(type(self.op).__name__)
def get_operand(self):
def operand(self):
return self.operands[0]
class Coerce(Instruction):
@ -691,7 +694,7 @@ class Coerce(Instruction):
def opcode(self):
return "coerce"
def get_value(self):
def value(self):
return self.operands[0]
class BinaryOp(Instruction):
@ -717,10 +720,10 @@ class BinaryOp(Instruction):
def opcode(self):
return "binaryop({})".format(type(self.op).__name__)
def get_lhs(self):
def lhs(self):
return self.operands[0]
def get_rhs(self):
def rhs(self):
return self.operands[1]
class Compare(Instruction):
@ -746,10 +749,10 @@ class Compare(Instruction):
def opcode(self):
return "compare({})".format(type(self.op).__name__)
def get_lhs(self):
def lhs(self):
return self.operands[0]
def get_rhs(self):
def rhs(self):
return self.operands[1]
class Builtin(Instruction):
@ -793,7 +796,7 @@ class Closure(Instruction):
def opcode(self):
return "closure({})".format(self.target_function.name)
def get_environment(self):
def environment(self):
return self.operands[0]
class Call(Instruction):
@ -813,10 +816,10 @@ class Call(Instruction):
def opcode(self):
return "call"
def get_function(self):
def function(self):
return self.operands[0]
def get_arguments(self):
def arguments(self):
return self.operands[1:]
class Select(Instruction):
@ -839,13 +842,13 @@ class Select(Instruction):
def opcode(self):
return "select"
def get_condition(self):
def condition(self):
return self.operands[0]
def get_if_true(self):
def if_true(self):
return self.operands[1]
def get_if_false(self):
def if_false(self):
return self.operands[2]
class Branch(Terminator):
@ -863,7 +866,7 @@ class Branch(Terminator):
def opcode(self):
return "branch"
def get_target(self):
def target(self):
return self.operands[0]
class BranchIf(Terminator):
@ -885,13 +888,13 @@ class BranchIf(Terminator):
def opcode(self):
return "branchif"
def get_condition(self):
def condition(self):
return self.operands[0]
def get_if_true(self):
def if_true(self):
return self.operands[1]
def get_if_false(self):
def if_false(self):
return self.operands[2]
class IndirectBranch(Terminator):
@ -911,10 +914,10 @@ class IndirectBranch(Terminator):
def opcode(self):
return "indirectbranch"
def get_target(self):
def target(self):
return self.operands[0]
def get_destinations(self):
def destinations(self):
return self.operands[1:]
def add_destination(self, destination):
@ -939,7 +942,7 @@ class Return(Terminator):
def opcode(self):
return "return"
def get_value(self):
def value(self):
return self.operands[0]
class Unreachable(Terminator):
@ -971,7 +974,7 @@ class Raise(Terminator):
def opcode(self):
return "raise"
def get_value(self):
def value(self):
return self.operands[0]
class Invoke(Terminator):
@ -995,16 +998,16 @@ class Invoke(Terminator):
def opcode(self):
return "invoke"
def get_function(self):
def function(self):
return self.operands[0]
def get_arguments(self):
def arguments(self):
return self.operands[1:-2]
def get_normal_target(self):
def normal_target(self):
return self.operands[-2]
def get_exception_target(self):
def exception_target(self):
return self.operands[-1]
def _operands_as_string(self):

View File

@ -11,27 +11,28 @@ class Module:
if engine is None:
engine = diagnostic.Engine(all_errors_are_fatal=True)
module_name, _ = os.path.splitext(os.path.basename(source_buffer.name))
self.name, _ = os.path.splitext(os.path.basename(source_buffer.name))
asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine)
inferencer = transforms.Inferencer(engine=engine)
int_monomorphizer = transforms.IntMonomorphizer(engine=engine)
monomorphism_validator = validators.MonomorphismValidator(engine=engine)
escape_validator = validators.EscapeValidator(engine=engine)
ir_generator = transforms.IRGenerator(engine=engine, module_name=module_name)
ir_generator = transforms.IRGenerator(engine=engine, module_name=self.name)
dead_code_eliminator = transforms.DeadCodeEliminator(engine=engine)
local_access_validator = validators.LocalAccessValidator(engine=engine)
parsetree, comments = parse_buffer(source_buffer, engine=engine)
typedtree = asttyped_rewriter.visit(parsetree)
inferencer.visit(typedtree)
int_monomorphizer.visit(typedtree)
inferencer.visit(typedtree)
monomorphism_validator.visit(typedtree)
escape_validator.visit(typedtree)
ir_generator.visit(typedtree)
self.name = module_name
self.parsetree, self.comments = parse_buffer(source_buffer, engine=engine)
self.typedtree = asttyped_rewriter.visit(self.parsetree)
self.globals = asttyped_rewriter.globals
self.ir = ir_generator.functions
inferencer.visit(self.typedtree)
int_monomorphizer.visit(self.typedtree)
inferencer.visit(self.typedtree)
monomorphism_validator.visit(self.typedtree)
escape_validator.visit(self.typedtree)
self.ir = ir_generator.visit(self.typedtree)
dead_code_eliminator.process(self.ir)
local_access_validator.process(self.ir)
@classmethod
def from_string(cls, source_string, name="input.py", first_line=1, engine=None):

View File

@ -138,7 +138,7 @@ class IRGenerator(algorithm.Visitor):
self.generic_visit(node)
self.terminate(ir.Return(ir.Constant(None, builtins.TNone())))
return func
return self.functions
finally:
self.current_function = old_func
self.current_block = old_block
@ -245,8 +245,18 @@ class IRGenerator(algorithm.Visitor):
self.append(ir.Branch(self.return_target))
def visit_Expr(self, node):
# ignore the value, do it for side effects
self.visit(node.value)
# Ignore the value, do it for side effects.
result = self.visit(node.value)
# See comment in visit_Pass.
if isinstance(result, ir.Constant):
self.visit_Pass(node)
def visit_Pass(self, node):
# Insert a dummy instruction so that analyses which extract
# locations from CFG have something to use.
self.append(ir.BinaryOp(ast.Add(loc=None),
ir.Constant(0, self._size_type), ir.Constant(0, self._size_type)))
def visit_Assign(self, node):
try:
@ -288,9 +298,9 @@ class IRGenerator(algorithm.Visitor):
if any(node.orelse):
if not if_false.is_terminated():
if_false.append(ir.Branch(tail))
head.append(ir.BranchIf(cond, if_true, if_false))
self.append(ir.BranchIf(cond, if_true, if_false), block=head)
else:
head.append(ir.BranchIf(cond, if_true, tail))
self.append(ir.BranchIf(cond, if_true, tail), block=head)
def visit_While(self, node):
try:

View File

@ -1,2 +1,3 @@
from .monomorphism import MonomorphismValidator
from .escape import EscapeValidator
from .local_access import LocalAccessValidator

View File

@ -0,0 +1,121 @@
"""
:class:`LocalAccessValidator` verifies that local variables
are not accessed before being used.
"""
from functools import reduce
from pythonparser import diagnostic
from .. import ir, analyses
class LocalAccessValidator:
def __init__(self, engine):
self.engine = engine
def process(self, functions):
for func in functions:
self.process_function(func)
def process_function(self, func):
# Find all environments allocated in this func.
environments = []
for insn in func.instructions():
if isinstance(insn, ir.Alloc) and ir.is_environment(insn.type):
environments.append(insn)
# Compute initial state of interesting environments.
# Environments consisting only of internal variables (containing a ".")
# are ignored.
initial_state = {}
for env in environments:
env_state = {var: False for var in env.type.params if "." not in var}
if any(env_state):
initial_state[env] = env_state
# Traverse the acyclic graph made of basic blocks and forward edges only,
# while updating the environment state.
dom = analyses.DominatorTree(func)
state = {}
def traverse(block):
# Have we computed the state of this block already?
if block in state:
return state[block]
# No! Which forward edges lead to this block?
# If we dominate a predecessor, it's a back edge instead.
forward_edge_preds = [pred for pred in block.predecessors()
if block not in dom.dominated_by[pred]]
# Figure out what the state is before the leader
# instruction of this block.
pred_states = [traverse(pred) for pred in forward_edge_preds]
block_state = {}
if len(pred_states) > 1:
for env in initial_state:
# The variable has to be initialized in all predecessors
# in order to be initialized in this block.
def merge_state(a, b):
return {var: a[var] and b[var] for var in a}
block_state[env] = reduce(lambda a, b: merge_state(a[env], b[env]),
pred_states)
elif len(pred_states) == 1:
# The state is the same as at the terminator of predecessor.
# We'll mutate it, so copy.
pred_state = pred_states[0]
for env in initial_state:
env_state = pred_state[env]
block_state[env] = {var: env_state[var] for var in env_state}
else:
# This is the entry block.
for env in initial_state:
env_state = initial_state[env]
block_state[env] = {var: env_state[var] for var in env_state}
# Update the state based on block contents, while validating
# that no access to uninitialized variables will be done.
for insn in block.instructions:
if isinstance(insn, (ir.SetLocal, ir.GetLocal)) and \
"." not in insn.var_name:
env, var_name = insn.environment(), insn.var_name
assert env in block_state
assert var_name in block_state[env]
if isinstance(insn, ir.SetLocal):
# We've just initialized it.
block_state[env][var_name] = True
else: # isinstance(insn, ir.GetLocal)
if not block_state[env][var_name]:
# Oops, accessing it uninitialized. Find out where
# the uninitialized state comes from.
pred_at_fault = None
for pred, pred_state in zip(forward_edge_preds, pred_states):
if not pred_state[env][var_name]:
pred_at_fault = pred
assert pred_at_fault is not None
# Report the error.
self._uninitialized_access(insn, pred_at_fault)
# Save the state.
state[block] = block_state
return block_state
for block in func.basic_blocks:
traverse(block)
def _uninitialized_access(self, insn, pred_at_fault):
uninitialized_loc = None
for pred_insn in reversed(pred_at_fault.instructions):
if pred_insn.loc is not None:
uninitialized_loc = pred_insn.loc.begin()
break
assert uninitialized_loc is not None
note = diagnostic.Diagnostic("note",
"variable is not initialized when control flows from this point", {},
uninitialized_loc)
diag = diagnostic.Diagnostic("error",
"variable '{name}' is not always initialized at this point",
{"name": insn.var_name},
insn.loc, notes=[note])
self.engine.process(diag)

View File

@ -0,0 +1,20 @@
# RUN: %python -m artiq.compiler.testbench.module +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
x = 1
if x > 10:
y = 1
# CHECK-L: ${LINE:+1}: error: variable 'y' is not always initialized
x + y
for z in [1]:
pass
# CHECK-L: ${LINE:+1}: error: variable 'z' is not always initialized
-z
if True:
pass
else:
t = 1
# CHECK-L: ${LINE:+1}: error: variable 't' is not always initialized
-t

View File

@ -0,0 +1,7 @@
# RUN: %python -m artiq.compiler.testbench.module %s >%t
if False:
x = 1
else:
x = 2
-x