forked from M-Labs/artiq
1
0
Fork 0

py2llvm: remove unnecessary indirection for unary operators

This commit is contained in:
Sebastien Bourdeauducq 2014-09-07 15:09:38 +08:00
parent bce687b4a0
commit c5c38c6376
2 changed files with 12 additions and 38 deletions

View File

@ -54,13 +54,13 @@ class Visitor:
def _visit_expr_UnaryOp(self, node): def _visit_expr_UnaryOp(self, node):
ast_unops = { ast_unops = {
ast.Invert: values.operators.inv, ast.Invert: "o_inv",
ast.Not: values.operators.not_, ast.Not: "o_not",
ast.UAdd: values.operators.pos, ast.UAdd: "o_pos",
ast.USub: values.operators.neg ast.USub: "o_neg"
} }
return ast_unops[type(node.op)](self.visit_expression(node.operand), value = self.visit_expression(node.operand)
self.builder) return getattr(value, ast_unops[type(node.op)])(self.builder)
def _visit_expr_BinOp(self, node): def _visit_expr_BinOp(self, node):
ast_binops = { ast_binops = {
@ -104,17 +104,10 @@ class Visitor:
return r return r
def _visit_expr_Call(self, node): def _visit_expr_Call(self, node):
ast_unfuns = {
"bool": values.operators.bool,
"int": values.operators.int,
"int64": values.operators.int64,
"round": values.operators.round,
"round64": values.operators.round64,
}
fn = node.func.id fn = node.func.id
if fn in ast_unfuns: if fn in {"bool", "int", "int64", "round", "round64"}:
return ast_unfuns[fn](self.visit_expression(node.args[0]), value = self.visit_expression(node.args[0])
self.builder) return getattr(value, "o_"+fn)(self.builder)
elif fn == "Fraction": elif fn == "Fraction":
r = fractions.VFraction() r = fractions.VFraction()
if self.builder is not None: if self.builder is not None:
@ -168,8 +161,7 @@ class Visitor:
else_block = function.append_basic_block("i_else") else_block = function.append_basic_block("i_else")
merge_block = function.append_basic_block("i_merge") merge_block = function.append_basic_block("i_merge")
condition = values.operators.bool(self.visit_expression(node.test), condition = self.visit_expression(node.test).o_bool(self.builder)
self.builder)
self.builder.cbranch(condition.get_ssa_value(self.builder), self.builder.cbranch(condition.get_ssa_value(self.builder),
then_block, else_block) then_block, else_block)
@ -191,16 +183,14 @@ class Visitor:
else_block = function.append_basic_block("w_else") else_block = function.append_basic_block("w_else")
merge_block = function.append_basic_block("w_merge") merge_block = function.append_basic_block("w_merge")
condition = values.operators.bool( condition = self.visit_expression(node.test).o_bool(self.builder)
self.visit_expression(node.test), self.builder)
self.builder.cbranch( self.builder.cbranch(
condition.get_ssa_value(self.builder), body_block, else_block) condition.get_ssa_value(self.builder), body_block, else_block)
self.builder.position_at_end(body_block) self.builder.position_at_end(body_block)
self.visit_statements(node.body) self.visit_statements(node.body)
if not is_terminated(self.builder.basic_block): if not is_terminated(self.builder.basic_block):
condition = values.operators.bool( condition = self.visit_expression(node.test).o_bool(self.builder)
self.visit_expression(node.test), self.builder)
self.builder.cbranch( self.builder.cbranch(
condition.get_ssa_value(self.builder), body_block, merge_block) condition.get_ssa_value(self.builder), body_block, merge_block)

View File

@ -40,18 +40,6 @@ class VGeneric:
return self.o_roundx(64, builder) return self.o_roundx(64, builder)
def _make_unary_operator(op_name):
def op(x, builder):
try:
opf = getattr(x, "o_"+op_name)
except AttributeError:
raise TypeError(
"Unsupported operand type for {}: {}"
.format(op_name, type(x).__name__))
return opf(builder)
return op
def _make_binary_operator(op_name): def _make_binary_operator(op_name):
def op(l, r, builder): def op(l, r, builder):
try: try:
@ -77,10 +65,6 @@ def _make_binary_operator(op_name):
def _make_operators(): def _make_operators():
d = dict() d = dict()
for op_name in ("bool", "int", "int64", "round", "round64",
"inv", "pos", "neg"):
d[op_name] = _make_unary_operator(op_name)
d["not_"] = _make_unary_operator("not")
for op_name in ("add", "sub", "mul", for op_name in ("add", "sub", "mul",
"truediv", "floordiv", "mod", "truediv", "floordiv", "mod",
"pow", "lshift", "rshift", "xor", "pow", "lshift", "rshift", "xor",