2
0
mirror of https://github.com/m-labs/artiq.git synced 2024-12-25 19:28:26 +08:00

Remove parts of py2llvm that are implemented in the new compiler.

This commit is contained in:
whitequark 2015-08-10 20:36:39 +03:00
parent 62e6f8a03d
commit 200330a808
23 changed files with 169 additions and 2746 deletions

View File

@ -1,6 +0,0 @@
from artiq.py2llvm.module import Module
def get_runtime_binary(runtime, func_def):
module = Module(runtime)
module.compile_function(func_def, dict())
return module.emit_object()

View File

@ -1,539 +0,0 @@
from pythonparser import ast
import llvmlite_or1k.ir as ll
from artiq.py2llvm import values, base_types, fractions, lists, iterators
from artiq.py2llvm.tools import is_terminated
_ast_unops = {
ast.Invert: "o_inv",
ast.Not: "o_not",
ast.UAdd: "o_pos",
ast.USub: "o_neg"
}
_ast_binops = {
ast.Add: values.operators.add,
ast.Sub: values.operators.sub,
ast.Mult: values.operators.mul,
ast.Div: values.operators.truediv,
ast.FloorDiv: values.operators.floordiv,
ast.Mod: values.operators.mod,
ast.Pow: values.operators.pow,
ast.LShift: values.operators.lshift,
ast.RShift: values.operators.rshift,
ast.BitOr: values.operators.or_,
ast.BitXor: values.operators.xor,
ast.BitAnd: values.operators.and_
}
_ast_cmps = {
ast.Eq: values.operators.eq,
ast.NotEq: values.operators.ne,
ast.Lt: values.operators.lt,
ast.LtE: values.operators.le,
ast.Gt: values.operators.gt,
ast.GtE: values.operators.ge
}
class Visitor:
def __init__(self, runtime, ns, builder=None):
self.runtime = runtime
self.ns = ns
self.builder = builder
self._break_stack = []
self._continue_stack = []
self._active_exception_stack = []
self._exception_level_stack = [0]
# builder can be None for visit_expression
def visit_expression(self, node):
method = "_visit_expr_" + node.__class__.__name__
try:
visitor = getattr(self, method)
except AttributeError:
raise NotImplementedError("Unsupported node '{}' in expression"
.format(node.__class__.__name__))
return visitor(node)
def _visit_expr_Name(self, node):
try:
r = self.ns[node.id]
except KeyError:
raise NameError("Name '{}' is not defined".format(node.id))
return r
def _visit_expr_NameConstant(self, node):
v = node.value
if v is None:
r = base_types.VNone()
elif isinstance(v, bool):
r = base_types.VBool()
else:
raise NotImplementedError
if self.builder is not None:
r.set_const_value(self.builder, v)
return r
def _visit_expr_Num(self, node):
n = node.n
if isinstance(n, int):
if abs(n) < 2**31:
r = base_types.VInt()
else:
r = base_types.VInt(64)
elif isinstance(n, float):
r = base_types.VFloat()
else:
raise NotImplementedError
if self.builder is not None:
r.set_const_value(self.builder, n)
return r
def _visit_expr_UnaryOp(self, node):
value = self.visit_expression(node.operand)
return getattr(value, _ast_unops[type(node.op)])(self.builder)
def _visit_expr_BinOp(self, node):
return _ast_binops[type(node.op)](self.visit_expression(node.left),
self.visit_expression(node.right),
self.builder)
def _visit_expr_BoolOp(self, node):
if self.builder is not None:
initial_block = self.builder.basic_block
function = initial_block.function
merge_block = function.append_basic_block("b_merge")
test_blocks = []
test_values = []
for i, value in enumerate(node.values):
if self.builder is not None:
test_block = function.append_basic_block("b_{}_test".format(i))
test_blocks.append(test_block)
self.builder.position_at_end(test_block)
test_values.append(self.visit_expression(value))
result = test_values[0].new()
for value in test_values[1:]:
result.merge(value)
if self.builder is not None:
self.builder.position_at_end(initial_block)
result.alloca(self.builder, "b_result")
self.builder.branch(test_blocks[0])
next_test_blocks = test_blocks[1:]
next_test_blocks.append(None)
for block, next_block, value in zip(test_blocks,
next_test_blocks,
test_values):
self.builder.position_at_end(block)
bval = value.o_bool(self.builder)
result.auto_store(self.builder,
value.auto_load(self.builder))
if next_block is None:
self.builder.branch(merge_block)
else:
if isinstance(node.op, ast.Or):
self.builder.cbranch(bval.auto_load(self.builder),
merge_block,
next_block)
elif isinstance(node.op, ast.And):
self.builder.cbranch(bval.auto_load(self.builder),
next_block,
merge_block)
else:
raise NotImplementedError
self.builder.position_at_end(merge_block)
return result
def _visit_expr_Compare(self, node):
comparisons = []
old_comparator = self.visit_expression(node.left)
for op, comparator_a in zip(node.ops, node.comparators):
comparator = self.visit_expression(comparator_a)
comparison = _ast_cmps[type(op)](old_comparator, comparator,
self.builder)
comparisons.append(comparison)
old_comparator = comparator
r = comparisons[0]
for comparison in comparisons[1:]:
r = values.operators.and_(r, comparison)
return r
def _visit_expr_Call(self, node):
fn = node.func.id
if fn in {"bool", "int", "int64", "round", "round64", "float", "len"}:
value = self.visit_expression(node.args[0])
return getattr(value, "o_" + fn)(self.builder)
elif fn == "Fraction":
r = fractions.VFraction()
if self.builder is not None:
numerator = self.visit_expression(node.args[0])
denominator = self.visit_expression(node.args[1])
r.set_value_nd(self.builder, numerator, denominator)
return r
elif fn == "range":
return iterators.IRange(
self.builder,
[self.visit_expression(arg) for arg in node.args])
elif fn == "syscall":
return self.runtime.build_syscall(
node.args[0].s,
[self.visit_expression(expr) for expr in node.args[1:]],
self.builder)
else:
raise NameError("Function '{}' is not defined".format(fn))
def _visit_expr_Attribute(self, node):
value = self.visit_expression(node.value)
return value.o_getattr(node.attr, self.builder)
def _visit_expr_List(self, node):
elts = [self.visit_expression(elt) for elt in node.elts]
if elts:
el_type = elts[0].new()
for elt in elts[1:]:
el_type.merge(elt)
else:
el_type = base_types.VNone()
count = len(elts)
r = lists.VList(el_type, count)
r.elts = elts
return r
def _visit_expr_ListComp(self, node):
if len(node.generators) != 1:
raise NotImplementedError
generator = node.generators[0]
if not isinstance(generator, ast.comprehension):
raise NotImplementedError
if not isinstance(generator.target, ast.Name):
raise NotImplementedError
target = generator.target.id
if not isinstance(generator.iter, ast.Call):
raise NotImplementedError
if not isinstance(generator.iter.func, ast.Name):
raise NotImplementedError
if generator.iter.func.id != "range":
raise NotImplementedError
if len(generator.iter.args) != 1:
raise NotImplementedError
if not isinstance(generator.iter.args[0], ast.Num):
raise NotImplementedError
count = generator.iter.args[0].n
# Prevent incorrect use of the generator target, if it is defined in
# the local function namespace.
if target in self.ns:
old_target_val = self.ns[target]
del self.ns[target]
else:
old_target_val = None
elt = self.visit_expression(node.elt)
if old_target_val is not None:
self.ns[target] = old_target_val
el_type = elt.new()
r = lists.VList(el_type, count)
r.elt = elt
return r
def _visit_expr_Subscript(self, node):
value = self.visit_expression(node.value)
if isinstance(node.slice, ast.Index):
index = self.visit_expression(node.slice.value)
else:
raise NotImplementedError
return value.o_subscript(index, self.builder)
def visit_statements(self, stmts):
for node in stmts:
node_type = node.__class__.__name__
method = "_visit_stmt_" + node_type
try:
visitor = getattr(self, method)
except AttributeError:
raise NotImplementedError("Unsupported node '{}' in statement"
.format(node_type))
visitor(node)
if node_type in ("Return", "Break", "Continue"):
break
def _bb_terminated(self):
return is_terminated(self.builder.basic_block)
def _visit_stmt_Assign(self, node):
val = self.visit_expression(node.value)
if isinstance(node.value, ast.List):
if len(node.targets) > 1:
raise NotImplementedError
target = self.visit_expression(node.targets[0])
target.set_count(self.builder, val.alloc_count)
for i, elt in enumerate(val.elts):
idx = base_types.VInt()
idx.set_const_value(self.builder, i)
target.o_subscript(idx, self.builder).set_value(self.builder,
elt)
elif isinstance(node.value, ast.ListComp):
if len(node.targets) > 1:
raise NotImplementedError
target = self.visit_expression(node.targets[0])
target.set_count(self.builder, val.alloc_count)
i = base_types.VInt()
i.alloca(self.builder)
i.auto_store(self.builder, ll.Constant(ll.IntType(32), 0))
function = self.builder.basic_block.function
copy_block = function.append_basic_block("ai_copy")
end_block = function.append_basic_block("ai_end")
self.builder.branch(copy_block)
self.builder.position_at_end(copy_block)
target.o_subscript(i, self.builder).set_value(self.builder,
val.elt)
i.auto_store(self.builder, self.builder.add(
i.auto_load(self.builder),
ll.Constant(ll.IntType(32), 1)))
cont = self.builder.icmp_signed(
"<", i.auto_load(self.builder),
ll.Constant(ll.IntType(32), val.alloc_count))
self.builder.cbranch(cont, copy_block, end_block)
self.builder.position_at_end(end_block)
else:
for target in node.targets:
target = self.visit_expression(target)
target.set_value(self.builder, val)
def _visit_stmt_AugAssign(self, node):
target = self.visit_expression(node.target)
right = self.visit_expression(node.value)
val = _ast_binops[type(node.op)](target, right, self.builder)
target.set_value(self.builder, val)
def _visit_stmt_Expr(self, node):
self.visit_expression(node.value)
def _visit_stmt_If(self, node):
function = self.builder.basic_block.function
then_block = function.append_basic_block("i_then")
else_block = function.append_basic_block("i_else")
merge_block = function.append_basic_block("i_merge")
condition = self.visit_expression(node.test).o_bool(self.builder)
self.builder.cbranch(condition.auto_load(self.builder),
then_block, else_block)
self.builder.position_at_end(then_block)
self.visit_statements(node.body)
if not self._bb_terminated():
self.builder.branch(merge_block)
self.builder.position_at_end(else_block)
self.visit_statements(node.orelse)
if not self._bb_terminated():
self.builder.branch(merge_block)
self.builder.position_at_end(merge_block)
def _enter_loop_body(self, break_block, continue_block):
self._break_stack.append(break_block)
self._continue_stack.append(continue_block)
self._exception_level_stack.append(0)
def _leave_loop_body(self):
self._exception_level_stack.pop()
self._continue_stack.pop()
self._break_stack.pop()
def _visit_stmt_While(self, node):
function = self.builder.basic_block.function
body_block = function.append_basic_block("w_body")
else_block = function.append_basic_block("w_else")
condition = self.visit_expression(node.test).o_bool(self.builder)
self.builder.cbranch(
condition.auto_load(self.builder), body_block, else_block)
continue_block = function.append_basic_block("w_continue")
merge_block = function.append_basic_block("w_merge")
self.builder.position_at_end(body_block)
self._enter_loop_body(merge_block, continue_block)
self.visit_statements(node.body)
self._leave_loop_body()
if not self._bb_terminated():
self.builder.branch(continue_block)
self.builder.position_at_end(continue_block)
condition = self.visit_expression(node.test).o_bool(self.builder)
self.builder.cbranch(
condition.auto_load(self.builder), body_block, merge_block)
self.builder.position_at_end(else_block)
self.visit_statements(node.orelse)
if not self._bb_terminated():
self.builder.branch(merge_block)
self.builder.position_at_end(merge_block)
def _visit_stmt_For(self, node):
function = self.builder.basic_block.function
it = self.visit_expression(node.iter)
target = self.visit_expression(node.target)
itval = it.get_value_ptr()
body_block = function.append_basic_block("f_body")
else_block = function.append_basic_block("f_else")
cont = it.o_next(self.builder)
self.builder.cbranch(
cont.auto_load(self.builder), body_block, else_block)
continue_block = function.append_basic_block("f_continue")
merge_block = function.append_basic_block("f_merge")
self.builder.position_at_end(body_block)
target.set_value(self.builder, itval)
self._enter_loop_body(merge_block, continue_block)
self.visit_statements(node.body)
self._leave_loop_body()
if not self._bb_terminated():
self.builder.branch(continue_block)
self.builder.position_at_end(continue_block)
cont = it.o_next(self.builder)
self.builder.cbranch(
cont.auto_load(self.builder), body_block, merge_block)
self.builder.position_at_end(else_block)
self.visit_statements(node.orelse)
if not self._bb_terminated():
self.builder.branch(merge_block)
self.builder.position_at_end(merge_block)
def _break_loop_body(self, target_block):
exception_levels = self._exception_level_stack[-1]
if exception_levels:
self.runtime.build_pop(self.builder, exception_levels)
self.builder.branch(target_block)
def _visit_stmt_Break(self, node):
self._break_loop_body(self._break_stack[-1])
def _visit_stmt_Continue(self, node):
self._break_loop_body(self._continue_stack[-1])
def _visit_stmt_Return(self, node):
if node.value is None:
val = base_types.VNone()
else:
val = self.visit_expression(node.value)
exception_levels = sum(self._exception_level_stack)
if exception_levels:
self.runtime.build_pop(self.builder, exception_levels)
if isinstance(val, base_types.VNone):
self.builder.ret_void()
else:
self.builder.ret(val.auto_load(self.builder))
def _visit_stmt_Pass(self, node):
pass
def _visit_stmt_Raise(self, node):
if self._active_exception_stack:
finally_block, propagate, propagate_eid = (
self._active_exception_stack[-1])
self.builder.store(ll.Constant(ll.IntType(1), 1), propagate)
if node.exc is not None:
eid = ll.Constant(ll.IntType(32), node.exc.args[0].n)
self.builder.store(eid, propagate_eid)
self.builder.branch(finally_block)
else:
eid = ll.Constant(ll.IntType(32), node.exc.args[0].n)
self.runtime.build_raise(self.builder, eid)
def _handle_exception(self, function, finally_block,
propagate, propagate_eid, handlers):
eid = self.runtime.build_getid(self.builder)
self._active_exception_stack.append(
(finally_block, propagate, propagate_eid))
self.builder.store(ll.Constant(ll.IntType(1), 1), propagate)
self.builder.store(eid, propagate_eid)
for handler in handlers:
handled_exc_block = function.append_basic_block("try_exc_h")
cont_exc_block = function.append_basic_block("try_exc_c")
if handler.type is None:
self.builder.branch(handled_exc_block)
else:
if isinstance(handler.type, ast.Tuple):
match = self.builder.icmp_signed(
"==", eid,
ll.Constant(ll.IntType(32),
handler.type.elts[0].args[0].n))
for elt in handler.type.elts[1:]:
match = self.builder.or_(
match,
self.builder.icmp_signed(
"==", eid,
ll.Constant(ll.IntType(32), elt.args[0].n)))
else:
match = self.builder.icmp_signed(
"==", eid,
ll.Constant(ll.IntType(32), handler.type.args[0].n))
self.builder.cbranch(match, handled_exc_block, cont_exc_block)
self.builder.position_at_end(handled_exc_block)
self.builder.store(ll.Constant(ll.IntType(1), 0), propagate)
self.visit_statements(handler.body)
if not self._bb_terminated():
self.builder.branch(finally_block)
self.builder.position_at_end(cont_exc_block)
self.builder.branch(finally_block)
self._active_exception_stack.pop()
def _visit_stmt_Try(self, node):
function = self.builder.basic_block.function
noexc_block = function.append_basic_block("try_noexc")
exc_block = function.append_basic_block("try_exc")
finally_block = function.append_basic_block("try_finally")
propagate = self.builder.alloca(ll.IntType(1),
name="propagate")
self.builder.store(ll.Constant(ll.IntType(1), 0), propagate)
propagate_eid = self.builder.alloca(ll.IntType(32),
name="propagate_eid")
exception_occured = self.runtime.build_catch(self.builder)
self.builder.cbranch(exception_occured, exc_block, noexc_block)
self.builder.position_at_end(noexc_block)
self._exception_level_stack[-1] += 1
self.visit_statements(node.body)
self._exception_level_stack[-1] -= 1
if not self._bb_terminated():
self.runtime.build_pop(self.builder, 1)
self.visit_statements(node.orelse)
if not self._bb_terminated():
self.builder.branch(finally_block)
self.builder.position_at_end(exc_block)
self._handle_exception(function, finally_block,
propagate, propagate_eid, node.handlers)
propagate_block = function.append_basic_block("try_propagate")
merge_block = function.append_basic_block("try_merge")
self.builder.position_at_end(finally_block)
self.visit_statements(node.finalbody)
if not self._bb_terminated():
self.builder.cbranch(
self.builder.load(propagate),
propagate_block, merge_block)
self.builder.position_at_end(propagate_block)
self.runtime.build_raise(self.builder, self.builder.load(propagate_eid))
self.builder.branch(merge_block)
self.builder.position_at_end(merge_block)

