From 236d5b886a4510a576803936a878985a36c7f9db Mon Sep 17 00:00:00 2001 From: whitequark Date: Wed, 22 Jul 2015 02:58:59 +0300 Subject: [PATCH] Add support for Assert. --- artiq/compiler/ir.py | 33 ++++- .../compiler/transforms/artiq_ir_generator.py | 118 ++++++++++++++++-- .../compiler/transforms/asttyped_rewriter.py | 1 - .../transforms/dead_code_eliminator.py | 23 +++- artiq/compiler/transforms/inferencer.py | 11 ++ .../compiler/transforms/llvm_ir_generator.py | 28 ++++- lit-test/compiler/inferencer/error_assert.py | 6 + 7 files changed, 199 insertions(+), 21 deletions(-) create mode 100644 lit-test/compiler/inferencer/error_assert.py diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index 55313b7ed..f94df0ed4 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -106,14 +106,13 @@ class User(NamedValue): def __init__(self, operands, typ, name): super().__init__(typ, name) self.operands = [] - if operands is not None: - self.set_operands(operands) + self.set_operands(operands) def set_operands(self, new_operands): - for operand in self.operands: + for operand in set(self.operands): operand.uses.remove(self) self.operands = new_operands - for operand in self.operands: + for operand in set(self.operands): operand.uses.add(self) def drop_references(self): @@ -162,6 +161,9 @@ class Instruction(User): def erase(self): self.remove_from_parent() self.drop_references() + # Check this after drop_references in case this + # is a self-referencing phi. + assert not any(self.uses) def replace_with(self, value): self.replace_all_uses_with(value) @@ -220,7 +222,21 @@ class Phi(Instruction): def add_incoming(self, value, block): assert value.type == self.type self.operands.append(value) + value.uses.add(self) self.operands.append(block) + block.uses.add(self) + + def remove_incoming_value(self, value): + index = self.operands.index(value) + self.operands[index].uses.remove(self) + self.operands[index + 1].uses.remove(self) + del self.operands[index:index + 2] + + def remove_incoming_block(self, block): + index = self.operands.index(block) + self.operands[index - 1].uses.remove(self) + self.operands[index].uses.remove(self) + del self.operands[index - 1:index + 1] def __str__(self): if builtins.is_none(self.type): @@ -268,9 +284,13 @@ class BasicBlock(NamedValue): self.function.remove(self) def erase(self): - for insn in self.instructions: + # self.instructions is updated while iterating + for insn in list(self.instructions): insn.erase() self.remove_from_parent() + # Check this after erasing instructions in case the block + # loops into itself. + assert not any(self.uses) def prepend(self, insn): assert isinstance(insn, Instruction) @@ -817,6 +837,7 @@ class Select(Instruction): """ def __init__(self, cond, if_true, if_false, name=""): assert isinstance(cond, Value) + assert builtins.is_bool(cond.type) assert isinstance(if_true, Value) assert isinstance(if_false, Value) assert if_true.type == if_false.type @@ -864,8 +885,10 @@ class BranchIf(Terminator): """ def __init__(self, cond, if_true, if_false, name=""): assert isinstance(cond, Value) + assert builtins.is_bool(cond.type) assert isinstance(if_true, BasicBlock) assert isinstance(if_false, BasicBlock) + assert if_true != if_false # use Branch instead super().__init__([cond, if_true, if_false], builtins.TNone(), name) def opcode(self): diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 715845b45..c7e633295 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -39,16 +39,22 @@ class ARTIQIRGenerator(algorithm.Visitor): set of variables that will be resolved in global scope :ivar current_block: (:class:`ir.BasicBlock`) basic block to which any new instruction will be appended - :ivar current_env: (:class:`ir.Environment`) + :ivar current_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`) the chained function environment, containing variables that can become upvalues - :ivar current_private_env: (:class:`ir.Environment`) + :ivar current_private_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`) the private function environment, containing internal state :ivar current_assign: (:class:`ir.Value` or None) the right-hand side of current assignment statement, or a component of a composite right-hand side when visiting a composite left-hand side, such as, in ``x, y = z``, the 2nd tuple element when visting ``y`` + :ivar current_assert_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`) + the environment where the individual components of current assert + statement are stored until display + :ivar current_assert_subexprs: (list of (:class:`ast.AST`, string)) + the mapping from components of current assert statement to the names + their values have in :ivar:`current_assert_env` :ivar break_target: (:class:`ir.BasicBlock` or None) the basic block to which ``break`` will transfer control :ivar continue_target: (:class:`ir.BasicBlock` or None) @@ -72,6 +78,8 @@ class ARTIQIRGenerator(algorithm.Visitor): self.current_env = None self.current_private_env = None self.current_assign = None + self.current_assert_env = None + self.current_assert_subexprs = None self.break_target = None self.continue_target = None self.return_target = None @@ -203,7 +211,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.append(ir.SetLocal(env, arg_name, args[index])) for index, (arg_name, env_default_name) in enumerate(zip(typ.optargs, defaults)): default = self.append(ir.GetLocal(self.current_env, env_default_name)) - value = self.append(ir.Builtin("unwrap", [optargs[index], default], + value = self.append(ir.Builtin("unwrap_or", [optargs[index], default], typ.optargs[arg_name])) self.append(ir.SetLocal(env, arg_name, value)) @@ -736,7 +744,7 @@ class ARTIQIRGenerator(algorithm.Visitor): for index, elt_node in enumerate(node.elts): self.current_assign = \ self.append(ir.GetAttr(old_assign, index, - name="{}.{}".format(old_assign.name, index)), + name="{}.e{}".format(old_assign.name, index)), loc=elt_node.loc) self.visit(elt_node) finally: @@ -805,18 +813,26 @@ class ARTIQIRGenerator(algorithm.Visitor): def visit_BoolOpT(self, node): blocks = [] for value_node in node.values: + value_head = self.current_block value = self.visit(value_node) - blocks.append((value, self.current_block)) + self.instrument_assert(value_node, value) + value_tail = self.current_block + + blocks.append((value, value_head, value_tail)) self.current_block = self.add_block() tail = self.current_block phi = self.append(ir.Phi(node.type)) - for ((value, block), next_block) in zip(blocks, [b for (v,b) in blocks[1:]] + [tail]): - phi.add_incoming(value, block) - if isinstance(node.op, ast.And): - block.append(ir.BranchIf(value, next_block, tail)) + for ((value, value_head, value_tail), (next_value_head, next_value_tail)) in \ + zip(blocks, [(h,t) for (v,h,t) in blocks[1:]] + [(tail, tail)]): + phi.add_incoming(value, value_tail) + if next_value_head != tail: + if isinstance(node.op, ast.And): + value_tail.append(ir.BranchIf(value, next_value_head, tail)) + else: + value_tail.append(ir.BranchIf(value, tail, next_value_head)) else: - block.append(ir.BranchIf(value, tail, next_block)) + value_tail.append(ir.Branch(tail)) return phi def visit_UnaryOpT(self, node): @@ -1005,7 +1021,7 @@ class ARTIQIRGenerator(algorithm.Visitor): ir.Constant(False, builtins.TBool()))) result = self.append(ir.Select(result, on_step, ir.Constant(False, builtins.TBool()))) - elif builtins.isiterable(haystack.type): + elif builtins.is_iterable(haystack.type): length = self.iterable_len(haystack) cmp_result = loop_body2 = None @@ -1068,8 +1084,10 @@ class ARTIQIRGenerator(algorithm.Visitor): # of comparisons. blocks = [] lhs = self.visit(node.left) + self.instrument_assert(node.left, lhs) for op, rhs_node in zip(node.ops, node.comparators): rhs = self.visit(rhs_node) + self.instrument_assert(rhs_node, rhs) result = self.polymorphic_compare_pair(op, lhs, rhs) blocks.append((result, self.current_block)) self.current_block = self.add_block() @@ -1079,7 +1097,10 @@ class ARTIQIRGenerator(algorithm.Visitor): phi = self.append(ir.Phi(node.type)) for ((value, block), next_block) in zip(blocks, [b for (v,b) in blocks[1:]] + [tail]): phi.add_incoming(value, block) - block.append(ir.BranchIf(value, next_block, tail)) + if next_block != tail: + block.append(ir.BranchIf(value, next_block, tail)) + else: + block.append(ir.Branch(tail)) return phi def visit_builtin_call(self, node): @@ -1211,6 +1232,79 @@ class ARTIQIRGenerator(algorithm.Visitor): self.current_block = after_invoke return invoke + def instrument_assert(self, node, value): + if self.current_assert_env is not None: + if isinstance(value, ir.Constant): + return # don't display the values of constants + + if any([algorithm.compare(node, subexpr) + for (subexpr, name) in self.current_assert_subexprs]): + return # don't display the same subexpression twice + + name = self.current_assert_env.type.add("subexpr", ir.TOption(node.type)) + value_opt = self.append(ir.Alloc([value], ir.TOption(node.type)), + loc=node.loc) + self.append(ir.SetLocal(self.current_assert_env, name, value_opt), + loc=node.loc) + self.current_assert_subexprs.append((node, name)) + + def visit_Assert(self, node): + try: + assert_env = self.current_assert_env = \ + self.append(ir.Alloc([], ir.TEnvironment({}), name="assertenv")) + assert_subexprs = self.current_assert_subexprs = [] + init = self.current_block + + prehead = self.current_block = self.add_block() + cond = self.visit(node.test) + head = self.current_block + finally: + self.current_assert_env = None + self.current_assert_subexprs = None + + for subexpr_node, subexpr_name in assert_subexprs: + empty = init.append(ir.Alloc([], ir.TOption(subexpr_node.type))) + init.append(ir.SetLocal(assert_env, subexpr_name, empty)) + init.append(ir.Branch(prehead)) + + if_failed = self.current_block = self.add_block() + + if node.msg: + explanation = node.msg.s + else: + explanation = node.loc.source() + self.append(ir.Builtin("printf", [ + ir.Constant("assertion failed at %s: %s\n", builtins.TStr()), + ir.Constant(str(node.loc.begin()), builtins.TStr()), + ir.Constant(str(explanation), builtins.TStr()), + ], builtins.TNone())) + + for subexpr_node, subexpr_name in assert_subexprs: + subexpr_head = self.current_block + subexpr_value_opt = self.append(ir.GetLocal(assert_env, subexpr_name)) + subexpr_cond = self.append(ir.Builtin("is_some", [subexpr_value_opt], + builtins.TBool())) + + subexpr_body = self.current_block = self.add_block() + self.append(ir.Builtin("printf", [ + ir.Constant(" (%s) = ", builtins.TStr()), + ir.Constant(subexpr_node.loc.source(), builtins.TStr()) + ], builtins.TNone())) + subexpr_value = self.append(ir.Builtin("unwrap", [subexpr_value_opt], + subexpr_node.type)) + self.polymorphic_print([subexpr_value], separator="", suffix="\n") + subexpr_postbody = self.current_block + + subexpr_tail = self.current_block = self.add_block() + self.append(ir.Branch(subexpr_tail), block=subexpr_postbody) + self.append(ir.BranchIf(subexpr_cond, subexpr_body, subexpr_tail), block=subexpr_head) + + self.append(ir.Builtin("abort", [], builtins.TNone())) + self.append(ir.Unreachable()) + + tail = self.current_block = self.add_block() + self.append(ir.BranchIf(cond, tail, if_failed), block=head) + def polymorphic_print(self, values, separator, suffix=""): format_string = "" args = [] diff --git a/artiq/compiler/transforms/asttyped_rewriter.py b/artiq/compiler/transforms/asttyped_rewriter.py index a1d0b7bd6..2b8d03f73 100644 --- a/artiq/compiler/transforms/asttyped_rewriter.py +++ b/artiq/compiler/transforms/asttyped_rewriter.py @@ -421,7 +421,6 @@ class ASTTypedRewriter(algorithm.Transformer): visit_YieldFrom = visit_unsupported # stmt - visit_Assert = visit_unsupported visit_ClassDef = visit_unsupported visit_Delete = visit_unsupported visit_Import = visit_unsupported diff --git a/artiq/compiler/transforms/dead_code_eliminator.py b/artiq/compiler/transforms/dead_code_eliminator.py index 92d82b4a7..1d4e08f6f 100644 --- a/artiq/compiler/transforms/dead_code_eliminator.py +++ b/artiq/compiler/transforms/dead_code_eliminator.py @@ -16,4 +16,25 @@ class DeadCodeEliminator: def process_function(self, func): for block in func.basic_blocks: if not any(block.predecessors()) and block != func.entry(): - block.erase() + self.remove_block(block) + + def remove_block(self, block): + # block.uses are updated while iterating + for use in set(block.uses): + if isinstance(use, ir.Phi): + use.remove_incoming_block(block) + if not any(use.operands): + self.remove_instruction(use) + else: + assert False + + block.erase() + + def remove_instruction(self, insn): + for use in set(insn.uses): + if isinstance(use, ir.Phi): + use.remove_incoming_value(insn) + if not any(use.operands): + self.remove_instruction(use) + + insn.erase() diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index a821059a9..21e803e9b 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -947,3 +947,14 @@ class Inferencer(algorithm.Visitor): else: self._unify(self.function.return_type, node.value.type, self.function.name_loc, node.value.loc, makenotes) + + def visit_Assert(self, node): + self.generic_visit(node) + self._unify(node.test.type, builtins.TBool(), + node.test.loc, None) + if node.msg is not None: + if not isinstance(node.msg, asttyped.StrT): + diag = diagnostic.Diagnostic("error", + "assertion message must be a string literal", {}, + node.msg.loc) + self.engine.process(diag) diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 80908e10b..e73505d7b 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -95,7 +95,9 @@ class LLVMIRGenerator: if llfun is not None: return llfun - if name in ("llvm.abort", "llvm.donothing"): + if name in "llvm.donothing": + llty = ll.FunctionType(ll.VoidType(), []) + elif name in "llvm.trap": llty = ll.FunctionType(ll.VoidType(), []) elif name == "llvm.round.f64": llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()]) @@ -181,6 +183,18 @@ class LLVMIRGenerator: if ir.is_environment(insn.type): return self.llbuilder.alloca(self.llty_of_type(insn.type, bare=True), name=insn.name) + elif ir.is_option(insn.type): + if len(insn.operands) == 0: # empty + llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) + return self.llbuilder.insert_value(llvalue, ll.Constant(ll.IntType(1), False), 0, + name=insn.name) + elif len(insn.operands) == 1: # full + llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) + llvalue = self.llbuilder.insert_value(llvalue, ll.Constant(ll.IntType(1), True), 0) + return self.llbuilder.insert_value(llvalue, self.map(insn.operands[0]), 1, + name=insn.name) + else: + assert False elif builtins.is_list(insn.type): llsize = self.map(insn.operands[0]) llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) @@ -382,7 +396,17 @@ class LLVMIRGenerator: def process_Builtin(self, insn): if insn.op == "nop": return self.llbuilder.call(self.llbuiltin("llvm.donothing"), []) + if insn.op == "abort": + return self.llbuilder.call(self.llbuiltin("llvm.trap"), []) + elif insn.op == "is_some": + optarg = self.map(insn.operands[0]) + return self.llbuilder.extract_value(optarg, 0, + name=insn.name) elif insn.op == "unwrap": + optarg = self.map(insn.operands[0]) + return self.llbuilder.extract_value(optarg, 1, + name=insn.name) + elif insn.op == "unwrap_or": optarg, default = map(self.map, insn.operands) has_arg = self.llbuilder.extract_value(optarg, 0) arg = self.llbuilder.extract_value(optarg, 1) @@ -455,7 +479,7 @@ class LLVMIRGenerator: def process_Raise(self, insn): # TODO: hack before EH is working - llinsn = self.llbuilder.call(self.llbuiltin("llvm.abort"), [], + llinsn = self.llbuilder.call(self.llbuiltin("llvm.trap"), [], name=insn.name) self.llbuilder.unreachable() return llinsn diff --git a/lit-test/compiler/inferencer/error_assert.py b/lit-test/compiler/inferencer/error_assert.py new file mode 100644 index 000000000..1e7c10284 --- /dev/null +++ b/lit-test/compiler/inferencer/error_assert.py @@ -0,0 +1,6 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +x = "A" +# CHECK-L: ${LINE:+1}: error: assertion message must be a string literal +assert True, x