forked from M-Labs/artiq
1
0
Fork 0

Add support for Assert.

This commit is contained in:
whitequark 2015-07-22 02:58:59 +03:00
parent 5d518dcec6
commit 236d5b886a
7 changed files with 199 additions and 21 deletions

View File

@ -106,14 +106,13 @@ class User(NamedValue):
def __init__(self, operands, typ, name):
super().__init__(typ, name)
self.operands = []
if operands is not None:
self.set_operands(operands)
def set_operands(self, new_operands):
for operand in self.operands:
for operand in set(self.operands):
operand.uses.remove(self)
self.operands = new_operands
for operand in self.operands:
for operand in set(self.operands):
operand.uses.add(self)
def drop_references(self):
@ -162,6 +161,9 @@ class Instruction(User):
def erase(self):
self.remove_from_parent()
self.drop_references()
# Check this after drop_references in case this
# is a self-referencing phi.
assert not any(self.uses)
def replace_with(self, value):
self.replace_all_uses_with(value)
@ -220,7 +222,21 @@ class Phi(Instruction):
def add_incoming(self, value, block):
assert value.type == self.type
self.operands.append(value)
value.uses.add(self)
self.operands.append(block)
block.uses.add(self)
def remove_incoming_value(self, value):
index = self.operands.index(value)
self.operands[index].uses.remove(self)
self.operands[index + 1].uses.remove(self)
del self.operands[index:index + 2]
def remove_incoming_block(self, block):
index = self.operands.index(block)
self.operands[index - 1].uses.remove(self)
self.operands[index].uses.remove(self)
del self.operands[index - 1:index + 1]
def __str__(self):
if builtins.is_none(self.type):
@ -268,9 +284,13 @@ class BasicBlock(NamedValue):
self.function.remove(self)
def erase(self):
for insn in self.instructions:
# self.instructions is updated while iterating
for insn in list(self.instructions):
insn.erase()
self.remove_from_parent()
# Check this after erasing instructions in case the block
# loops into itself.
assert not any(self.uses)
def prepend(self, insn):
assert isinstance(insn, Instruction)
@ -817,6 +837,7 @@ class Select(Instruction):
"""
def __init__(self, cond, if_true, if_false, name=""):
assert isinstance(cond, Value)
assert builtins.is_bool(cond.type)
assert isinstance(if_true, Value)
assert isinstance(if_false, Value)
assert if_true.type == if_false.type
@ -864,8 +885,10 @@ class BranchIf(Terminator):
"""
def __init__(self, cond, if_true, if_false, name=""):
assert isinstance(cond, Value)
assert builtins.is_bool(cond.type)
assert isinstance(if_true, BasicBlock)
assert isinstance(if_false, BasicBlock)
assert if_true != if_false # use Branch instead
super().__init__([cond, if_true, if_false], builtins.TNone(), name)
def opcode(self):

View File