View File

@ -1,321 +0,0 @@
import llvmlite_or1k.ir as ll
from artiq.py2llvm.values import VGeneric
class VNone(VGeneric):
def get_llvm_type(self):
return ll.VoidType()
def alloca(self, builder, name):
pass
def set_const_value(self, builder, v):
assert v is None
def set_value(self, builder, other):
if not isinstance(other, VNone):
raise TypeError
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_const_value(builder, False)
return r
def o_not(self, builder):
r = VBool()
if builder is not None:
r.set_const_value(builder, True)
return r
class VInt(VGeneric):
def __init__(self, nbits=32):
VGeneric.__init__(self)
self.nbits = nbits
def get_llvm_type(self):
return ll.IntType(self.nbits)
def __repr__(self):
return "<VInt:{}>".format(self.nbits)
def same_type(self, other):
return isinstance(other, VInt) and other.nbits == self.nbits
def merge(self, other):
if isinstance(other, VInt) and not isinstance(other, VBool):
if other.nbits > self.nbits:
self.nbits = other.nbits
else:
raise TypeError("Incompatible types: {} and {}"
.format(repr(self), repr(other)))
def set_value(self, builder, n):
self.auto_store(
builder, n.o_intx(self.nbits, builder).auto_load(builder))
def set_const_value(self, builder, n):
self.auto_store(builder, ll.Constant(self.get_llvm_type(), n))
def o_bool(self, builder, inv=False):
r = VBool()
if builder is not None:
r.auto_store(
builder, builder.icmp_signed(
"==" if inv else "!=",
self.auto_load(builder),
ll.Constant(self.get_llvm_type(), 0)))
return r
def o_float(self, builder):
r = VFloat()
if builder is not None:
if isinstance(self, VBool):
cf = builder.uitofp
else:
cf = builder.sitofp
r.auto_store(builder, cf(self.auto_load(builder),
r.get_llvm_type()))
return r
def o_not(self, builder):
return self.o_bool(builder, inv=True)
def o_neg(self, builder):
r = VInt(self.nbits)
if builder is not None:
r.auto_store(
builder, builder.mul(
self.auto_load(builder),
ll.Constant(self.get_llvm_type(), -1)))
return r
def o_intx(self, target_bits, builder):
r = VInt(target_bits)
if builder is not None:
if self.nbits == target_bits:
r.auto_store(
builder, self.auto_load(builder))
if self.nbits > target_bits:
r.auto_store(
builder, builder.trunc(self.auto_load(builder),
r.get_llvm_type()))
if self.nbits < target_bits:
if isinstance(self, VBool):
ef = builder.zext
else:
ef = builder.sext
r.auto_store(
builder, ef(self.auto_load(builder),
r.get_llvm_type()))
return r
o_roundx = o_intx
def o_truediv(self, other, builder):
if isinstance(other, VInt):
left = self.o_float(builder)
right = other.o_float(builder)
return left.o_truediv(right, builder)
else:
return NotImplemented
def _make_vint_binop_method(builder_name, bool_op):
def binop_method(self, other, builder):
if isinstance(other, VInt):
target_bits = max(self.nbits, other.nbits)
if not bool_op and target_bits == 1:
target_bits = 32
if bool_op and target_bits == 1:
r = VBool()
else:
r = VInt(target_bits)
if builder is not None:
left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
bf = getattr(builder, builder_name)
r.auto_store(
builder, bf(left.auto_load(builder),
right.auto_load(builder)))
return r
else:
return NotImplemented
return binop_method
for _method_name, _builder_name, _bool_op in (("o_add", "add", False),
("o_sub", "sub", False),
("o_mul", "mul", False),
("o_floordiv", "sdiv", False),
("o_mod", "srem", False),
("o_and", "and_", True),
("o_xor", "xor", True),
("o_or", "or_", True)):
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name, _bool_op))
def _make_vint_cmp_method(icmp_val):
def cmp_method(self, other, builder):
if isinstance(other, VInt):
r = VBool()
if builder is not None:
target_bits = max(self.nbits, other.nbits)
left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder)
r.auto_store(
builder,
builder.icmp_signed(
icmp_val, left.auto_load(builder),
right.auto_load(builder)))
return r
else:
return NotImplemented
return cmp_method
for _method_name, _icmp_val in (("o_eq", "=="),
("o_ne", "!="),
("o_lt", "<"),
("o_le", "<="),
("o_gt", ">"),
("o_ge", ">=")):
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
class VBool(VInt):
def __init__(self):
VInt.__init__(self, 1)
__repr__ = VGeneric.__repr__
same_type = VGeneric.same_type
merge = VGeneric.merge
def set_const_value(self, builder, b):
VInt.set_const_value(self, builder, int(b))
class VFloat(VGeneric):
def get_llvm_type(self):
return ll.DoubleType()
def set_value(self, builder, v):
if not isinstance(v, VFloat):
raise TypeError
self.auto_store(builder, v.auto_load(builder))
def set_const_value(self, builder, n):
self.auto_store(builder, ll.Constant(self.get_llvm_type(), n))
def o_float(self, builder):
r = VFloat()
if builder is not None:
r.auto_store(builder, self.auto_load(builder))
return r
def o_bool(self, builder, inv=False):
r = VBool()
if builder is not None:
r.auto_store(
builder, builder.fcmp_ordered(
"==" if inv else "!=",
self.auto_load(builder),
ll.Constant(self.get_llvm_type(), 0.0)))
return r
def o_not(self, builder):
return self.o_bool(builder, True)
def o_neg(self, builder):
r = VFloat()
if builder is not None:
r.auto_store(
builder, builder.fmul(
self.auto_load(builder),
ll.Constant(self.get_llvm_type(), -1.0)))
return r
def o_intx(self, target_bits, builder):
r = VInt(target_bits)
if builder is not None:
r.auto_store(builder, builder.fptosi(self.auto_load(builder),
r.get_llvm_type()))
return r
def o_roundx(self, target_bits, builder):
r = VInt(target_bits)
if builder is not None:
function = builder.basic_block.function
neg_block = function.append_basic_block("fr_neg")
merge_block = function.append_basic_block("fr_merge")
half = VFloat()
half.alloca(builder, "half")
half.set_const_value(builder, 0.5)
condition = builder.fcmp_ordered(
"<",
self.auto_load(builder),
ll.Constant(self.get_llvm_type(), 0.0))
builder.cbranch(condition, neg_block, merge_block)
builder.position_at_end(neg_block)
half.set_const_value(builder, -0.5)
builder.branch(merge_block)
builder.position_at_end(merge_block)
s = builder.fadd(self.auto_load(builder), half.auto_load(builder))
r.auto_store(builder, builder.fptosi(s, r.get_llvm_type()))
return r
def o_floordiv(self, other, builder):
return self.o_truediv(other, builder).o_int64(builder).o_float(builder)
def _make_vfloat_binop_method(builder_name, reverse):
def binop_method(self, other, builder):
if not hasattr(other, "o_float"):
return NotImplemented
r = VFloat()
if builder is not None:
left = self.o_float(builder)
right = other.o_float(builder)
if reverse:
left, right = right, left
bf = getattr(builder, builder_name)
r.auto_store(
builder, bf(left.auto_load(builder),
right.auto_load(builder)))
return r
return binop_method
for _method_name, _builder_name in (("add", "fadd"),
("sub", "fsub"),
("mul", "fmul"),
("truediv", "fdiv")):
setattr(VFloat, "o_" + _method_name,
_make_vfloat_binop_method(_builder_name, False))
setattr(VFloat, "or_" + _method_name,
_make_vfloat_binop_method(_builder_name, True))
def _make_vfloat_cmp_method(fcmp_val):
def cmp_method(self, other, builder):
if not hasattr(other, "o_float"):
return NotImplemented
r = VBool()
if builder is not None:
left = self.o_float(builder)
right = other.o_float(builder)
r.auto_store(
builder,
builder.fcmp_ordered(
fcmp_val, left.auto_load(builder),
right.auto_load(builder)))
return r
return cmp_method
for _method_name, _fcmp_val in (("o_eq", "=="),
("o_ne", "!="),
("o_lt", "<"),
("o_le", "<="),
("o_gt", ">"),
("o_ge", ">=")):
setattr(VFloat, _method_name, _make_vfloat_cmp_method(_fcmp_val))

