diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index 1aaf0d97f..58e3566eb 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -395,6 +395,9 @@ class Function: assert any(self.basic_blocks) return self.basic_blocks[0] + def exits(self): + return [block for block in self.basic_blocks if not any(block.successors())] + def instructions(self): for basic_block in self.basic_blocks: yield from iter(basic_block.instructions) @@ -508,7 +511,7 @@ class GetLocal(Instruction): def opcode(self): return "getlocal({})".format(self.var_name) - def get_env(self): + def environment(self): return self.operands[0] class SetLocal(Instruction): @@ -536,10 +539,10 @@ class SetLocal(Instruction): def opcode(self): return "setlocal({})".format(self.var_name) - def get_env(self): + def environment(self): return self.operands[0] - def get_value(self): + def value(self): return self.operands[1] class GetAttr(Instruction): @@ -568,7 +571,7 @@ class GetAttr(Instruction): def opcode(self): return "getattr({})".format(repr(self.attr)) - def get_env(self): + def env(self): return self.operands[0] class SetAttr(Instruction): @@ -597,10 +600,10 @@ class SetAttr(Instruction): def opcode(self): return "setattr({})".format(repr(self.attr)) - def get_env(self): + def env(self): return self.operands[0] - def get_value(self): + def value(self): return self.operands[1] class GetElem(Instruction): @@ -620,10 +623,10 @@ class GetElem(Instruction): def opcode(self): return "getelem" - def get_list(self): + def list(self): return self.operands[0] - def get_index(self): + def index(self): return self.operands[1] class SetElem(Instruction): @@ -646,13 +649,13 @@ class SetElem(Instruction): def opcode(self): return "setelem" - def get_list(self): + def list(self): return self.operands[0] - def get_index(self): + def index(self): return self.operands[1] - def get_value(self): + def value(self): return self.operands[2] class UnaryOp(Instruction): @@ -675,7 +678,7 @@ class UnaryOp(Instruction): def opcode(self): return "unaryop({})".format(type(self.op).__name__) - def get_operand(self): + def operand(self): return self.operands[0] class Coerce(Instruction): @@ -691,7 +694,7 @@ class Coerce(Instruction): def opcode(self): return "coerce" - def get_value(self): + def value(self): return self.operands[0] class BinaryOp(Instruction): @@ -717,10 +720,10 @@ class BinaryOp(Instruction): def opcode(self): return "binaryop({})".format(type(self.op).__name__) - def get_lhs(self): + def lhs(self): return self.operands[0] - def get_rhs(self): + def rhs(self): return self.operands[1] class Compare(Instruction): @@ -746,10 +749,10 @@ class Compare(Instruction): def opcode(self): return "compare({})".format(type(self.op).__name__) - def get_lhs(self): + def lhs(self): return self.operands[0] - def get_rhs(self): + def rhs(self): return self.operands[1] class Builtin(Instruction): @@ -793,7 +796,7 @@ class Closure(Instruction): def opcode(self): return "closure({})".format(self.target_function.name) - def get_environment(self): + def environment(self): return self.operands[0] class Call(Instruction): @@ -813,10 +816,10 @@ class Call(Instruction): def opcode(self): return "call" - def get_function(self): + def function(self): return self.operands[0] - def get_arguments(self): + def arguments(self): return self.operands[1:] class Select(Instruction): @@ -839,13 +842,13 @@ class Select(Instruction): def opcode(self): return "select" - def get_condition(self): + def condition(self): return self.operands[0] - def get_if_true(self): + def if_true(self): return self.operands[1] - def get_if_false(self): + def if_false(self): return self.operands[2] class Branch(Terminator): @@ -863,7 +866,7 @@ class Branch(Terminator): def opcode(self): return "branch" - def get_target(self): + def target(self): return self.operands[0] class BranchIf(Terminator): @@ -885,13 +888,13 @@ class BranchIf(Terminator): def opcode(self): return "branchif" - def get_condition(self): + def condition(self): return self.operands[0] - def get_if_true(self): + def if_true(self): return self.operands[1] - def get_if_false(self): + def if_false(self): return self.operands[2] class IndirectBranch(Terminator): @@ -911,10 +914,10 @@ class IndirectBranch(Terminator): def opcode(self): return "indirectbranch" - def get_target(self): + def target(self): return self.operands[0] - def get_destinations(self): + def destinations(self): return self.operands[1:] def add_destination(self, destination): @@ -939,7 +942,7 @@ class Return(Terminator): def opcode(self): return "return" - def get_value(self): + def value(self): return self.operands[0] class Unreachable(Terminator): @@ -971,7 +974,7 @@ class Raise(Terminator): def opcode(self): return "raise" - def get_value(self): + def value(self): return self.operands[0] class Invoke(Terminator): @@ -995,16 +998,16 @@ class Invoke(Terminator): def opcode(self): return "invoke" - def get_function(self): + def function(self): return self.operands[0] - def get_arguments(self): + def arguments(self): return self.operands[1:-2] - def get_normal_target(self): + def normal_target(self): return self.operands[-2] - def get_exception_target(self): + def exception_target(self): return self.operands[-1] def _operands_as_string(self): diff --git a/artiq/compiler/module.py b/artiq/compiler/module.py index 833896d79..22ae58577 100644 --- a/artiq/compiler/module.py +++ b/artiq/compiler/module.py @@ -11,27 +11,28 @@ class Module: if engine is None: engine = diagnostic.Engine(all_errors_are_fatal=True) - module_name, _ = os.path.splitext(os.path.basename(source_buffer.name)) + self.name, _ = os.path.splitext(os.path.basename(source_buffer.name)) asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine) inferencer = transforms.Inferencer(engine=engine) int_monomorphizer = transforms.IntMonomorphizer(engine=engine) monomorphism_validator = validators.MonomorphismValidator(engine=engine) escape_validator = validators.EscapeValidator(engine=engine) - ir_generator = transforms.IRGenerator(engine=engine, module_name=module_name) + ir_generator = transforms.IRGenerator(engine=engine, module_name=self.name) + dead_code_eliminator = transforms.DeadCodeEliminator(engine=engine) + local_access_validator = validators.LocalAccessValidator(engine=engine) - parsetree, comments = parse_buffer(source_buffer, engine=engine) - typedtree = asttyped_rewriter.visit(parsetree) - inferencer.visit(typedtree) - int_monomorphizer.visit(typedtree) - inferencer.visit(typedtree) - monomorphism_validator.visit(typedtree) - escape_validator.visit(typedtree) - ir_generator.visit(typedtree) - - self.name = module_name + self.parsetree, self.comments = parse_buffer(source_buffer, engine=engine) + self.typedtree = asttyped_rewriter.visit(self.parsetree) self.globals = asttyped_rewriter.globals - self.ir = ir_generator.functions + inferencer.visit(self.typedtree) + int_monomorphizer.visit(self.typedtree) + inferencer.visit(self.typedtree) + monomorphism_validator.visit(self.typedtree) + escape_validator.visit(self.typedtree) + self.ir = ir_generator.visit(self.typedtree) + dead_code_eliminator.process(self.ir) + local_access_validator.process(self.ir) @classmethod def from_string(cls, source_string, name="input.py", first_line=1, engine=None): diff --git a/artiq/compiler/transforms/ir_generator.py b/artiq/compiler/transforms/ir_generator.py index 000e016cf..c30624cd6 100644 --- a/artiq/compiler/transforms/ir_generator.py +++ b/artiq/compiler/transforms/ir_generator.py @@ -138,7 +138,7 @@ class IRGenerator(algorithm.Visitor): self.generic_visit(node) self.terminate(ir.Return(ir.Constant(None, builtins.TNone()))) - return func + return self.functions finally: self.current_function = old_func self.current_block = old_block @@ -245,8 +245,18 @@ class IRGenerator(algorithm.Visitor): self.append(ir.Branch(self.return_target)) def visit_Expr(self, node): - # ignore the value, do it for side effects - self.visit(node.value) + # Ignore the value, do it for side effects. + result = self.visit(node.value) + + # See comment in visit_Pass. + if isinstance(result, ir.Constant): + self.visit_Pass(node) + + def visit_Pass(self, node): + # Insert a dummy instruction so that analyses which extract + # locations from CFG have something to use. + self.append(ir.BinaryOp(ast.Add(loc=None), + ir.Constant(0, self._size_type), ir.Constant(0, self._size_type))) def visit_Assign(self, node): try: @@ -288,9 +298,9 @@ class IRGenerator(algorithm.Visitor): if any(node.orelse): if not if_false.is_terminated(): if_false.append(ir.Branch(tail)) - head.append(ir.BranchIf(cond, if_true, if_false)) + self.append(ir.BranchIf(cond, if_true, if_false), block=head) else: - head.append(ir.BranchIf(cond, if_true, tail)) + self.append(ir.BranchIf(cond, if_true, tail), block=head) def visit_While(self, node): try: diff --git a/artiq/compiler/validators/__init__.py b/artiq/compiler/validators/__init__.py index a90a89c69..7f0719ea9 100644 --- a/artiq/compiler/validators/__init__.py +++ b/artiq/compiler/validators/__init__.py @@ -1,2 +1,3 @@ from .monomorphism import MonomorphismValidator from .escape import EscapeValidator +from .local_access import LocalAccessValidator diff --git a/artiq/compiler/validators/local_access.py b/artiq/compiler/validators/local_access.py new file mode 100644 index 000000000..4f2b4ffd0 --- /dev/null +++ b/artiq/compiler/validators/local_access.py @@ -0,0 +1,121 @@ +""" +:class:`LocalAccessValidator` verifies that local variables +are not accessed before being used. +""" + +from functools import reduce +from pythonparser import diagnostic +from .. import ir, analyses + +class LocalAccessValidator: + def __init__(self, engine): + self.engine = engine + + def process(self, functions): + for func in functions: + self.process_function(func) + + def process_function(self, func): + # Find all environments allocated in this func. + environments = [] + for insn in func.instructions(): + if isinstance(insn, ir.Alloc) and ir.is_environment(insn.type): + environments.append(insn) + + # Compute initial state of interesting environments. + # Environments consisting only of internal variables (containing a ".") + # are ignored. + initial_state = {} + for env in environments: + env_state = {var: False for var in env.type.params if "." not in var} + if any(env_state): + initial_state[env] = env_state + + # Traverse the acyclic graph made of basic blocks and forward edges only, + # while updating the environment state. + dom = analyses.DominatorTree(func) + state = {} + def traverse(block): + # Have we computed the state of this block already? + if block in state: + return state[block] + + # No! Which forward edges lead to this block? + # If we dominate a predecessor, it's a back edge instead. + forward_edge_preds = [pred for pred in block.predecessors() + if block not in dom.dominated_by[pred]] + + # Figure out what the state is before the leader + # instruction of this block. + pred_states = [traverse(pred) for pred in forward_edge_preds] + block_state = {} + if len(pred_states) > 1: + for env in initial_state: + # The variable has to be initialized in all predecessors + # in order to be initialized in this block. + def merge_state(a, b): + return {var: a[var] and b[var] for var in a} + block_state[env] = reduce(lambda a, b: merge_state(a[env], b[env]), + pred_states) + elif len(pred_states) == 1: + # The state is the same as at the terminator of predecessor. + # We'll mutate it, so copy. + pred_state = pred_states[0] + for env in initial_state: + env_state = pred_state[env] + block_state[env] = {var: env_state[var] for var in env_state} + else: + # This is the entry block. + for env in initial_state: + env_state = initial_state[env] + block_state[env] = {var: env_state[var] for var in env_state} + + # Update the state based on block contents, while validating + # that no access to uninitialized variables will be done. + for insn in block.instructions: + if isinstance(insn, (ir.SetLocal, ir.GetLocal)) and \ + "." not in insn.var_name: + env, var_name = insn.environment(), insn.var_name + assert env in block_state + assert var_name in block_state[env] + + if isinstance(insn, ir.SetLocal): + # We've just initialized it. + block_state[env][var_name] = True + else: # isinstance(insn, ir.GetLocal) + if not block_state[env][var_name]: + # Oops, accessing it uninitialized. Find out where + # the uninitialized state comes from. + pred_at_fault = None + for pred, pred_state in zip(forward_edge_preds, pred_states): + if not pred_state[env][var_name]: + pred_at_fault = pred + assert pred_at_fault is not None + + # Report the error. + self._uninitialized_access(insn, pred_at_fault) + + # Save the state. + state[block] = block_state + + return block_state + + for block in func.basic_blocks: + traverse(block) + + def _uninitialized_access(self, insn, pred_at_fault): + uninitialized_loc = None + for pred_insn in reversed(pred_at_fault.instructions): + if pred_insn.loc is not None: + uninitialized_loc = pred_insn.loc.begin() + break + assert uninitialized_loc is not None + + note = diagnostic.Diagnostic("note", + "variable is not initialized when control flows from this point", {}, + uninitialized_loc) + diag = diagnostic.Diagnostic("error", + "variable '{name}' is not always initialized at this point", + {"name": insn.var_name}, + insn.loc, notes=[note]) + self.engine.process(diag) diff --git a/lit-test/compiler/local_access/invalid.py b/lit-test/compiler/local_access/invalid.py new file mode 100644 index 000000000..94561d2d6 --- /dev/null +++ b/lit-test/compiler/local_access/invalid.py @@ -0,0 +1,20 @@ +# RUN: %python -m artiq.compiler.testbench.module +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +x = 1 +if x > 10: + y = 1 +# CHECK-L: ${LINE:+1}: error: variable 'y' is not always initialized +x + y + +for z in [1]: + pass +# CHECK-L: ${LINE:+1}: error: variable 'z' is not always initialized +-z + +if True: + pass +else: + t = 1 +# CHECK-L: ${LINE:+1}: error: variable 't' is not always initialized +-t diff --git a/lit-test/compiler/local_access/valid.py b/lit-test/compiler/local_access/valid.py new file mode 100644 index 000000000..0e598ed39 --- /dev/null +++ b/lit-test/compiler/local_access/valid.py @@ -0,0 +1,7 @@ +# RUN: %python -m artiq.compiler.testbench.module %s >%t + +if False: + x = 1 +else: + x = 2 +-x