@ -39,16 +39,22 @@ class ARTIQIRGenerator(algorithm.Visitor):
set of variables that will be resolved in global scope
:ivar current_block: (:class:`ir.BasicBlock`)
basic block to which any new instruction will be appended
:ivar current_env: (:class:`ir.Environment`)
:ivar current_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`)
the chained function environment, containing variables that
can become upvalues
:ivar current_private_env: (:class:`ir.Environment`)
:ivar current_private_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`)
the private function environment, containing internal state
:ivar current_assign: (:class:`ir.Value` or None)
the right-hand side of current assignment statement, or
a component of a composite right-hand side when visiting
a composite left-hand side, such as, in ``x, y = z``,
the 2nd tuple element when visting ``y``
:ivar current_assert_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`)
the environment where the individual components of current assert
statement are stored until display
:ivar current_assert_subexprs: (list of (:class:`ast.AST`, string))
the mapping from components of current assert statement to the names
their values have in :ivar:`current_assert_env`
:ivar break_target: (:class:`ir.BasicBlock` or None)
the basic block to which ``break`` will transfer control
:ivar continue_target: (:class:`ir.BasicBlock` or None)
@ -72,6 +78,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.current_env = None
self.current_private_env = None
self.current_assign = None
self.current_assert_env = None
self.current_assert_subexprs = None
self.break_target = None
self.continue_target = None
self.return_target = None
@ -203,7 +211,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.append(ir.SetLocal(env, arg_name, args[index]))
for index, (arg_name, env_default_name) in enumerate(zip(typ.optargs, defaults)):
default = self.append(ir.GetLocal(self.current_env, env_default_name))
value = self.append(ir.Builtin("unwrap", [optargs[index], default],
value = self.append(ir.Builtin("unwrap_or", [optargs[index], default],
typ.optargs[arg_name]))
self.append(ir.SetLocal(env, arg_name, value))
@ -736,7 +744,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
for index, elt_node in enumerate(node.elts):
self.current_assign = \
self.append(ir.GetAttr(old_assign, index,
name="{}.{}".format(old_assign.name, index)),
name="{}.e{}".format(old_assign.name, index)),
loc=elt_node.loc)
self.visit(elt_node)
finally:
@ -805,18 +813,26 @@ class ARTIQIRGenerator(algorithm.Visitor):
def visit_BoolOpT(self, node):
blocks = []
for value_node in node.values:
value_head = self.current_block
value = self.visit(value_node)
blocks.append((value, self.current_block))
self.instrument_assert(value_node, value)
value_tail = self.current_block
blocks.append((value, value_head, value_tail))
self.current_block = self.add_block()
tail = self.current_block
phi = self.append(ir.Phi(node.type))
for ((value, block), next_block) in zip(blocks, [b for (v,b) in blocks[1:]] + [tail]):
phi.add_incoming(value, block)
for ((value, value_head, value_tail), (next_value_head, next_value_tail)) in \
zip(blocks, [(h,t) for (v,h,t) in blocks[1:]] + [(tail, tail)]):
phi.add_incoming(value, value_tail)
if next_value_head != tail:
if isinstance(node.op, ast.And):
block.append(ir.BranchIf(value, next_block, tail))
value_tail.append(ir.BranchIf(value, next_value_head, tail))
else:
block.append(ir.BranchIf(value, tail, next_block))
value_tail.append(ir.BranchIf(value, tail, next_value_head))
else:
value_tail.append(ir.Branch(tail))
return phi
def visit_UnaryOpT(self, node):
@ -1005,7 +1021,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
ir.Constant(False, builtins.TBool())))
result = self.append(ir.Select(result, on_step,
ir.Constant(False, builtins.TBool())))
elif builtins.isiterable(haystack.type):
elif builtins.is_iterable(haystack.type):
length = self.iterable_len(haystack)
cmp_result = loop_body2 = None
@ -1068,8 +1084,10 @@ class ARTIQIRGenerator(algorithm.Visitor):
# of comparisons.
blocks = []
lhs = self.visit(node.left)
self.instrument_assert(node.left, lhs)
for op, rhs_node in zip(node.ops, node.comparators):
rhs = self.visit(rhs_node)
self.instrument_assert(rhs_node, rhs)
result = self.polymorphic_compare_pair(op, lhs, rhs)
blocks.append((result, self.current_block))
self.current_block = self.add_block()
@ -1079,7 +1097,10 @@ class ARTIQIRGenerator(algorithm.Visitor):
phi = self.append(ir.Phi(node.type))
for ((value, block), next_block) in zip(blocks, [b for (v,b) in blocks[1:]] + [tail]):
phi.add_incoming(value, block)
if next_block != tail:
block.append(ir.BranchIf(value, next_block, tail))
else:
block.append(ir.Branch(tail))
return phi
def visit_builtin_call(self, node):
@ -1211,6 +1232,79 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.current_block = after_invoke
return invoke
def instrument_assert(self, node, value):
if self.current_assert_env is not None:
if isinstance(value, ir.Constant):
return # don't display the values of constants
if any([algorithm.compare(node, subexpr)
for (subexpr, name) in self.current_assert_subexprs]):
return # don't display the same subexpression twice
name = self.current_assert_env.type.add("subexpr", ir.TOption(node.type))
value_opt = self.append(ir.Alloc([value], ir.TOption(node.type)),
loc=node.loc)
self.append(ir.SetLocal(self.current_assert_env, name, value_opt),
loc=node.loc)
self.current_assert_subexprs.append((node, name))
def visit_Assert(self, node):
try:
assert_env = self.current_assert_env = \
self.append(ir.Alloc([], ir.TEnvironment({}), name="assertenv"))
assert_subexprs = self.current_assert_subexprs = []
init = self.current_block
prehead = self.current_block = self.add_block()
cond = self.visit(node.test)
head = self.current_block
finally:
self.current_assert_env = None
self.current_assert_subexprs = None
for subexpr_node, subexpr_name in assert_subexprs:
empty = init.append(ir.Alloc([], ir.TOption(subexpr_node.type)))
init.append(ir.SetLocal(assert_env, subexpr_name, empty))
init.append(ir.Branch(prehead))
if_failed = self.current_block = self.add_block()
if node.msg:
explanation = node.msg.s
else:
explanation = node.loc.source()
self.append(ir.Builtin("printf", [
ir.Constant("assertion failed at %s: %s\n", builtins.TStr()),
ir.Constant(str(node.loc.begin()), builtins.TStr()),
ir.Constant(str(explanation), builtins.TStr()),
], builtins.TNone()))
for subexpr_node, subexpr_name in assert_subexprs:
subexpr_head = self.current_block
subexpr_value_opt = self.append(ir.GetLocal(assert_env, subexpr_name))
subexpr_cond = self.append(ir.Builtin("is_some", [subexpr_value_opt],
builtins.TBool()))
subexpr_body = self.current_block = self.add_block()
self.append(ir.Builtin("printf", [
ir.Constant(" (%s) = ", builtins.TStr()),
ir.Constant(subexpr_node.loc.source(), builtins.TStr())
], builtins.TNone()))
subexpr_value = self.append(ir.Builtin("unwrap", [subexpr_value_opt],
subexpr_node.type))
self.polymorphic_print([subexpr_value], separator="", suffix="\n")
subexpr_postbody = self.current_block
subexpr_tail = self.current_block = self.add_block()
self.append(ir.Branch(subexpr_tail), block=subexpr_postbody)
self.append(ir.BranchIf(subexpr_cond, subexpr_body, subexpr_tail), block=subexpr_head)
self.append(ir.Builtin("abort", [], builtins.TNone()))
self.append(ir.Unreachable())
tail = self.current_block = self.add_block()
self.append(ir.BranchIf(cond, tail, if_failed), block=head)
def polymorphic_print(self, values, separator, suffix=""):
format_string = ""
args = []

View File

@ -421,7 +421,6 @@ class ASTTypedRewriter(algorithm.Transformer):
visit_YieldFrom = visit_unsupported
# stmt
visit_Assert = visit_unsupported
visit_ClassDef = visit_unsupported
visit_Delete = visit_unsupported
visit_Import = visit_unsupported

View File

@ -16,4 +16,25 @@ class DeadCodeEliminator:
def process_function(self, func):
for block in func.basic_blocks:
if not any(block.predecessors()) and block != func.entry():
self.remove_block(block)
def remove_block(self, block):
# block.uses are updated while iterating
for use in set(block.uses):
if isinstance(use, ir.Phi):
use.remove_incoming_block(block)
if not any(use.operands):
self.remove_instruction(use)
else:
assert False
block.erase()
def remove_instruction(self, insn):
for use in set(insn.uses):
if isinstance(use, ir.Phi):
use.remove_incoming_value(insn)
if not any(use.operands):
self.remove_instruction(use)
insn.erase()

View File

@ -947,3 +947,14 @@ class Inferencer(algorithm.Visitor):
else:
self._unify(self.function.return_type, node.value.type,
self.function.name_loc, node.value.loc, makenotes)
def visit_Assert(self, node):
self.generic_visit(node)
self._unify(node.test.type, builtins.TBool(),
node.test.loc, None)
if node.msg is not None:
if not isinstance(node.msg, asttyped.StrT):
diag = diagnostic.Diagnostic("error",
"assertion message must be a string literal", {},
node.msg.loc)
self.engine.process(diag)

View File

@ -95,7 +95,9 @@ class LLVMIRGenerator:
if llfun is not None:
return llfun
if name in ("llvm.abort", "llvm.donothing"):
if name in "llvm.donothing":
llty = ll.FunctionType(ll.VoidType(), [])
elif name in "llvm.trap":
llty = ll.FunctionType(ll.VoidType(), [])
elif name == "llvm.round.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()])
@ -181,6 +183,18 @@ class LLVMIRGenerator:
if ir.is_environment(insn.type):
return self.llbuilder.alloca(self.llty_of_type(insn.type, bare=True),
name=insn.name)
elif ir.is_option(insn.type):
if len(insn.operands) == 0: # empty
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
return self.llbuilder.insert_value(llvalue, ll.Constant(ll.IntType(1), False), 0,
name=insn.name)
elif len(insn.operands) == 1: # full
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
llvalue = self.llbuilder.insert_value(llvalue, ll.Constant(ll.IntType(1), True), 0)
return self.llbuilder.insert_value(llvalue, self.map(insn.operands[0]), 1,
name=insn.name)
else:
assert False
elif builtins.is_list(insn.type):
llsize = self.map(insn.operands[0])
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
@ -382,7 +396,17 @@ class LLVMIRGenerator:
def process_Builtin(self, insn):
if insn.op == "nop":
return self.llbuilder.call(self.llbuiltin("llvm.donothing"), [])
if insn.op == "abort":
return self.llbuilder.call(self.llbuiltin("llvm.trap"), [])
elif insn.op == "is_some":
optarg = self.map(insn.operands[0])
return self.llbuilder.extract_value(optarg, 0,
name=insn.name)
elif insn.op == "unwrap":
optarg = self.map(insn.operands[0])
return self.llbuilder.extract_value(optarg, 1,
name=insn.name)
elif insn.op == "unwrap_or":
optarg, default = map(self.map, insn.operands)
has_arg = self.llbuilder.extract_value(optarg, 0)
arg = self.llbuilder.extract_value(optarg, 1)
@ -455,7 +479,7 @@ class LLVMIRGenerator:
def process_Raise(self, insn):
# TODO: hack before EH is working
llinsn = self.llbuilder.call(self.llbuiltin("llvm.abort"), [],
llinsn = self.llbuilder.call(self.llbuiltin("llvm.trap"), [],
name=insn.name)
self.llbuilder.unreachable()
return llinsn

View File

@ -0,0 +1,6 @@
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
x = "A"
# CHECK-L: ${LINE:+1}: error: assertion message must be a string literal
assert True, x