View File

@ -1,75 +0,0 @@
import pythonparser.algorithm
from pythonparser import ast
from copy import deepcopy
from artiq.py2llvm.ast_body import Visitor
from artiq.py2llvm import base_types
class _TypeScanner(pythonparser.algorithm.Visitor):
def __init__(self, env, ns):
self.exprv = Visitor(env, ns)
def _update_target(self, target, val):
ns = self.exprv.ns
if isinstance(target, ast.Name):
if target.id in ns:
ns[target.id].merge(val)
else:
ns[target.id] = deepcopy(val)
elif isinstance(target, ast.Subscript):
target = target.value
levels = 0
while isinstance(target, ast.Subscript):
target = target.value
levels += 1
if isinstance(target, ast.Name):
target_value = ns[target.id]
for i in range(levels):
target_value = target_value.o_subscript(None, None)
target_value.merge_subscript(val)
else:
raise NotImplementedError
else:
raise NotImplementedError
def visit_Assign(self, node):
val = self.exprv.visit_expression(node.value)
for target in node.targets:
self._update_target(target, val)
def visit_AugAssign(self, node):
val = self.exprv.visit_expression(ast.BinOp(
op=node.op, left=node.target, right=node.value))
self._update_target(node.target, val)
def visit_For(self, node):
it = self.exprv.visit_expression(node.iter)
self._update_target(node.target, it.get_value_ptr())
self.generic_visit(node)
def visit_Return(self, node):
if node.value is None:
val = base_types.VNone()
else:
val = self.exprv.visit_expression(node.value)
ns = self.exprv.ns
if "return" in ns:
ns["return"].merge(val)
else:
ns["return"] = deepcopy(val)
def infer_function_types(env, node, param_types):
ns = deepcopy(param_types)
ts = _TypeScanner(env, ns)
ts.visit(node)
while True:
prev_ns = deepcopy(ns)
ts = _TypeScanner(env, ns)
ts.visit(node)
if all(v.same_type(prev_ns[k]) for k, v in ns.items()):
# no more promotions - completed
if "return" not in ns:
ns["return"] = base_types.VNone()
return ns

View File

@ -1,51 +0,0 @@
from artiq.py2llvm.values import operators
from artiq.py2llvm.base_types import VInt
class IRange:
def __init__(self, builder, args):
minimum, step = None, None
if len(args) == 1:
maximum = args[0]
elif len(args) == 2:
minimum, maximum = args
else:
minimum, maximum, step = args
if minimum is None:
minimum = VInt()
if builder is not None:
minimum.set_const_value(builder, 0)
if step is None:
step = VInt()
if builder is not None:
step.set_const_value(builder, 1)
self._counter = minimum.new()
self._counter.merge(maximum)
self._counter.merge(step)
self._minimum = self._counter.new()
self._maximum = self._counter.new()
self._step = self._counter.new()
if builder is not None:
self._minimum.alloca(builder, "irange_min")
self._maximum.alloca(builder, "irange_max")
self._step.alloca(builder, "irange_step")
self._counter.alloca(builder, "irange_count")
self._minimum.set_value(builder, minimum)
self._maximum.set_value(builder, maximum)
self._step.set_value(builder, step)
counter_init = operators.sub(self._minimum, self._step, builder)
self._counter.set_value(builder, counter_init)
# must be a pointer value that can be dereferenced anytime
# to get the current value of the iterator
def get_value_ptr(self):
return self._counter
def o_next(self, builder):
self._counter.set_value(
builder,
operators.add(self._counter, self._step, builder))
return operators.lt(self._counter, self._maximum, builder)

