mirror of
https://github.com/m-labs/artiq.git
synced 2025-01-26 18:38:13 +08:00
Add LocalAccessValidator.
This commit is contained in:
parent
f5d9e11b38
commit
ac491fae47
@ -395,6 +395,9 @@ class Function:
|
||||
assert any(self.basic_blocks)
|
||||
return self.basic_blocks[0]
|
||||
|
||||
def exits(self):
|
||||
return [block for block in self.basic_blocks if not any(block.successors())]
|
||||
|
||||
def instructions(self):
|
||||
for basic_block in self.basic_blocks:
|
||||
yield from iter(basic_block.instructions)
|
||||
@ -508,7 +511,7 @@ class GetLocal(Instruction):
|
||||
def opcode(self):
|
||||
return "getlocal({})".format(self.var_name)
|
||||
|
||||
def get_env(self):
|
||||
def environment(self):
|
||||
return self.operands[0]
|
||||
|
||||
class SetLocal(Instruction):
|
||||
@ -536,10 +539,10 @@ class SetLocal(Instruction):
|
||||
def opcode(self):
|
||||
return "setlocal({})".format(self.var_name)
|
||||
|
||||
def get_env(self):
|
||||
def environment(self):
|
||||
return self.operands[0]
|
||||
|
||||
def get_value(self):
|
||||
def value(self):
|
||||
return self.operands[1]
|
||||
|
||||
class GetAttr(Instruction):
|
||||
@ -568,7 +571,7 @@ class GetAttr(Instruction):
|
||||
def opcode(self):
|
||||
return "getattr({})".format(repr(self.attr))
|
||||
|
||||
def get_env(self):
|
||||
def env(self):
|
||||
return self.operands[0]
|
||||
|
||||
class SetAttr(Instruction):
|
||||
@ -597,10 +600,10 @@ class SetAttr(Instruction):
|
||||
def opcode(self):
|
||||
return "setattr({})".format(repr(self.attr))
|
||||
|
||||
def get_env(self):
|
||||
def env(self):
|
||||
return self.operands[0]
|
||||
|
||||
def get_value(self):
|
||||
def value(self):
|
||||
return self.operands[1]
|
||||
|
||||
class GetElem(Instruction):
|
||||
@ -620,10 +623,10 @@ class GetElem(Instruction):
|
||||
def opcode(self):
|
||||
return "getelem"
|
||||
|
||||
def get_list(self):
|
||||
def list(self):
|
||||
return self.operands[0]
|
||||
|
||||
def get_index(self):
|
||||
def index(self):
|
||||
return self.operands[1]
|
||||
|
||||
class SetElem(Instruction):
|
||||
@ -646,13 +649,13 @@ class SetElem(Instruction):
|
||||
def opcode(self):
|
||||
return "setelem"
|
||||
|
||||
def get_list(self):
|
||||
def list(self):
|
||||
return self.operands[0]
|
||||
|
||||
def get_index(self):
|
||||
def index(self):
|
||||
return self.operands[1]
|
||||
|
||||
def get_value(self):
|
||||
def value(self):
|
||||
return self.operands[2]
|
||||
|
||||
class UnaryOp(Instruction):
|
||||
@ -675,7 +678,7 @@ class UnaryOp(Instruction):
|
||||
def opcode(self):
|
||||
return "unaryop({})".format(type(self.op).__name__)
|
||||
|
||||
def get_operand(self):
|
||||
def operand(self):
|
||||
return self.operands[0]
|
||||
|
||||
class Coerce(Instruction):
|
||||
@ -691,7 +694,7 @@ class Coerce(Instruction):
|
||||
def opcode(self):
|
||||
return "coerce"
|
||||
|
||||
def get_value(self):
|
||||
def value(self):
|
||||
return self.operands[0]
|
||||
|
||||
class BinaryOp(Instruction):
|
||||
@ -717,10 +720,10 @@ class BinaryOp(Instruction):
|
||||
def opcode(self):
|
||||
return "binaryop({})".format(type(self.op).__name__)
|
||||
|
||||
def get_lhs(self):
|
||||
def lhs(self):
|
||||
return self.operands[0]
|
||||
|
||||
def get_rhs(self):
|
||||
def rhs(self):
|
||||
return self.operands[1]
|
||||
|
||||
class Compare(Instruction):
|
||||
@ -746,10 +749,10 @@ class Compare(Instruction):
|
||||
def opcode(self):
|
||||
return "compare({})".format(type(self.op).__name__)
|
||||
|
||||
def get_lhs(self):
|
||||
def lhs(self):
|
||||
return self.operands[0]
|
||||
|
||||
def get_rhs(self):
|
||||
def rhs(self):
|
||||
return self.operands[1]
|
||||
|
||||
class Builtin(Instruction):
|
||||
@ -793,7 +796,7 @@ class Closure(Instruction):
|
||||
def opcode(self):
|
||||
return "closure({})".format(self.target_function.name)
|
||||
|
||||
def get_environment(self):
|
||||
def environment(self):
|
||||
return self.operands[0]
|
||||
|
||||
class Call(Instruction):
|
||||
@ -813,10 +816,10 @@ class Call(Instruction):
|
||||
def opcode(self):
|
||||
return "call"
|
||||
|
||||
def get_function(self):
|
||||
def function(self):
|
||||
return self.operands[0]
|
||||
|
||||
def get_arguments(self):
|
||||
def arguments(self):
|
||||
return self.operands[1:]
|
||||
|
||||
class Select(Instruction):
|
||||
@ -839,13 +842,13 @@ class Select(Instruction):
|
||||
def opcode(self):
|
||||
return "select"
|
||||
|
||||
def get_condition(self):
|
||||
def condition(self):
|
||||
return self.operands[0]
|
||||
|
||||
def get_if_true(self):
|
||||
def if_true(self):
|
||||
return self.operands[1]
|
||||
|
||||
def get_if_false(self):
|
||||
def if_false(self):
|
||||
return self.operands[2]
|
||||
|
||||
class Branch(Terminator):
|
||||
@ -863,7 +866,7 @@ class Branch(Terminator):
|
||||
def opcode(self):
|
||||
return "branch"
|
||||
|
||||
def get_target(self):
|
||||
def target(self):
|
||||
return self.operands[0]
|
||||
|
||||
class BranchIf(Terminator):
|
||||
@ -885,13 +888,13 @@ class BranchIf(Terminator):
|
||||
def opcode(self):
|
||||
return "branchif"
|
||||
|
||||
def get_condition(self):
|
||||
def condition(self):
|
||||
return self.operands[0]
|
||||
|
||||
def get_if_true(self):
|
||||
def if_true(self):
|
||||
return self.operands[1]
|
||||
|
||||
def get_if_false(self):
|
||||
def if_false(self):
|
||||
return self.operands[2]
|
||||
|
||||
class IndirectBranch(Terminator):
|
||||
@ -911,10 +914,10 @@ class IndirectBranch(Terminator):
|
||||
def opcode(self):
|
||||
return "indirectbranch"
|
||||
|
||||
def get_target(self):
|
||||
def target(self):
|
||||
return self.operands[0]
|
||||
|
||||
def get_destinations(self):
|
||||
def destinations(self):
|
||||
return self.operands[1:]
|
||||
|
||||
def add_destination(self, destination):
|
||||
@ -939,7 +942,7 @@ class Return(Terminator):
|
||||
def opcode(self):
|
||||
return "return"
|
||||
|
||||
def get_value(self):
|
||||
def value(self):
|
||||
return self.operands[0]
|
||||
|
||||
class Unreachable(Terminator):
|
||||
@ -971,7 +974,7 @@ class Raise(Terminator):
|
||||
def opcode(self):
|
||||
return "raise"
|
||||
|
||||
def get_value(self):
|
||||
def value(self):
|
||||
return self.operands[0]
|
||||
|
||||
class Invoke(Terminator):
|
||||
@ -995,16 +998,16 @@ class Invoke(Terminator):
|
||||
def opcode(self):
|
||||
return "invoke"
|
||||
|
||||
def get_function(self):
|
||||
def function(self):
|
||||
return self.operands[0]
|
||||
|
||||
def get_arguments(self):
|
||||
def arguments(self):
|
||||
return self.operands[1:-2]
|
||||
|
||||
def get_normal_target(self):
|
||||
def normal_target(self):
|
||||
return self.operands[-2]
|
||||
|
||||
def get_exception_target(self):
|
||||
def exception_target(self):
|
||||
return self.operands[-1]
|
||||
|
||||
def _operands_as_string(self):
|
||||
|
@ -11,27 +11,28 @@ class Module:
|
||||
if engine is None:
|
||||
engine = diagnostic.Engine(all_errors_are_fatal=True)
|
||||
|
||||
module_name, _ = os.path.splitext(os.path.basename(source_buffer.name))
|
||||
self.name, _ = os.path.splitext(os.path.basename(source_buffer.name))
|
||||
|
||||
asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine)
|
||||
inferencer = transforms.Inferencer(engine=engine)
|
||||
int_monomorphizer = transforms.IntMonomorphizer(engine=engine)
|
||||
monomorphism_validator = validators.MonomorphismValidator(engine=engine)
|
||||
escape_validator = validators.EscapeValidator(engine=engine)
|
||||
ir_generator = transforms.IRGenerator(engine=engine, module_name=module_name)
|
||||
ir_generator = transforms.IRGenerator(engine=engine, module_name=self.name)
|
||||
dead_code_eliminator = transforms.DeadCodeEliminator(engine=engine)
|
||||
local_access_validator = validators.LocalAccessValidator(engine=engine)
|
||||
|
||||
parsetree, comments = parse_buffer(source_buffer, engine=engine)
|
||||
typedtree = asttyped_rewriter.visit(parsetree)
|
||||
inferencer.visit(typedtree)
|
||||
int_monomorphizer.visit(typedtree)
|
||||
inferencer.visit(typedtree)
|
||||
monomorphism_validator.visit(typedtree)
|
||||
escape_validator.visit(typedtree)
|
||||
ir_generator.visit(typedtree)
|
||||
|
||||
self.name = module_name
|
||||
self.parsetree, self.comments = parse_buffer(source_buffer, engine=engine)
|
||||
self.typedtree = asttyped_rewriter.visit(self.parsetree)
|
||||
self.globals = asttyped_rewriter.globals
|
||||
self.ir = ir_generator.functions
|
||||
inferencer.visit(self.typedtree)
|
||||
int_monomorphizer.visit(self.typedtree)
|
||||
inferencer.visit(self.typedtree)
|
||||
monomorphism_validator.visit(self.typedtree)
|
||||
escape_validator.visit(self.typedtree)
|
||||
self.ir = ir_generator.visit(self.typedtree)
|
||||
dead_code_eliminator.process(self.ir)
|
||||
local_access_validator.process(self.ir)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, source_string, name="input.py", first_line=1, engine=None):
|
||||
|
@ -138,7 +138,7 @@ class IRGenerator(algorithm.Visitor):
|
||||
self.generic_visit(node)
|
||||
self.terminate(ir.Return(ir.Constant(None, builtins.TNone())))
|
||||
|
||||
return func
|
||||
return self.functions
|
||||
finally:
|
||||
self.current_function = old_func
|
||||
self.current_block = old_block
|
||||
@ -245,8 +245,18 @@ class IRGenerator(algorithm.Visitor):
|
||||
self.append(ir.Branch(self.return_target))
|
||||
|
||||
def visit_Expr(self, node):
|
||||
# ignore the value, do it for side effects
|
||||
self.visit(node.value)
|
||||
# Ignore the value, do it for side effects.
|
||||
result = self.visit(node.value)
|
||||
|
||||
# See comment in visit_Pass.
|
||||
if isinstance(result, ir.Constant):
|
||||
self.visit_Pass(node)
|
||||
|
||||
def visit_Pass(self, node):
|
||||
# Insert a dummy instruction so that analyses which extract
|
||||
# locations from CFG have something to use.
|
||||
self.append(ir.BinaryOp(ast.Add(loc=None),
|
||||
ir.Constant(0, self._size_type), ir.Constant(0, self._size_type)))
|
||||
|
||||
def visit_Assign(self, node):
|
||||
try:
|
||||
@ -288,9 +298,9 @@ class IRGenerator(algorithm.Visitor):
|
||||
if any(node.orelse):
|
||||
if not if_false.is_terminated():
|
||||
if_false.append(ir.Branch(tail))
|
||||
head.append(ir.BranchIf(cond, if_true, if_false))
|
||||
self.append(ir.BranchIf(cond, if_true, if_false), block=head)
|
||||
else:
|
||||
head.append(ir.BranchIf(cond, if_true, tail))
|
||||
self.append(ir.BranchIf(cond, if_true, tail), block=head)
|
||||
|
||||
def visit_While(self, node):
|
||||
try:
|
||||
|
@ -1,2 +1,3 @@
|
||||
from .monomorphism import MonomorphismValidator
|
||||
from .escape import EscapeValidator
|
||||
from .local_access import LocalAccessValidator
|
||||
|
121
artiq/compiler/validators/local_access.py
Normal file
121
artiq/compiler/validators/local_access.py
Normal file
@ -0,0 +1,121 @@
|
||||
"""
|
||||
:class:`LocalAccessValidator` verifies that local variables
|
||||
are not accessed before being used.
|
||||
"""
|
||||
|
||||
from functools import reduce
|
||||
from pythonparser import diagnostic
|
||||
from .. import ir, analyses
|
||||
|
||||
class LocalAccessValidator:
|
||||
def __init__(self, engine):
|
||||
self.engine = engine
|
||||
|
||||
def process(self, functions):
|
||||
for func in functions:
|
||||
self.process_function(func)
|
||||
|
||||
def process_function(self, func):
|
||||
# Find all environments allocated in this func.
|
||||
environments = []
|
||||
for insn in func.instructions():
|
||||
if isinstance(insn, ir.Alloc) and ir.is_environment(insn.type):
|
||||
environments.append(insn)
|
||||
|
||||
# Compute initial state of interesting environments.
|
||||
# Environments consisting only of internal variables (containing a ".")
|
||||
# are ignored.
|
||||
initial_state = {}
|
||||
for env in environments:
|
||||
env_state = {var: False for var in env.type.params if "." not in var}
|
||||
if any(env_state):
|
||||
initial_state[env] = env_state
|
||||
|
||||
# Traverse the acyclic graph made of basic blocks and forward edges only,
|
||||
# while updating the environment state.
|
||||
dom = analyses.DominatorTree(func)
|
||||
state = {}
|
||||
def traverse(block):
|
||||
# Have we computed the state of this block already?
|
||||
if block in state:
|
||||
return state[block]
|
||||
|
||||
# No! Which forward edges lead to this block?
|
||||
# If we dominate a predecessor, it's a back edge instead.
|
||||
forward_edge_preds = [pred for pred in block.predecessors()
|
||||
if block not in dom.dominated_by[pred]]
|
||||
|
||||
# Figure out what the state is before the leader
|
||||
# instruction of this block.
|
||||
pred_states = [traverse(pred) for pred in forward_edge_preds]
|
||||
block_state = {}
|
||||
if len(pred_states) > 1:
|
||||
for env in initial_state:
|
||||
# The variable has to be initialized in all predecessors
|
||||
# in order to be initialized in this block.
|
||||
def merge_state(a, b):
|
||||
return {var: a[var] and b[var] for var in a}
|
||||
block_state[env] = reduce(lambda a, b: merge_state(a[env], b[env]),
|
||||
pred_states)
|
||||
elif len(pred_states) == 1:
|
||||
# The state is the same as at the terminator of predecessor.
|
||||
# We'll mutate it, so copy.
|
||||
pred_state = pred_states[0]
|
||||
for env in initial_state:
|
||||
env_state = pred_state[env]
|
||||
block_state[env] = {var: env_state[var] for var in env_state}
|
||||
else:
|
||||
# This is the entry block.
|
||||
for env in initial_state:
|
||||
env_state = initial_state[env]
|
||||
block_state[env] = {var: env_state[var] for var in env_state}
|
||||
|
||||
# Update the state based on block contents, while validating
|
||||
# that no access to uninitialized variables will be done.
|
||||
for insn in block.instructions:
|
||||
if isinstance(insn, (ir.SetLocal, ir.GetLocal)) and \
|
||||
"." not in insn.var_name:
|
||||
env, var_name = insn.environment(), insn.var_name
|
||||
assert env in block_state
|
||||
assert var_name in block_state[env]
|
||||
|
||||
if isinstance(insn, ir.SetLocal):
|
||||
# We've just initialized it.
|
||||
block_state[env][var_name] = True
|
||||
else: # isinstance(insn, ir.GetLocal)
|
||||
if not block_state[env][var_name]:
|
||||
# Oops, accessing it uninitialized. Find out where
|
||||
# the uninitialized state comes from.
|
||||
pred_at_fault = None
|
||||
for pred, pred_state in zip(forward_edge_preds, pred_states):
|
||||
if not pred_state[env][var_name]:
|
||||
pred_at_fault = pred
|
||||
assert pred_at_fault is not None
|
||||
|
||||
# Report the error.
|
||||
self._uninitialized_access(insn, pred_at_fault)
|
||||
|
||||
# Save the state.
|
||||
state[block] = block_state
|
||||
|
||||
return block_state
|
||||
|
||||
for block in func.basic_blocks:
|
||||
traverse(block)
|
||||
|
||||
def _uninitialized_access(self, insn, pred_at_fault):
|
||||
uninitialized_loc = None
|
||||
for pred_insn in reversed(pred_at_fault.instructions):
|
||||
if pred_insn.loc is not None:
|
||||
uninitialized_loc = pred_insn.loc.begin()
|
||||
break
|
||||
assert uninitialized_loc is not None
|
||||
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"variable is not initialized when control flows from this point", {},
|
||||
uninitialized_loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"variable '{name}' is not always initialized at this point",
|
||||
{"name": insn.var_name},
|
||||
insn.loc, notes=[note])
|
||||
self.engine.process(diag)
|
20
lit-test/compiler/local_access/invalid.py
Normal file
20
lit-test/compiler/local_access/invalid.py
Normal file
@ -0,0 +1,20 @@
|
||||
# RUN: %python -m artiq.compiler.testbench.module +diag %s >%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
x = 1
|
||||
if x > 10:
|
||||
y = 1
|
||||
# CHECK-L: ${LINE:+1}: error: variable 'y' is not always initialized
|
||||
x + y
|
||||
|
||||
for z in [1]:
|
||||
pass
|
||||
# CHECK-L: ${LINE:+1}: error: variable 'z' is not always initialized
|
||||
-z
|
||||
|
||||
if True:
|
||||
pass
|
||||
else:
|
||||
t = 1
|
||||
# CHECK-L: ${LINE:+1}: error: variable 't' is not always initialized
|
||||
-t
|
7
lit-test/compiler/local_access/valid.py
Normal file
7
lit-test/compiler/local_access/valid.py
Normal file
@ -0,0 +1,7 @@
|
||||
# RUN: %python -m artiq.compiler.testbench.module %s >%t
|
||||
|
||||
if False:
|
||||
x = 1
|
||||
else:
|
||||
x = 2
|
||||
-x
|
Loading…
Reference in New Issue
Block a user