View File

@ -1,72 +0,0 @@
import llvmlite_or1k.ir as ll
from artiq.py2llvm.values import VGeneric
from artiq.py2llvm.base_types import VInt, VNone
class VList(VGeneric):
def __init__(self, el_type, alloc_count):
VGeneric.__init__(self)
self.el_type = el_type
self.alloc_count = alloc_count
def get_llvm_type(self):
count = 0 if self.alloc_count is None else self.alloc_count
if isinstance(self.el_type, VNone):
return ll.LiteralStructType([ll.IntType(32)])
else:
return ll.LiteralStructType([
ll.IntType(32), ll.ArrayType(self.el_type.get_llvm_type(),
count)])
def __repr__(self):
return "<VList:{} x{}>".format(
repr(self.el_type),
"?" if self.alloc_count is None else self.alloc_count)
def same_type(self, other):
return (isinstance(other, VList)
and self.el_type.same_type(other.el_type))
def merge(self, other):
if isinstance(other, VList):
if self.alloc_count:
if other.alloc_count:
self.el_type.merge(other.el_type)
if self.alloc_count < other.alloc_count:
self.alloc_count = other.alloc_count
else:
self.el_type = other.el_type.new()
self.alloc_count = other.alloc_count
else:
raise TypeError("Incompatible types: {} and {}"
.format(repr(self), repr(other)))
def merge_subscript(self, other):
self.el_type.merge(other)
def set_count(self, builder, count):
count_ptr = builder.gep(self.llvm_value, [
ll.Constant(ll.IntType(32), 0),
ll.Constant(ll.IntType(32), 0)])
builder.store(ll.Constant(ll.IntType(32), count), count_ptr)
def o_len(self, builder):
r = VInt()
if builder is not None:
count_ptr = builder.gep(self.llvm_value, [
ll.Constant(ll.IntType(32), 0),
ll.Constant(ll.IntType(32), 0)])
r.auto_store(builder, builder.load(count_ptr))
return r
def o_subscript(self, index, builder):
r = self.el_type.new()
if builder is not None and not isinstance(r, VNone):
index = index.o_int(builder).auto_load(builder)
ssa_r = builder.gep(self.llvm_value, [
ll.Constant(ll.IntType(32), 0),
ll.Constant(ll.IntType(32), 1),
index])
r.auto_store(builder, ssa_r)
return r

View File

@ -1,62 +0,0 @@
import llvmlite_or1k.ir as ll
import llvmlite_or1k.binding as llvm
from artiq.py2llvm import infer_types, ast_body, base_types, fractions, tools
class Module:
def __init__(self, runtime=None):
self.llvm_module = ll.Module("main")
self.runtime = runtime
if self.runtime is not None:
self.runtime.init_module(self)
fractions.init_module(self)
def finalize(self):
self.llvm_module_ref = llvm.parse_assembly(str(self.llvm_module))
pmb = llvm.create_pass_manager_builder()
pmb.opt_level = 2
pm = llvm.create_module_pass_manager()
pmb.populate(pm)
pm.run(self.llvm_module_ref)
def get_ee(self):
self.finalize()
tm = llvm.Target.from_default_triple().create_target_machine()
ee = llvm.create_mcjit_compiler(self.llvm_module_ref, tm)
ee.finalize_object()
return ee
def emit_object(self):
self.finalize()
return self.runtime.emit_object()
def compile_function(self, func_def, param_types):
ns = infer_types.infer_function_types(self.runtime, func_def, param_types)
retval = ns["return"]
function_type = ll.FunctionType(retval.get_llvm_type(),
[ns[arg.arg].get_llvm_type() for arg in func_def.args.args])
function = ll.Function(self.llvm_module, function_type, func_def.name)
bb = function.append_basic_block("entry")
builder = ll.IRBuilder()
builder.position_at_end(bb)
for arg_ast, arg_llvm in zip(func_def.args.args, function.args):
arg_llvm.name = arg_ast.arg
for k, v in ns.items():
v.alloca(builder, k)
for arg_ast, arg_llvm in zip(func_def.args.args, function.args):
ns[arg_ast.arg].auto_store(builder, arg_llvm)
visitor = ast_body.Visitor(self.runtime, ns, builder)
visitor.visit_statements(func_def.body)
if not tools.is_terminated(builder.basic_block):
if isinstance(retval, base_types.VNone):
builder.ret_void()
else:
builder.ret(retval.auto_load(builder))
return function, retval

View File

@ -0,0 +1,169 @@
import unittest
from pythonparser import parse, ast
import inspect
from fractions import Fraction
from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double
import struct
import llvmlite_or1k.binding as llvm
from artiq.language.core import int64
from artiq.py2llvm.infer_types import infer_function_types
from artiq.py2llvm import base_types, lists
from artiq.py2llvm.module import Module
def simplify_encode(a, b):
f = Fraction(a, b)
return f.numerator*1000 + f.denominator
def frac_arith_encode(op, a, b, c, d):
if op == 0:
f = Fraction(a, b) - Fraction(c, d)
elif op == 1:
f = Fraction(a, b) + Fraction(c, d)
elif op == 2:
f = Fraction(a, b) * Fraction(c, d)
else:
f = Fraction(a, b) / Fraction(c, d)
return f.numerator*1000 + f.denominator
def frac_arith_encode_int(op, a, b, x):
if op == 0:
f = Fraction(a, b) - x
elif op == 1:
f = Fraction(a, b) + x
elif op == 2:
f = Fraction(a, b) * x
else:
f = Fraction(a, b) / x
return f.numerator*1000 + f.denominator
def frac_arith_encode_int_rev(op, a, b, x):
if op == 0:
f = x - Fraction(a, b)
elif op == 1:
f = x + Fraction(a, b)
elif op == 2:
f = x * Fraction(a, b)
else:
f = x / Fraction(a, b)
return f.numerator*1000 + f.denominator
def frac_arith_float(op, a, b, x):
if op == 0:
return Fraction(a, b) - x
elif op == 1:
return Fraction(a, b) + x
elif op == 2:
return Fraction(a, b) * x
else:
return Fraction(a, b) / x
def frac_arith_float_rev(op, a, b, x):
if op == 0:
return x - Fraction(a, b)
elif op == 1:
return x + Fraction(a, b)
elif op == 2:
return x * Fraction(a, b)
else:
return x / Fraction(a, b)
class CodeGenCase(unittest.TestCase):
def test_frac_simplify(self):
simplify_encode_c = CompiledFunction(
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
for a in _test_range():
for b in _test_range():
self.assertEqual(
simplify_encode_c(a, b), simplify_encode(a, b))
def _test_frac_arith(self, op):
frac_arith_encode_c = CompiledFunction(
frac_arith_encode, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"c": base_types.VInt(), "d": base_types.VInt()})
for a in _test_range():
for b in _test_range():
for c in _test_range():
for d in _test_range():
self.assertEqual(
frac_arith_encode_c(op, a, b, c, d),
frac_arith_encode(op, a, b, c, d))
def test_frac_add(self):
self._test_frac_arith(0)
def test_frac_sub(self):
self._test_frac_arith(1)
def test_frac_mul(self):
self._test_frac_arith(2)
def test_frac_div(self):
self._test_frac_arith(3)
def _test_frac_arith_int(self, op, rev):
f = frac_arith_encode_int_rev if rev else frac_arith_encode_int
f_c = CompiledFunction(f, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"x": base_types.VInt()})
for a in _test_range():
for b in _test_range():
for x in _test_range():
self.assertEqual(
f_c(op, a, b, x),
f(op, a, b, x))
def test_frac_add_int(self):
self._test_frac_arith_int(0, False)
self._test_frac_arith_int(0, True)
def test_frac_sub_int(self):
self._test_frac_arith_int(1, False)
self._test_frac_arith_int(1, True)
def test_frac_mul_int(self):
self._test_frac_arith_int(2, False)
self._test_frac_arith_int(2, True)
def test_frac_div_int(self):
self._test_frac_arith_int(3, False)
self._test_frac_arith_int(3, True)
def _test_frac_arith_float(self, op, rev):
f = frac_arith_float_rev if rev else frac_arith_float
f_c = CompiledFunction(f, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"x": base_types.VFloat()})
for a in _test_range():
for b in _test_range():
for x in _test_range():
self.assertAlmostEqual(
f_c(op, a, b, x/2),
f(op, a, b, x/2))
def test_frac_add_float(self):
self._test_frac_arith_float(0, False)
self._test_frac_arith_float(0, True)
def test_frac_sub_float(self):
self._test_frac_arith_float(1, False)
self._test_frac_arith_float(1, True)
def test_frac_mul_float(self):
self._test_frac_arith_float(2, False)
self._test_frac_arith_float(2, True)
def test_frac_div_float(self):
self._test_frac_arith_float(3, False)
self._test_frac_arith_float(3, True)

View File

@ -1,5 +0,0 @@
import llvmlite_or1k.ir as ll
def is_terminated(basic_block):
return (basic_block.instructions
and isinstance(basic_block.instructions[-1], ll.Terminator))

View File

@ -1,94 +0,0 @@
from types import SimpleNamespace
from copy import copy
import llvmlite_or1k.ir as ll
class VGeneric:
def __init__(self):
self.llvm_value = None
def new(self):
r = copy(self)
r.llvm_value = None
return r
def __repr__(self):
return "<" + self.__class__.__name__ + ">"
def same_type(self, other):
return isinstance(other, self.__class__)
def merge(self, other):
if not self.same_type(other):
raise TypeError("Incompatible types: {} and {}"
.format(repr(self), repr(other)))
def auto_load(self, builder):
if isinstance(self.llvm_value.type, ll.PointerType):
return builder.load(self.llvm_value)
else:
return self.llvm_value
def auto_store(self, builder, llvm_value):
if self.llvm_value is None:
self.llvm_value = llvm_value
elif isinstance(self.llvm_value.type, ll.PointerType):
builder.store(llvm_value, self.llvm_value)
else:
raise RuntimeError(
"Attempted to set LLVM SSA value multiple times")
def alloca(self, builder, name=""):
if self.llvm_value is not None:
raise RuntimeError("Attempted to alloca existing LLVM value "+name)
self.llvm_value = builder.alloca(self.get_llvm_type(), name=name)
def o_int(self, builder):
return self.o_intx(32, builder)
def o_int64(self, builder):
return self.o_intx(64, builder)
def o_round(self, builder):
return self.o_roundx(32, builder)
def o_round64(self, builder):
return self.o_roundx(64, builder)
def _make_binary_operator(op_name):
def op(l, r, builder):
try:
opf = getattr(l, "o_" + op_name)
except AttributeError:
result = NotImplemented
else:
result = opf(r, builder)
if result is NotImplemented:
try:
ropf = getattr(r, "or_" + op_name)
except AttributeError:
result = NotImplemented
else:
result = ropf(l, builder)
if result is NotImplemented:
raise TypeError(
"Unsupported operand types for {}: {} and {}"
.format(op_name, type(l).__name__, type(r).__name__))
return result
return op
def _make_operators():
d = dict()
for op_name in ("add", "sub", "mul",
"truediv", "floordiv", "mod",
"pow", "lshift", "rshift", "xor",
"eq", "ne", "lt", "le", "gt", "ge"):
d[op_name] = _make_binary_operator(op_name)
d["and_"] = _make_binary_operator("and")
d["or_"] = _make_binary_operator("or")
return SimpleNamespace(**d)
operators = _make_operators()

View File

@ -1,372 +0,0 @@
import unittest
from pythonparser import parse, ast
import inspect
from fractions import Fraction
from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double
import struct
import llvmlite_or1k.binding as llvm
from artiq.language.core import int64
from artiq.py2llvm.infer_types import infer_function_types
from artiq.py2llvm import base_types, lists
from artiq.py2llvm.module import Module
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
if struct.calcsize("P") < 8:
from ctypes import _dlopen, RTLD_GLOBAL
_dlopen("libgcc_s.so", RTLD_GLOBAL)
def _base_types(choice):
a = 2 # promoted later to int64
b = a + 1 # initially int32, becomes int64 after a is promoted
c = b//2 # initially int32, becomes int64 after b is promoted
d = 4 and 5 # stays int32
x = int64(7)
a += x # promotes a to int64
foo = True | True or False
bar = None
myf = 4.5
myf2 = myf + x
if choice and foo and not bar:
return d
elif myf2:
return x + c
else:
return int64(8)
def _build_function_types(f):
return infer_function_types(
None, parse(inspect.getsource(f)),
dict())
class FunctionBaseTypesCase(unittest.TestCase):
def setUp(self):
self.ns = _build_function_types(_base_types)
def test_simple_types(self):
self.assertIsInstance(self.ns["foo"], base_types.VBool)
self.assertIsInstance(self.ns["bar"], base_types.VNone)
self.assertIsInstance(self.ns["d"], base_types.VInt)
self.assertEqual(self.ns["d"].nbits, 32)
self.assertIsInstance(self.ns["x"], base_types.VInt)
self.assertEqual(self.ns["x"].nbits, 64)
self.assertIsInstance(self.ns["myf"], base_types.VFloat)
self.assertIsInstance(self.ns["myf2"], base_types.VFloat)
def test_promotion(self):
for v in "abc":
self.assertIsInstance(self.ns[v], base_types.VInt)
self.assertEqual(self.ns[v].nbits, 64)
def test_return(self):
self.assertIsInstance(self.ns["return"], base_types.VInt)
self.assertEqual(self.ns["return"].nbits, 64)
def test_list_types():
a = [0, 0, 0, 0, 0]
for i in range(2):
a[i] = int64(8)
return a
class FunctionListTypesCase(unittest.TestCase):
def setUp(self):
self.ns = _build_function_types(test_list_types)
def test_list_types(self):
self.assertIsInstance(self.ns["a"], lists.VList)
self.assertIsInstance(self.ns["a"].el_type, base_types.VInt)
self.assertEqual(self.ns["a"].el_type.nbits, 64)
self.assertEqual(self.ns["a"].alloc_count, 5)
self.assertIsInstance(self.ns["i"], base_types.VInt)
self.assertEqual(self.ns["i"].nbits, 32)
def _value_to_ctype(v):
if isinstance(v, base_types.VBool):
return c_int
elif isinstance(v, base_types.VInt):
if v.nbits == 32:
return c_int32
elif v.nbits == 64:
return c_int64
else:
raise NotImplementedError(str(v))
elif isinstance(v, base_types.VFloat):
return c_double
else:
raise NotImplementedError(str(v))
class CompiledFunction:
def __init__(self, function, param_types):
module = Module()
func_def = parse(inspect.getsource(function)).body[0]
function, retval = module.compile_function(func_def, param_types)
argvals = [param_types[arg.arg] for arg in func_def.args.args]
ee = module.get_ee()
cfptr = ee.get_pointer_to_global(
module.llvm_module_ref.get_function(function.name))
retval_ctype = _value_to_ctype(retval)
argval_ctypes = [_value_to_ctype(argval) for argval in argvals]
self.cfunc = CFUNCTYPE(retval_ctype, *argval_ctypes)(cfptr)
# HACK: prevent garbage collection of self.cfunc internals
self.ee = ee
def __call__(self, *args):
return self.cfunc(*args)
def arith(op, a, b):
if op == 0:
return a + b
elif op == 1:
return a - b
elif op == 2:
return a * b
else:
return a / b
def is_prime(x):
d = 2
while d*d <= x:
if not x % d:
return False
d += 1
return True
def simplify_encode(a, b):
f = Fraction(a, b)
return f.numerator*1000 + f.denominator
def frac_arith_encode(op, a, b, c, d):
if op == 0:
f = Fraction(a, b) - Fraction(c, d)
elif op == 1:
f = Fraction(a, b) + Fraction(c, d)
elif op == 2:
f = Fraction(a, b) * Fraction(c, d)
else:
f = Fraction(a, b) / Fraction(c, d)
return f.numerator*1000 + f.denominator
def frac_arith_encode_int(op, a, b, x):
if op == 0:
f = Fraction(a, b) - x
elif op == 1:
f = Fraction(a, b) + x
elif op == 2:
f = Fraction(a, b) * x
else:
f = Fraction(a, b) / x
return f.numerator*1000 + f.denominator
def frac_arith_encode_int_rev(op, a, b, x):
if op == 0:
f = x - Fraction(a, b)
elif op == 1:
f = x + Fraction(a, b)
elif op == 2:
f = x * Fraction(a, b)
else:
f = x / Fraction(a, b)
return f.numerator*1000 + f.denominator
def frac_arith_float(op, a, b, x):
if op == 0:
return Fraction(a, b) - x
elif op == 1:
return Fraction(a, b) + x
elif op == 2:
return Fraction(a, b) * x
else:
return Fraction(a, b) / x
def frac_arith_float_rev(op, a, b, x):
if op == 0:
return x - Fraction(a, b)
elif op == 1:
return x + Fraction(a, b)
elif op == 2:
return x * Fraction(a, b)
else:
return x / Fraction(a, b)
def list_test():
x = 80
a = [3 for x in range(7)]
b = [1, 2, 4, 5, 4, 0, 5]
a[3] = x
a[0] += 6
a[1] = b[1] + b[2]
acc = 0
for i in range(7):
if i and a[i]:
acc += 1
acc += a[i]
return acc
def corner_cases():
two = True + True - (not True)
three = two + True//True - False*True
two_float = three - True/True
one_float = two_float - (1.0 == bool(0.1))
zero = int(one_float) + round(-0.6)
eleven_float = zero + 5.5//0.5
ten_float = eleven_float + round(Fraction(2, -3))
return ten_float
def _test_range():
for i in range(5, 10):
yield i
yield -i
class CodeGenCase(unittest.TestCase):
def _test_float_arith(self, op):
arith_c = CompiledFunction(arith, {
"op": base_types.VInt(),
"a": base_types.VFloat(), "b": base_types.VFloat()})
for a in _test_range():
for b in _test_range():
self.assertEqual(arith_c(op, a/2, b/2), arith(op, a/2, b/2))
def test_float_add(self):
self._test_float_arith(0)
def test_float_sub(self):
self._test_float_arith(1)
def test_float_mul(self):
self._test_float_arith(2)
def test_float_div(self):
self._test_float_arith(3)
def test_is_prime(self):
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
for i in range(200):
self.assertEqual(is_prime_c(i), is_prime(i))
def test_frac_simplify(self):
simplify_encode_c = CompiledFunction(
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
for a in _test_range():
for b in _test_range():
self.assertEqual(
simplify_encode_c(a, b), simplify_encode(a, b))
def _test_frac_arith(self, op):
frac_arith_encode_c = CompiledFunction(
frac_arith_encode, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"c": base_types.VInt(), "d": base_types.VInt()})
for a in _test_range():
for b in _test_range():
for c in _test_range():
for d in _test_range():
self.assertEqual(
frac_arith_encode_c(op, a, b, c, d),
frac_arith_encode(op, a, b, c, d))
def test_frac_add(self):
self._test_frac_arith(0)
def test_frac_sub(self):
self._test_frac_arith(1)
def test_frac_mul(self):
self._test_frac_arith(2)
def test_frac_div(self):
self._test_frac_arith(3)
def _test_frac_arith_int(self, op, rev):
f = frac_arith_encode_int_rev if rev else frac_arith_encode_int
f_c = CompiledFunction(f, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"x": base_types.VInt()})
for a in _test_range():
for b in _test_range():
for x in _test_range():
self.assertEqual(
f_c(op, a, b, x),
f(op, a, b, x))
def test_frac_add_int(self):
self._test_frac_arith_int(0, False)
self._test_frac_arith_int(0, True)
def test_frac_sub_int(self):
self._test_frac_arith_int(1, False)
self._test_frac_arith_int(1, True)
def test_frac_mul_int(self):
self._test_frac_arith_int(2, False)
self._test_frac_arith_int(2, True)
def test_frac_div_int(self):
self._test_frac_arith_int(3, False)
self._test_frac_arith_int(3, True)
def _test_frac_arith_float(self, op, rev):
f = frac_arith_float_rev if rev else frac_arith_float
f_c = CompiledFunction(f, {
"op": base_types.VInt(),
"a": base_types.VInt(), "b": base_types.VInt(),
"x": base_types.VFloat()})
for a in _test_range():
for b in _test_range():
for x in _test_range():
self.assertAlmostEqual(
f_c(op, a, b, x/2),
f(op, a, b, x/2))
def test_frac_add_float(self):
self._test_frac_arith_float(0, False)
self._test_frac_arith_float(0, True)
def test_frac_sub_float(self):
self._test_frac_arith_float(1, False)
self._test_frac_arith_float(1, True)
def test_frac_mul_float(self):
self._test_frac_arith_float(2, False)
self._test_frac_arith_float(2, True)
def test_frac_div_float(self):
self._test_frac_arith_float(3, False)
self._test_frac_arith_float(3, True)
def test_list(self):
list_test_c = CompiledFunction(list_test, dict())
self.assertEqual(list_test_c(), list_test())
def test_corner_cases(self):
corner_cases_c = CompiledFunction(corner_cases, dict())
self.assertEqual(corner_cases_c(), corner_cases())

View File

@ -1,44 +0,0 @@
import unittest
import ast
from artiq import ns
from artiq.coredevice import comm_dummy, core
from artiq.transforms.unparse import unparse
optimize_in = """
def run():
dds_sysclk = Fraction(1000000000, 1)
n = seconds_to_mu((1.2345 * Fraction(1, 1000000000)))
with sequential:
frequency = 345 * Fraction(1000000, 1)
frequency_to_ftw_return = int((((2 ** 32) * frequency) / dds_sysclk))
ftw = frequency_to_ftw_return
with sequential:
ftw2 = ftw
ftw_to_frequency_return = ((ftw2 * dds_sysclk) / (2 ** 32))
f = ftw_to_frequency_return
phi = ((1000 * mu_to_seconds(n)) * f)
do_something(int(phi))
"""
optimize_out = """
def run():
now = syscall('now_init')
try:
do_something(344)
finally:
syscall('now_save', now)
"""
class OptimizeCase(unittest.TestCase):
def test_optimize(self):
dmgr = dict()
dmgr["comm"] = comm_dummy.Comm(dmgr)
coredev = core.Core(dmgr, ref_period=1*ns)
func_def = ast.parse(optimize_in).body[0]
coredev.transform_stack(func_def, dict(), dict())
self.assertEqual(unparse(func_def), optimize_out)

View File

@ -1,156 +0,0 @@
import ast
import operator
from fractions import Fraction
from artiq.transforms.tools import *
from artiq.language.core import int64, round64
_ast_unops = {
ast.Invert: operator.inv,
ast.Not: operator.not_,
ast.UAdd: operator.pos,
ast.USub: operator.neg
}
_ast_binops = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
ast.LShift: operator.lshift,
ast.RShift: operator.rshift,
ast.BitOr: operator.or_,
ast.BitXor: operator.xor,
ast.BitAnd: operator.and_
}
_ast_cmpops = {
ast.Eq: operator.eq,
ast.NotEq: operator.ne,
ast.Lt: operator.lt,
ast.LtE: operator.le,
ast.Gt: operator.gt,
ast.GtE: operator.ge
}
_ast_boolops = {
ast.Or: lambda x, y: x or y,
ast.And: lambda x, y: x and y
}
class _ConstantFolder(ast.NodeTransformer):
def visit_UnaryOp(self, node):
self.generic_visit(node)
try:
operand = eval_constant(node.operand)
except NotConstant:
return node
try:
op = _ast_unops[type(node.op)]
except KeyError:
return node
try:
result = value_to_ast(op(operand))
except:
return node
return ast.copy_location(result, node)
def visit_BinOp(self, node):
self.generic_visit(node)
try:
left, right = eval_constant(node.left), eval_constant(node.right)
except NotConstant:
return node
try:
op = _ast_binops[type(node.op)]
except KeyError:
return node
try:
result = value_to_ast(op(left, right))
except:
return node
return ast.copy_location(result, node)
def visit_Compare(self, node):
self.generic_visit(node)
try:
operands = [eval_constant(node.left)]
except NotConstant:
operands = [node.left]
ops = []
for op, right_ast in zip(node.ops, node.comparators):
try:
right = eval_constant(right_ast)
except NotConstant:
right = right_ast
if (not isinstance(operands[-1], ast.AST)
and not isinstance(right, ast.AST)):
left = operands.pop()
operands.append(_ast_cmpops[type(op)](left, right))
else:
ops.append(op)
operands.append(right_ast)
operands = [operand if isinstance(operand, ast.AST)
else ast.copy_location(value_to_ast(operand), node)
for operand in operands]
if len(operands) == 1:
return operands[0]
else:
node.left = operands[0]
node.right = operands[1:]
node.ops = ops
return node
def visit_BoolOp(self, node):
self.generic_visit(node)
new_values = []
for value in node.values:
try:
value_c = eval_constant(value)
except NotConstant:
new_values.append(value)
else:
if new_values and not isinstance(new_values[-1], ast.AST):
op = _ast_boolops[type(node.op)]
new_values[-1] = op(new_values[-1], value_c)
else:
new_values.append(value_c)
new_values = [v if isinstance(v, ast.AST) else value_to_ast(v)
for v in new_values]
if len(new_values) > 1:
node.values = new_values
return node
else:
return new_values[0]
def visit_Call(self, node):
self.generic_visit(node)
fn = node.func.id
constant_ops = {
"int": int,
"int64": int64,
"round": round,
"round64": round64,
"Fraction": Fraction
}
if fn in constant_ops:
args = []
for arg in node.args:
try:
args.append(eval_constant(arg))
except NotConstant:
return node
result = value_to_ast(constant_ops[fn](*args))
return ast.copy_location(result, node)
else:
return node
def fold_constants(node):
_ConstantFolder().visit(node)

View File

@ -1,59 +0,0 @@
import ast
from artiq.transforms.tools import is_ref_transparent
class _SourceLister(ast.NodeVisitor):
def __init__(self):
self.sources = set()
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
self.sources.add(node.id)
class _DeadCodeRemover(ast.NodeTransformer):
def __init__(self, kept_targets):
self.kept_targets = kept_targets
def visit_Assign(self, node):
new_targets = []
for target in node.targets:
if (not isinstance(target, ast.Name)
or target.id in self.kept_targets):
new_targets.append(target)
if not new_targets and is_ref_transparent(node.value)[0]:
return None
else:
return node
def visit_AugAssign(self, node):
if (isinstance(node.target, ast.Name)
and node.target.id not in self.kept_targets
and is_ref_transparent(node.value)[0]):
return None
else:
return node
def visit_If(self, node):
self.generic_visit(node)
if isinstance(node.test, ast.NameConstant):
if node.test.value:
return node.body
else:
return node.orelse
else:
return node
def visit_While(self, node):
self.generic_visit(node)
if isinstance(node.test, ast.NameConstant) and not node.test.value:
return node.orelse
else:
return node
def remove_dead_code(func_def):
sl = _SourceLister()
sl.visit(func_def)
_DeadCodeRemover(sl.sources).visit(func_def)

View File

@ -1,149 +0,0 @@
import ast
from copy import copy, deepcopy
from collections import defaultdict
from artiq.transforms.tools import is_ref_transparent, count_all_nodes
class _TargetLister(ast.NodeVisitor):
def __init__(self):
self.targets = set()
def visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
self.targets.add(node.id)
class _InterAssignRemover(ast.NodeTransformer):
def __init__(self):
self.replacements = dict()
self.modified_names = set()
# name -> set of names that depend on it
# i.e. when x is modified, dependencies[x] is the set of names that
# cannot be replaced anymore
self.dependencies = defaultdict(set)
def invalidate(self, name):
try:
del self.replacements[name]
except KeyError:
pass
for d in self.dependencies[name]:
self.invalidate(d)
del self.dependencies[name]
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
try:
return deepcopy(self.replacements[node.id])
except KeyError:
return node
else:
self.modified_names.add(node.id)
self.invalidate(node.id)
return node
def visit_Assign(self, node):
node.value = self.visit(node.value)
node.targets = [self.visit(target) for target in node.targets]
rt, depends_on = is_ref_transparent(node.value)
if rt and count_all_nodes(node.value) < 100:
for target in node.targets:
if isinstance(target, ast.Name):
if target.id not in depends_on:
self.replacements[target.id] = node.value
for d in depends_on:
self.dependencies[d].add(target.id)
return node
def visit_AugAssign(self, node):
left = deepcopy(node.target)
left.ctx = ast.Load()
newnode = ast.copy_location(
ast.Assign(
targets=[node.target],
value=ast.BinOp(left=left, op=node.op, right=node.value)
),
node
)
return self.visit_Assign(newnode)
def modified_names_push(self):
prev_modified_names = self.modified_names
self.modified_names = set()
return prev_modified_names
def modified_names_pop(self, prev_modified_names):
for name in self.modified_names:
self.invalidate(name)
self.modified_names |= prev_modified_names
def visit_Try(self, node):
prev_modified_names = self.modified_names_push()
node.body = [self.visit(stmt) for stmt in node.body]
self.modified_names_pop(prev_modified_names)
prev_modified_names = self.modified_names_push()
prev_replacements = self.replacements
for handler in node.handlers:
self.replacements = copy(prev_replacements)
handler.body = [self.visit(stmt) for stmt in handler.body]
self.replacements = copy(prev_replacements)
node.orelse = [self.visit(stmt) for stmt in node.orelse]
self.modified_names_pop(prev_modified_names)
prev_modified_names = self.modified_names_push()
node.finalbody = [self.visit(stmt) for stmt in node.finalbody]
self.modified_names_pop(prev_modified_names)
return node
def visit_If(self, node):
node.test = self.visit(node.test)
prev_modified_names = self.modified_names_push()
prev_replacements = self.replacements
self.replacements = copy(prev_replacements)
node.body = [self.visit(n) for n in node.body]
self.replacements = copy(prev_replacements)
node.orelse = [self.visit(n) for n in node.orelse]
self.replacements = prev_replacements
self.modified_names_pop(prev_modified_names)
return node
def visit_loop(self, node):
prev_modified_names = self.modified_names_push()
prev_replacements = self.replacements
self.replacements = copy(prev_replacements)
tl = _TargetLister()
for n in node.body:
tl.visit(n)
for name in tl.targets:
self.invalidate(name)
node.body = [self.visit(n) for n in node.body]
self.replacements = copy(prev_replacements)
node.orelse = [self.visit(n) for n in node.orelse]
self.replacements = prev_replacements
self.modified_names_pop(prev_modified_names)
def visit_For(self, node):
prev_modified_names = self.modified_names_push()
node.target = self.visit(node.target)
self.modified_names_pop(prev_modified_names)
node.iter = self.visit(node.iter)
self.visit_loop(node)
return node
def visit_While(self, node):
self.visit_loop(node)
node.test = self.visit(node.test)
return node
def remove_inter_assigns(func_def):
_InterAssignRemover().visit(func_def)

View File

@ -1,141 +0,0 @@
import ast
from fractions import Fraction
from artiq.language import core as core_language
from artiq.language import units
embeddable_funcs = (
core_language.delay_mu, core_language.at_mu, core_language.now_mu,
core_language.delay,
core_language.seconds_to_mu, core_language.mu_to_seconds,
core_language.syscall, core_language.watchdog,
range, bool, int, float, round, len,
core_language.int64, core_language.round64,
Fraction, core_language.EncodedException
)
embeddable_func_names = {func.__name__ for func in embeddable_funcs}
def is_embeddable(func):
for ef in embeddable_funcs:
if func is ef:
return True
return False
def eval_ast(expr, symdict=dict()):
if not isinstance(expr, ast.Expression):
expr = ast.copy_location(ast.Expression(expr), expr)
ast.fix_missing_locations(expr)
code = compile(expr, "<ast>", "eval")
return eval(code, symdict)
class NotASTRepresentable(Exception):
pass
def value_to_ast(value):
if isinstance(value, core_language.int64): # must be before int
return ast.Call(
func=ast.Name("int64", ast.Load()),
args=[ast.Num(int(value))],
keywords=[], starargs=None, kwargs=None)
elif isinstance(value, bool) or value is None:
# must also be before int
# isinstance(True/False, int) == True
return ast.NameConstant(value)
elif isinstance(value, (int, float)):
return ast.Num(value)
elif isinstance(value, Fraction):
return ast.Call(
func=ast.Name("Fraction", ast.Load()),
args=[ast.Num(value.numerator), ast.Num(value.denominator)],
keywords=[], starargs=None, kwargs=None)
elif isinstance(value, str):
return ast.Str(value)
elif isinstance(value, list):
elts = [value_to_ast(elt) for elt in value]
return ast.List(elts, ast.Load())
else:
for kg in core_language.kernel_globals:
if value is getattr(core_language, kg):
return ast.Name(kg, ast.Load())
raise NotASTRepresentable(str(value))
class NotConstant(Exception):
pass
def eval_constant(node):
if isinstance(node, ast.Num):
return node.n
elif isinstance(node, ast.Str):
return node.s
elif isinstance(node, ast.NameConstant):
return node.value
elif isinstance(node, ast.Call):
funcname = node.func.id
if funcname == "int64":
return core_language.int64(eval_constant(node.args[0]))
elif funcname == "Fraction":
numerator = eval_constant(node.args[0])
denominator = eval_constant(node.args[1])
return Fraction(numerator, denominator)
else:
raise NotConstant
else:
raise NotConstant
_replaceable_funcs = {
"bool", "int", "float", "round",
"int64", "round64", "Fraction",
"seconds_to_mu", "mu_to_seconds"
}
def _is_ref_transparent(dependencies, expr):
if isinstance(expr, (ast.NameConstant, ast.Num, ast.Str)):
return True
elif isinstance(expr, ast.Name):
dependencies.add(expr.id)
return True
elif isinstance(expr, ast.UnaryOp):
return _is_ref_transparent(dependencies, expr.operand)
elif isinstance(expr, ast.BinOp):
return (_is_ref_transparent(dependencies, expr.left)
and _is_ref_transparent(dependencies, expr.right))
elif isinstance(expr, ast.BoolOp):
return all(_is_ref_transparent(dependencies, v) for v in expr.values)
elif isinstance(expr, ast.Call):
return (expr.func.id in _replaceable_funcs and
all(_is_ref_transparent(dependencies, arg)
for arg in expr.args))
else:
return False
def is_ref_transparent(expr):
dependencies = set()
if _is_ref_transparent(dependencies, expr):
return True, dependencies
else:
return False, None
class _NodeCounter(ast.NodeVisitor):
def __init__(self):
self.count = 0
def generic_visit(self, node):
self.count += 1
ast.NodeVisitor.generic_visit(self, node)
def count_all_nodes(node):
nc = _NodeCounter()
nc.visit(node)
return nc.count

View File

@ -1,600 +0,0 @@
import sys
import ast
# Large float and imaginary literals get turned into infinities in the AST.
# We unparse those infinities to INFSTR.
INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
def _interleave(inter, f, seq):
"""Call f on each item in seq, calling inter() in between.
"""
seq = iter(seq)
try:
f(next(seq))
except StopIteration:
pass
else:
for x in seq:
inter()
f(x)
class _Unparser:
"""Methods in this class recursively traverse an AST and
output source code for the abstract syntax; original formatting
is disregarded. """
def __init__(self, tree):
"""Print the source for tree to the "result" string."""
self.result = ""
self._indent = 0
self.dispatch(tree)
self.result += "\n"
def fill(self, text=""):
"Indent a piece of text, according to the current indentation level"
self.result += "\n"+" "*self._indent + text
def write(self, text):
"Append a piece of text to the current line."
self.result += text
def enter(self):
"Print ':', and increase the indentation."
self.write(":")
self._indent += 1
def leave(self):
"Decrease the indentation level."
self._indent -= 1
def dispatch(self, tree):
"Dispatcher function, dispatching tree type T to method _T."
if isinstance(tree, list):
for t in tree:
self.dispatch(t)
return
meth = getattr(self, "_"+tree.__class__.__name__)
meth(tree)
# Unparsing methods
#
# There should be one method per concrete grammar type
# Constructors should be grouped by sum type. Ideally,
# this would follow the order in the grammar, but
# currently doesn't.
def _Module(self, tree):
for stmt in tree.body:
self.dispatch(stmt)
# stmt
def _Expr(self, tree):
self.fill()
self.dispatch(tree.value)
def _Import(self, t):
self.fill("import ")
_interleave(lambda: self.write(", "), self.dispatch, t.names)
def _ImportFrom(self, t):
self.fill("from ")
self.write("." * t.level)
if t.module:
self.write(t.module)
self.write(" import ")
_interleave(lambda: self.write(", "), self.dispatch, t.names)
def _Assign(self, t):
self.fill()
for target in t.targets:
self.dispatch(target)
self.write(" = ")
self.dispatch(t.value)
def _AugAssign(self, t):
self.fill()
self.dispatch(t.target)
self.write(" "+self.binop[t.op.__class__.__name__]+"= ")
self.dispatch(t.value)
def _Return(self, t):
self.fill("return")
if t.value:
self.write(" ")
self.dispatch(t.value)
def _Pass(self, t):
self.fill("pass")
def _Break(self, t):
self.fill("break")
def _Continue(self, t):
self.fill("continue")
def _Delete(self, t):
self.fill("del ")
_interleave(lambda: self.write(", "), self.dispatch, t.targets)
def _Assert(self, t):
self.fill("assert ")
self.dispatch(t.test)
if t.msg:
self.write(", ")
self.dispatch(t.msg)
def _Global(self, t):
self.fill("global ")
_interleave(lambda: self.write(", "), self.write, t.names)
def _Nonlocal(self, t):
self.fill("nonlocal ")
_interleave(lambda: self.write(", "), self.write, t.names)
def _Yield(self, t):
self.write("(")
self.write("yield")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
def _YieldFrom(self, t):
self.write("(")
self.write("yield from")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
def _Raise(self, t):
self.fill("raise")
if not t.exc:
assert not t.cause
return
self.write(" ")
self.dispatch(t.exc)
if t.cause:
self.write(" from ")
self.dispatch(t.cause)
def _Try(self, t):
self.fill("try")
self.enter()
self.dispatch(t.body)
self.leave()
for ex in t.handlers:
self.dispatch(ex)
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
if t.finalbody:
self.fill("finally")
self.enter()
self.dispatch(t.finalbody)
self.leave()
def _ExceptHandler(self, t):
self.fill("except")
if t.type:
self.write(" ")
self.dispatch(t.type)
if t.name:
self.write(" as ")
self.write(t.name)
self.enter()
self.dispatch(t.body)
self.leave()
def _ClassDef(self, t):
self.write("\n")
for deco in t.decorator_list:
self.fill("@")
self.dispatch(deco)
self.fill("class "+t.name)
self.write("(")
comma = False
for e in t.bases:
if comma:
self.write(", ")
else:
comma = True
self.dispatch(e)
for e in t.keywords:
if comma:
self.write(", ")
else:
comma = True
self.dispatch(e)
if t.starargs:
if comma:
self.write(", ")
else:
comma = True
self.write("*")
self.dispatch(t.starargs)
if t.kwargs:
if comma:
self.write(", ")
else:
comma = True
self.write("**")
self.dispatch(t.kwargs)
self.write(")")
self.enter()
self.dispatch(t.body)
self.leave()
def _FunctionDef(self, t):
self.write("\n")
for deco in t.decorator_list:
self.fill("@")
self.dispatch(deco)
self.fill("def "+t.name + "(")
self.dispatch(t.args)
self.write(")")
if t.returns:
self.write(" -> ")
self.dispatch(t.returns)
self.enter()
self.dispatch(t.body)
self.leave()
def _For(self, t):
self.fill("for ")
self.dispatch(t.target)
self.write(" in ")
self.dispatch(t.iter)
self.enter()
self.dispatch(t.body)
self.leave()
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _If(self, t):
self.fill("if ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
# collapse nested ifs into equivalent elifs.
while (t.orelse and len(t.orelse) == 1 and
isinstance(t.orelse[0], ast.If)):
t = t.orelse[0]
self.fill("elif ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
# final else
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _While(self, t):
self.fill("while ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _With(self, t):
self.fill("with ")
_interleave(lambda: self.write(", "), self.dispatch, t.items)
self.enter()
self.dispatch(t.body)
self.leave()
# expr
def _Bytes(self, t):
self.write(repr(t.s))
def _Str(self, tree):
self.write(repr(tree.s))
def _Name(self, t):
self.write(t.id)
def _NameConstant(self, t):
self.write(repr(t.value))
def _Num(self, t):
# Substitute overflowing decimal literal for AST infinities.
self.write(repr(t.n).replace("inf", INFSTR))
def _List(self, t):
self.write("[")
_interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("]")
def _ListComp(self, t):
self.write("[")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write("]")
def _GeneratorExp(self, t):
self.write("(")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write(")")
def _SetComp(self, t):
self.write("{")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write("}")
def _DictComp(self, t):
self.write("{")
self.dispatch(t.key)
self.write(": ")
self.dispatch(t.value)
for gen in t.generators:
self.dispatch(gen)
self.write("}")
def _comprehension(self, t):
self.write(" for ")
self.dispatch(t.target)
self.write(" in ")
self.dispatch(t.iter)
for if_clause in t.ifs:
self.write(" if ")
self.dispatch(if_clause)
def _IfExp(self, t):
self.write("(")
self.dispatch(t.body)
self.write(" if ")
self.dispatch(t.test)
self.write(" else ")
self.dispatch(t.orelse)
self.write(")")
def _Set(self, t):
assert(t.elts) # should be at least one element
self.write("{")
_interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("}")
def _Dict(self, t):
self.write("{")
def write_pair(pair):
(k, v) = pair
self.dispatch(k)
self.write(": ")
self.dispatch(v)
_interleave(lambda: self.write(", "), write_pair,
zip(t.keys, t.values))
self.write("}")
def _Tuple(self, t):
self.write("(")
if len(t.elts) == 1:
(elt,) = t.elts
self.dispatch(elt)
self.write(",")
else:
_interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write(")")
unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
def _UnaryOp(self, t):
self.write("(")
self.write(self.unop[t.op.__class__.__name__])
self.write(" ")
self.dispatch(t.operand)
self.write(")")
binop = {"Add": "+", "Sub": "-", "Mult": "*", "Div": "/", "Mod": "%",
"LShift": "<<", "RShift": ">>",
"BitOr": "|", "BitXor": "^", "BitAnd": "&",
"FloorDiv": "//", "Pow": "**"}
def _BinOp(self, t):
self.write("(")
self.dispatch(t.left)
self.write(" " + self.binop[t.op.__class__.__name__] + " ")
self.dispatch(t.right)
self.write(")")
cmpops = {"Eq": "==", "NotEq": "!=",
"Lt": "<", "LtE": "<=", "Gt": ">", "GtE": ">=",
"Is": "is", "IsNot": "is not", "In": "in", "NotIn": "not in"}
def _Compare(self, t):
self.write("(")
self.dispatch(t.left)
for o, e in zip(t.ops, t.comparators):
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
self.dispatch(e)
self.write(")")
boolops = {ast.And: "and", ast.Or: "or"}
def _BoolOp(self, t):
self.write("(")
s = " %s " % self.boolops[t.op.__class__]
_interleave(lambda: self.write(s), self.dispatch, t.values)
self.write(")")
def _Attribute(self, t):
self.dispatch(t.value)
# Special case: 3.__abs__() is a syntax error, so if t.value
# is an integer literal then we need to either parenthesize
# it or add an extra space to get 3 .__abs__().
if isinstance(t.value, ast.Num) and isinstance(t.value.n, int):
self.write(" ")
self.write(".")
self.write(t.attr)
def _Call(self, t):
self.dispatch(t.func)
self.write("(")
comma = False
for e in t.args:
if comma:
self.write(", ")
else:
comma = True
self.dispatch(e)
for e in t.keywords:
if comma:
self.write(", ")
else:
comma = True
self.dispatch(e)
if t.starargs:
if comma:
self.write(", ")
else:
comma = True
self.write("*")
self.dispatch(t.starargs)
if t.kwargs:
if comma:
self.write(", ")
else:
comma = True
self.write("**")
self.dispatch(t.kwargs)
self.write(")")
def _Subscript(self, t):
self.dispatch(t.value)
self.write("[")
self.dispatch(t.slice)
self.write("]")
def _Starred(self, t):
self.write("*")
self.dispatch(t.value)
# slice
def _Ellipsis(self, t):
self.write("...")
def _Index(self, t):
self.dispatch(t.value)
def _Slice(self, t):
if t.lower:
self.dispatch(t.lower)
self.write(":")
if t.upper:
self.dispatch(t.upper)
if t.step:
self.write(":")
self.dispatch(t.step)
def _ExtSlice(self, t):
_interleave(lambda: self.write(', '), self.dispatch, t.dims)
# argument
def _arg(self, t):
self.write(t.arg)
if t.annotation:
self.write(": ")
self.dispatch(t.annotation)
# others
def _arguments(self, t):
first = True
# normal arguments
defaults = [None] * (len(t.args) - len(t.defaults)) + t.defaults
for a, d in zip(t.args, defaults):
if first:
first = False
else:
self.write(", ")
self.dispatch(a)
if d:
self.write("=")
self.dispatch(d)
# varargs, or bare '*' if no varargs but keyword-only arguments present
if t.vararg or t.kwonlyargs:
if first:
first = False
else:
self.write(", ")
self.write("*")
if t.vararg:
self.write(t.vararg.arg)
if t.vararg.annotation:
self.write(": ")
self.dispatch(t.vararg.annotation)
# keyword-only arguments
if t.kwonlyargs:
for a, d in zip(t.kwonlyargs, t.kw_defaults):
if first:
first = False
else:
self.write(", ")
self.dispatch(a),
if d:
self.write("=")
self.dispatch(d)
# kwargs
if t.kwarg:
if first:
first = False
else:
self.write(", ")
self.write("**"+t.kwarg.arg)
if t.kwarg.annotation:
self.write(": ")
self.dispatch(t.kwarg.annotation)
def _keyword(self, t):
self.write(t.arg)
self.write("=")
self.dispatch(t.value)
def _Lambda(self, t):
self.write("(")
self.write("lambda ")
self.dispatch(t.args)
self.write(": ")
self.dispatch(t.body)
self.write(")")
def _alias(self, t):
self.write(t.name)
if t.asname:
self.write(" as "+t.asname)
def _withitem(self, t):
self.dispatch(t.context_expr)
if t.optional_vars:
self.write(" as ")
self.dispatch(t.optional_vars)
def unparse(tree):
unparser = _Unparser(tree)
return unparser.result