mirror of
https://github.com/m-labs/artiq.git
synced 2024-12-25 19:28:26 +08:00
Remove parts of py2llvm that are implemented in the new compiler.
This commit is contained in:
parent
62e6f8a03d
commit
200330a808
@ -1,6 +0,0 @@
|
||||
from artiq.py2llvm.module import Module
|
||||
|
||||
def get_runtime_binary(runtime, func_def):
|
||||
module = Module(runtime)
|
||||
module.compile_function(func_def, dict())
|
||||
return module.emit_object()
|
@ -1,539 +0,0 @@
|
||||
from pythonparser import ast
|
||||
|
||||
import llvmlite_or1k.ir as ll
|
||||
|
||||
from artiq.py2llvm import values, base_types, fractions, lists, iterators
|
||||
from artiq.py2llvm.tools import is_terminated
|
||||
|
||||
|
||||
_ast_unops = {
|
||||
ast.Invert: "o_inv",
|
||||
ast.Not: "o_not",
|
||||
ast.UAdd: "o_pos",
|
||||
ast.USub: "o_neg"
|
||||
}
|
||||
|
||||
_ast_binops = {
|
||||
ast.Add: values.operators.add,
|
||||
ast.Sub: values.operators.sub,
|
||||
ast.Mult: values.operators.mul,
|
||||
ast.Div: values.operators.truediv,
|
||||
ast.FloorDiv: values.operators.floordiv,
|
||||
ast.Mod: values.operators.mod,
|
||||
ast.Pow: values.operators.pow,
|
||||
ast.LShift: values.operators.lshift,
|
||||
ast.RShift: values.operators.rshift,
|
||||
ast.BitOr: values.operators.or_,
|
||||
ast.BitXor: values.operators.xor,
|
||||
ast.BitAnd: values.operators.and_
|
||||
}
|
||||
|
||||
_ast_cmps = {
|
||||
ast.Eq: values.operators.eq,
|
||||
ast.NotEq: values.operators.ne,
|
||||
ast.Lt: values.operators.lt,
|
||||
ast.LtE: values.operators.le,
|
||||
ast.Gt: values.operators.gt,
|
||||
ast.GtE: values.operators.ge
|
||||
}
|
||||
|
||||
|
||||
class Visitor:
|
||||
def __init__(self, runtime, ns, builder=None):
|
||||
self.runtime = runtime
|
||||
self.ns = ns
|
||||
self.builder = builder
|
||||
self._break_stack = []
|
||||
self._continue_stack = []
|
||||
self._active_exception_stack = []
|
||||
self._exception_level_stack = [0]
|
||||
|
||||
# builder can be None for visit_expression
|
||||
def visit_expression(self, node):
|
||||
method = "_visit_expr_" + node.__class__.__name__
|
||||
try:
|
||||
visitor = getattr(self, method)
|
||||
except AttributeError:
|
||||
raise NotImplementedError("Unsupported node '{}' in expression"
|
||||
.format(node.__class__.__name__))
|
||||
return visitor(node)
|
||||
|
||||
def _visit_expr_Name(self, node):
|
||||
try:
|
||||
r = self.ns[node.id]
|
||||
except KeyError:
|
||||
raise NameError("Name '{}' is not defined".format(node.id))
|
||||
return r
|
||||
|
||||
def _visit_expr_NameConstant(self, node):
|
||||
v = node.value
|
||||
if v is None:
|
||||
r = base_types.VNone()
|
||||
elif isinstance(v, bool):
|
||||
r = base_types.VBool()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.builder is not None:
|
||||
r.set_const_value(self.builder, v)
|
||||
return r
|
||||
|
||||
def _visit_expr_Num(self, node):
|
||||
n = node.n
|
||||
if isinstance(n, int):
|
||||
if abs(n) < 2**31:
|
||||
r = base_types.VInt()
|
||||
else:
|
||||
r = base_types.VInt(64)
|
||||
elif isinstance(n, float):
|
||||
r = base_types.VFloat()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.builder is not None:
|
||||
r.set_const_value(self.builder, n)
|
||||
return r
|
||||
|
||||
def _visit_expr_UnaryOp(self, node):
|
||||
value = self.visit_expression(node.operand)
|
||||
return getattr(value, _ast_unops[type(node.op)])(self.builder)
|
||||
|
||||
def _visit_expr_BinOp(self, node):
|
||||
return _ast_binops[type(node.op)](self.visit_expression(node.left),
|
||||
self.visit_expression(node.right),
|
||||
self.builder)
|
||||
|
||||
def _visit_expr_BoolOp(self, node):
|
||||
if self.builder is not None:
|
||||
initial_block = self.builder.basic_block
|
||||
function = initial_block.function
|
||||
merge_block = function.append_basic_block("b_merge")
|
||||
|
||||
test_blocks = []
|
||||
test_values = []
|
||||
for i, value in enumerate(node.values):
|
||||
if self.builder is not None:
|
||||
test_block = function.append_basic_block("b_{}_test".format(i))
|
||||
test_blocks.append(test_block)
|
||||
self.builder.position_at_end(test_block)
|
||||
test_values.append(self.visit_expression(value))
|
||||
|
||||
result = test_values[0].new()
|
||||
for value in test_values[1:]:
|
||||
result.merge(value)
|
||||
|
||||
if self.builder is not None:
|
||||
self.builder.position_at_end(initial_block)
|
||||
result.alloca(self.builder, "b_result")
|
||||
self.builder.branch(test_blocks[0])
|
||||
|
||||
next_test_blocks = test_blocks[1:]
|
||||
next_test_blocks.append(None)
|
||||
for block, next_block, value in zip(test_blocks,
|
||||
next_test_blocks,
|
||||
test_values):
|
||||
self.builder.position_at_end(block)
|
||||
bval = value.o_bool(self.builder)
|
||||
result.auto_store(self.builder,
|
||||
value.auto_load(self.builder))
|
||||
if next_block is None:
|
||||
self.builder.branch(merge_block)
|
||||
else:
|
||||
if isinstance(node.op, ast.Or):
|
||||
self.builder.cbranch(bval.auto_load(self.builder),
|
||||
merge_block,
|
||||
next_block)
|
||||
elif isinstance(node.op, ast.And):
|
||||
self.builder.cbranch(bval.auto_load(self.builder),
|
||||
next_block,
|
||||
merge_block)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.builder.position_at_end(merge_block)
|
||||
|
||||
return result
|
||||
|
||||
def _visit_expr_Compare(self, node):
|
||||
comparisons = []
|
||||
old_comparator = self.visit_expression(node.left)
|
||||
for op, comparator_a in zip(node.ops, node.comparators):
|
||||
comparator = self.visit_expression(comparator_a)
|
||||
comparison = _ast_cmps[type(op)](old_comparator, comparator,
|
||||
self.builder)
|
||||
comparisons.append(comparison)
|
||||
old_comparator = comparator
|
||||
r = comparisons[0]
|
||||
for comparison in comparisons[1:]:
|
||||
r = values.operators.and_(r, comparison)
|
||||
return r
|
||||
|
||||
def _visit_expr_Call(self, node):
|
||||
fn = node.func.id
|
||||
if fn in {"bool", "int", "int64", "round", "round64", "float", "len"}:
|
||||
value = self.visit_expression(node.args[0])
|
||||
return getattr(value, "o_" + fn)(self.builder)
|
||||
elif fn == "Fraction":
|
||||
r = fractions.VFraction()
|
||||
if self.builder is not None:
|
||||
numerator = self.visit_expression(node.args[0])
|
||||
denominator = self.visit_expression(node.args[1])
|
||||
r.set_value_nd(self.builder, numerator, denominator)
|
||||
return r
|
||||
elif fn == "range":
|
||||
return iterators.IRange(
|
||||
self.builder,
|
||||
[self.visit_expression(arg) for arg in node.args])
|
||||
elif fn == "syscall":
|
||||
return self.runtime.build_syscall(
|
||||
node.args[0].s,
|
||||
[self.visit_expression(expr) for expr in node.args[1:]],
|
||||
self.builder)
|
||||
else:
|
||||
raise NameError("Function '{}' is not defined".format(fn))
|
||||
|
||||
def _visit_expr_Attribute(self, node):
|
||||
value = self.visit_expression(node.value)
|
||||
return value.o_getattr(node.attr, self.builder)
|
||||
|
||||
def _visit_expr_List(self, node):
|
||||
elts = [self.visit_expression(elt) for elt in node.elts]
|
||||
if elts:
|
||||
el_type = elts[0].new()
|
||||
for elt in elts[1:]:
|
||||
el_type.merge(elt)
|
||||
else:
|
||||
el_type = base_types.VNone()
|
||||
count = len(elts)
|
||||
r = lists.VList(el_type, count)
|
||||
r.elts = elts
|
||||
return r
|
||||
|
||||
def _visit_expr_ListComp(self, node):
|
||||
if len(node.generators) != 1:
|
||||
raise NotImplementedError
|
||||
generator = node.generators[0]
|
||||
if not isinstance(generator, ast.comprehension):
|
||||
raise NotImplementedError
|
||||
if not isinstance(generator.target, ast.Name):
|
||||
raise NotImplementedError
|
||||
target = generator.target.id
|
||||
if not isinstance(generator.iter, ast.Call):
|
||||
raise NotImplementedError
|
||||
if not isinstance(generator.iter.func, ast.Name):
|
||||
raise NotImplementedError
|
||||
if generator.iter.func.id != "range":
|
||||
raise NotImplementedError
|
||||
if len(generator.iter.args) != 1:
|
||||
raise NotImplementedError
|
||||
if not isinstance(generator.iter.args[0], ast.Num):
|
||||
raise NotImplementedError
|
||||
count = generator.iter.args[0].n
|
||||
|
||||
# Prevent incorrect use of the generator target, if it is defined in
|
||||
# the local function namespace.
|
||||
if target in self.ns:
|
||||
old_target_val = self.ns[target]
|
||||
del self.ns[target]
|
||||
else:
|
||||
old_target_val = None
|
||||
elt = self.visit_expression(node.elt)
|
||||
if old_target_val is not None:
|
||||
self.ns[target] = old_target_val
|
||||
|
||||
el_type = elt.new()
|
||||
r = lists.VList(el_type, count)
|
||||
r.elt = elt
|
||||
return r
|
||||
|
||||
def _visit_expr_Subscript(self, node):
|
||||
value = self.visit_expression(node.value)
|
||||
if isinstance(node.slice, ast.Index):
|
||||
index = self.visit_expression(node.slice.value)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return value.o_subscript(index, self.builder)
|
||||
|
||||
def visit_statements(self, stmts):
|
||||
for node in stmts:
|
||||
node_type = node.__class__.__name__
|
||||
method = "_visit_stmt_" + node_type
|
||||
try:
|
||||
visitor = getattr(self, method)
|
||||
except AttributeError:
|
||||
raise NotImplementedError("Unsupported node '{}' in statement"
|
||||
.format(node_type))
|
||||
visitor(node)
|
||||
if node_type in ("Return", "Break", "Continue"):
|
||||
break
|
||||
|
||||
def _bb_terminated(self):
|
||||
return is_terminated(self.builder.basic_block)
|
||||
|
||||
def _visit_stmt_Assign(self, node):
|
||||
val = self.visit_expression(node.value)
|
||||
if isinstance(node.value, ast.List):
|
||||
if len(node.targets) > 1:
|
||||
raise NotImplementedError
|
||||
target = self.visit_expression(node.targets[0])
|
||||
target.set_count(self.builder, val.alloc_count)
|
||||
for i, elt in enumerate(val.elts):
|
||||
idx = base_types.VInt()
|
||||
idx.set_const_value(self.builder, i)
|
||||
target.o_subscript(idx, self.builder).set_value(self.builder,
|
||||
elt)
|
||||
elif isinstance(node.value, ast.ListComp):
|
||||
if len(node.targets) > 1:
|
||||
raise NotImplementedError
|
||||
target = self.visit_expression(node.targets[0])
|
||||
target.set_count(self.builder, val.alloc_count)
|
||||
|
||||
i = base_types.VInt()
|
||||
i.alloca(self.builder)
|
||||
i.auto_store(self.builder, ll.Constant(ll.IntType(32), 0))
|
||||
|
||||
function = self.builder.basic_block.function
|
||||
copy_block = function.append_basic_block("ai_copy")
|
||||
end_block = function.append_basic_block("ai_end")
|
||||
self.builder.branch(copy_block)
|
||||
|
||||
self.builder.position_at_end(copy_block)
|
||||
target.o_subscript(i, self.builder).set_value(self.builder,
|
||||
val.elt)
|
||||
i.auto_store(self.builder, self.builder.add(
|
||||
i.auto_load(self.builder),
|
||||
ll.Constant(ll.IntType(32), 1)))
|
||||
cont = self.builder.icmp_signed(
|
||||
"<", i.auto_load(self.builder),
|
||||
ll.Constant(ll.IntType(32), val.alloc_count))
|
||||
self.builder.cbranch(cont, copy_block, end_block)
|
||||
|
||||
self.builder.position_at_end(end_block)
|
||||
else:
|
||||
for target in node.targets:
|
||||
target = self.visit_expression(target)
|
||||
target.set_value(self.builder, val)
|
||||
|
||||
def _visit_stmt_AugAssign(self, node):
|
||||
target = self.visit_expression(node.target)
|
||||
right = self.visit_expression(node.value)
|
||||
val = _ast_binops[type(node.op)](target, right, self.builder)
|
||||
target.set_value(self.builder, val)
|
||||
|
||||
def _visit_stmt_Expr(self, node):
|
||||
self.visit_expression(node.value)
|
||||
|
||||
def _visit_stmt_If(self, node):
|
||||
function = self.builder.basic_block.function
|
||||
then_block = function.append_basic_block("i_then")
|
||||
else_block = function.append_basic_block("i_else")
|
||||
merge_block = function.append_basic_block("i_merge")
|
||||
|
||||
condition = self.visit_expression(node.test).o_bool(self.builder)
|
||||
self.builder.cbranch(condition.auto_load(self.builder),
|
||||
then_block, else_block)
|
||||
|
||||
self.builder.position_at_end(then_block)
|
||||
self.visit_statements(node.body)
|
||||
if not self._bb_terminated():
|
||||
self.builder.branch(merge_block)
|
||||
|
||||
self.builder.position_at_end(else_block)
|
||||
self.visit_statements(node.orelse)
|
||||
if not self._bb_terminated():
|
||||
self.builder.branch(merge_block)
|
||||
|
||||
self.builder.position_at_end(merge_block)
|
||||
|
||||
def _enter_loop_body(self, break_block, continue_block):
|
||||
self._break_stack.append(break_block)
|
||||
self._continue_stack.append(continue_block)
|
||||
self._exception_level_stack.append(0)
|
||||
|
||||
def _leave_loop_body(self):
|
||||
self._exception_level_stack.pop()
|
||||
self._continue_stack.pop()
|
||||
self._break_stack.pop()
|
||||
|
||||
def _visit_stmt_While(self, node):
|
||||
function = self.builder.basic_block.function
|
||||
|
||||
body_block = function.append_basic_block("w_body")
|
||||
else_block = function.append_basic_block("w_else")
|
||||
condition = self.visit_expression(node.test).o_bool(self.builder)
|
||||
self.builder.cbranch(
|
||||
condition.auto_load(self.builder), body_block, else_block)
|
||||
|
||||
continue_block = function.append_basic_block("w_continue")
|
||||
merge_block = function.append_basic_block("w_merge")
|
||||
self.builder.position_at_end(body_block)
|
||||
self._enter_loop_body(merge_block, continue_block)
|
||||
self.visit_statements(node.body)
|
||||
self._leave_loop_body()
|
||||
if not self._bb_terminated():
|
||||
self.builder.branch(continue_block)
|
||||
|
||||
self.builder.position_at_end(continue_block)
|
||||
condition = self.visit_expression(node.test).o_bool(self.builder)
|
||||
self.builder.cbranch(
|
||||
condition.auto_load(self.builder), body_block, merge_block)
|
||||
|
||||
self.builder.position_at_end(else_block)
|
||||
self.visit_statements(node.orelse)
|
||||
if not self._bb_terminated():
|
||||
self.builder.branch(merge_block)
|
||||
|
||||
self.builder.position_at_end(merge_block)
|
||||
|
||||
def _visit_stmt_For(self, node):
|
||||
function = self.builder.basic_block.function
|
||||
|
||||
it = self.visit_expression(node.iter)
|
||||
target = self.visit_expression(node.target)
|
||||
itval = it.get_value_ptr()
|
||||
|
||||
body_block = function.append_basic_block("f_body")
|
||||
else_block = function.append_basic_block("f_else")
|
||||
cont = it.o_next(self.builder)
|
||||
self.builder.cbranch(
|
||||
cont.auto_load(self.builder), body_block, else_block)
|
||||
|
||||
continue_block = function.append_basic_block("f_continue")
|
||||
merge_block = function.append_basic_block("f_merge")
|
||||
self.builder.position_at_end(body_block)
|
||||
target.set_value(self.builder, itval)
|
||||
self._enter_loop_body(merge_block, continue_block)
|
||||
self.visit_statements(node.body)
|
||||
self._leave_loop_body()
|
||||
if not self._bb_terminated():
|
||||
self.builder.branch(continue_block)
|
||||
|
||||
self.builder.position_at_end(continue_block)
|
||||
cont = it.o_next(self.builder)
|
||||
self.builder.cbranch(
|
||||
cont.auto_load(self.builder), body_block, merge_block)
|
||||
|
||||
self.builder.position_at_end(else_block)
|
||||
self.visit_statements(node.orelse)
|
||||
if not self._bb_terminated():
|
||||
self.builder.branch(merge_block)
|
||||
|
||||
self.builder.position_at_end(merge_block)
|
||||
|
||||
def _break_loop_body(self, target_block):
|
||||
exception_levels = self._exception_level_stack[-1]
|
||||
if exception_levels:
|
||||
self.runtime.build_pop(self.builder, exception_levels)
|
||||
self.builder.branch(target_block)
|
||||
|
||||
def _visit_stmt_Break(self, node):
|
||||
self._break_loop_body(self._break_stack[-1])
|
||||
|
||||
def _visit_stmt_Continue(self, node):
|
||||
self._break_loop_body(self._continue_stack[-1])
|
||||
|
||||
def _visit_stmt_Return(self, node):
|
||||
if node.value is None:
|
||||
val = base_types.VNone()
|
||||
else:
|
||||
val = self.visit_expression(node.value)
|
||||
exception_levels = sum(self._exception_level_stack)
|
||||
if exception_levels:
|
||||
self.runtime.build_pop(self.builder, exception_levels)
|
||||
if isinstance(val, base_types.VNone):
|
||||
self.builder.ret_void()
|
||||
else:
|
||||
self.builder.ret(val.auto_load(self.builder))
|
||||
|
||||
def _visit_stmt_Pass(self, node):
|
||||
pass
|
||||
|
||||
def _visit_stmt_Raise(self, node):
|
||||
if self._active_exception_stack:
|
||||
finally_block, propagate, propagate_eid = (
|
||||
self._active_exception_stack[-1])
|
||||
self.builder.store(ll.Constant(ll.IntType(1), 1), propagate)
|
||||
if node.exc is not None:
|
||||
eid = ll.Constant(ll.IntType(32), node.exc.args[0].n)
|
||||
self.builder.store(eid, propagate_eid)
|
||||
self.builder.branch(finally_block)
|
||||
else:
|
||||
eid = ll.Constant(ll.IntType(32), node.exc.args[0].n)
|
||||
self.runtime.build_raise(self.builder, eid)
|
||||
|
||||
def _handle_exception(self, function, finally_block,
|
||||
propagate, propagate_eid, handlers):
|
||||
eid = self.runtime.build_getid(self.builder)
|
||||
self._active_exception_stack.append(
|
||||
(finally_block, propagate, propagate_eid))
|
||||
self.builder.store(ll.Constant(ll.IntType(1), 1), propagate)
|
||||
self.builder.store(eid, propagate_eid)
|
||||
|
||||
for handler in handlers:
|
||||
handled_exc_block = function.append_basic_block("try_exc_h")
|
||||
cont_exc_block = function.append_basic_block("try_exc_c")
|
||||
if handler.type is None:
|
||||
self.builder.branch(handled_exc_block)
|
||||
else:
|
||||
if isinstance(handler.type, ast.Tuple):
|
||||
match = self.builder.icmp_signed(
|
||||
"==", eid,
|
||||
ll.Constant(ll.IntType(32),
|
||||
handler.type.elts[0].args[0].n))
|
||||
for elt in handler.type.elts[1:]:
|
||||
match = self.builder.or_(
|
||||
match,
|
||||
self.builder.icmp_signed(
|
||||
"==", eid,
|
||||
ll.Constant(ll.IntType(32), elt.args[0].n)))
|
||||
else:
|
||||
match = self.builder.icmp_signed(
|
||||
"==", eid,
|
||||
ll.Constant(ll.IntType(32), handler.type.args[0].n))
|
||||
self.builder.cbranch(match, handled_exc_block, cont_exc_block)
|
||||
self.builder.position_at_end(handled_exc_block)
|
||||
self.builder.store(ll.Constant(ll.IntType(1), 0), propagate)
|
||||
self.visit_statements(handler.body)
|
||||
if not self._bb_terminated():
|
||||
self.builder.branch(finally_block)
|
||||
self.builder.position_at_end(cont_exc_block)
|
||||
self.builder.branch(finally_block)
|
||||
|
||||
self._active_exception_stack.pop()
|
||||
|
||||
def _visit_stmt_Try(self, node):
|
||||
function = self.builder.basic_block.function
|
||||
noexc_block = function.append_basic_block("try_noexc")
|
||||
exc_block = function.append_basic_block("try_exc")
|
||||
finally_block = function.append_basic_block("try_finally")
|
||||
|
||||
propagate = self.builder.alloca(ll.IntType(1),
|
||||
name="propagate")
|
||||
self.builder.store(ll.Constant(ll.IntType(1), 0), propagate)
|
||||
propagate_eid = self.builder.alloca(ll.IntType(32),
|
||||
name="propagate_eid")
|
||||
exception_occured = self.runtime.build_catch(self.builder)
|
||||
self.builder.cbranch(exception_occured, exc_block, noexc_block)
|
||||
|
||||
self.builder.position_at_end(noexc_block)
|
||||
self._exception_level_stack[-1] += 1
|
||||
self.visit_statements(node.body)
|
||||
self._exception_level_stack[-1] -= 1
|
||||
if not self._bb_terminated():
|
||||
self.runtime.build_pop(self.builder, 1)
|
||||
self.visit_statements(node.orelse)
|
||||
if not self._bb_terminated():
|
||||
self.builder.branch(finally_block)
|
||||
self.builder.position_at_end(exc_block)
|
||||
self._handle_exception(function, finally_block,
|
||||
propagate, propagate_eid, node.handlers)
|
||||
|
||||
propagate_block = function.append_basic_block("try_propagate")
|
||||
merge_block = function.append_basic_block("try_merge")
|
||||
self.builder.position_at_end(finally_block)
|
||||
self.visit_statements(node.finalbody)
|
||||
if not self._bb_terminated():
|
||||
self.builder.cbranch(
|
||||
self.builder.load(propagate),
|
||||
propagate_block, merge_block)
|
||||
self.builder.position_at_end(propagate_block)
|
||||
self.runtime.build_raise(self.builder, self.builder.load(propagate_eid))
|
||||
self.builder.branch(merge_block)
|
||||
self.builder.position_at_end(merge_block)
|
@ -1,321 +0,0 @@
|
||||
import llvmlite_or1k.ir as ll
|
||||
|
||||
from artiq.py2llvm.values import VGeneric
|
||||
|
||||
|
||||
class VNone(VGeneric):
|
||||
def get_llvm_type(self):
|
||||
return ll.VoidType()
|
||||
|
||||
def alloca(self, builder, name):
|
||||
pass
|
||||
|
||||
def set_const_value(self, builder, v):
|
||||
assert v is None
|
||||
|
||||
def set_value(self, builder, other):
|
||||
if not isinstance(other, VNone):
|
||||
raise TypeError
|
||||
|
||||
def o_bool(self, builder):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.set_const_value(builder, False)
|
||||
return r
|
||||
|
||||
def o_not(self, builder):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.set_const_value(builder, True)
|
||||
return r
|
||||
|
||||
|
||||
class VInt(VGeneric):
|
||||
def __init__(self, nbits=32):
|
||||
VGeneric.__init__(self)
|
||||
self.nbits = nbits
|
||||
|
||||
def get_llvm_type(self):
|
||||
return ll.IntType(self.nbits)
|
||||
|
||||
def __repr__(self):
|
||||
return "<VInt:{}>".format(self.nbits)
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, VInt) and other.nbits == self.nbits
|
||||
|
||||
def merge(self, other):
|
||||
if isinstance(other, VInt) and not isinstance(other, VBool):
|
||||
if other.nbits > self.nbits:
|
||||
self.nbits = other.nbits
|
||||
else:
|
||||
raise TypeError("Incompatible types: {} and {}"
|
||||
.format(repr(self), repr(other)))
|
||||
|
||||
def set_value(self, builder, n):
|
||||
self.auto_store(
|
||||
builder, n.o_intx(self.nbits, builder).auto_load(builder))
|
||||
|
||||
def set_const_value(self, builder, n):
|
||||
self.auto_store(builder, ll.Constant(self.get_llvm_type(), n))
|
||||
|
||||
def o_bool(self, builder, inv=False):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.auto_store(
|
||||
builder, builder.icmp_signed(
|
||||
"==" if inv else "!=",
|
||||
self.auto_load(builder),
|
||||
ll.Constant(self.get_llvm_type(), 0)))
|
||||
return r
|
||||
|
||||
def o_float(self, builder):
|
||||
r = VFloat()
|
||||
if builder is not None:
|
||||
if isinstance(self, VBool):
|
||||
cf = builder.uitofp
|
||||
else:
|
||||
cf = builder.sitofp
|
||||
r.auto_store(builder, cf(self.auto_load(builder),
|
||||
r.get_llvm_type()))
|
||||
return r
|
||||
|
||||
def o_not(self, builder):
|
||||
return self.o_bool(builder, inv=True)
|
||||
|
||||
def o_neg(self, builder):
|
||||
r = VInt(self.nbits)
|
||||
if builder is not None:
|
||||
r.auto_store(
|
||||
builder, builder.mul(
|
||||
self.auto_load(builder),
|
||||
ll.Constant(self.get_llvm_type(), -1)))
|
||||
return r
|
||||
|
||||
def o_intx(self, target_bits, builder):
|
||||
r = VInt(target_bits)
|
||||
if builder is not None:
|
||||
if self.nbits == target_bits:
|
||||
r.auto_store(
|
||||
builder, self.auto_load(builder))
|
||||
if self.nbits > target_bits:
|
||||
r.auto_store(
|
||||
builder, builder.trunc(self.auto_load(builder),
|
||||
r.get_llvm_type()))
|
||||
if self.nbits < target_bits:
|
||||
if isinstance(self, VBool):
|
||||
ef = builder.zext
|
||||
else:
|
||||
ef = builder.sext
|
||||
r.auto_store(
|
||||
builder, ef(self.auto_load(builder),
|
||||
r.get_llvm_type()))
|
||||
return r
|
||||
o_roundx = o_intx
|
||||
|
||||
def o_truediv(self, other, builder):
|
||||
if isinstance(other, VInt):
|
||||
left = self.o_float(builder)
|
||||
right = other.o_float(builder)
|
||||
return left.o_truediv(right, builder)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def _make_vint_binop_method(builder_name, bool_op):
|
||||
def binop_method(self, other, builder):
|
||||
if isinstance(other, VInt):
|
||||
target_bits = max(self.nbits, other.nbits)
|
||||
if not bool_op and target_bits == 1:
|
||||
target_bits = 32
|
||||
if bool_op and target_bits == 1:
|
||||
r = VBool()
|
||||
else:
|
||||
r = VInt(target_bits)
|
||||
if builder is not None:
|
||||
left = self.o_intx(target_bits, builder)
|
||||
right = other.o_intx(target_bits, builder)
|
||||
bf = getattr(builder, builder_name)
|
||||
r.auto_store(
|
||||
builder, bf(left.auto_load(builder),
|
||||
right.auto_load(builder)))
|
||||
return r
|
||||
else:
|
||||
return NotImplemented
|
||||
return binop_method
|
||||
|
||||
for _method_name, _builder_name, _bool_op in (("o_add", "add", False),
|
||||
("o_sub", "sub", False),
|
||||
("o_mul", "mul", False),
|
||||
("o_floordiv", "sdiv", False),
|
||||
("o_mod", "srem", False),
|
||||
("o_and", "and_", True),
|
||||
("o_xor", "xor", True),
|
||||
("o_or", "or_", True)):
|
||||
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name, _bool_op))
|
||||
|
||||
|
||||
def _make_vint_cmp_method(icmp_val):
|
||||
def cmp_method(self, other, builder):
|
||||
if isinstance(other, VInt):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
target_bits = max(self.nbits, other.nbits)
|
||||
left = self.o_intx(target_bits, builder)
|
||||
right = other.o_intx(target_bits, builder)
|
||||
r.auto_store(
|
||||
builder,
|
||||
builder.icmp_signed(
|
||||
icmp_val, left.auto_load(builder),
|
||||
right.auto_load(builder)))
|
||||
return r
|
||||
else:
|
||||
return NotImplemented
|
||||
return cmp_method
|
||||
|
||||
for _method_name, _icmp_val in (("o_eq", "=="),
|
||||
("o_ne", "!="),
|
||||
("o_lt", "<"),
|
||||
("o_le", "<="),
|
||||
("o_gt", ">"),
|
||||
("o_ge", ">=")):
|
||||
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
|
||||
|
||||
|
||||
class VBool(VInt):
|
||||
def __init__(self):
|
||||
VInt.__init__(self, 1)
|
||||
|
||||
__repr__ = VGeneric.__repr__
|
||||
same_type = VGeneric.same_type
|
||||
merge = VGeneric.merge
|
||||
|
||||
def set_const_value(self, builder, b):
|
||||
VInt.set_const_value(self, builder, int(b))
|
||||
|
||||
|
||||
class VFloat(VGeneric):
|
||||
def get_llvm_type(self):
|
||||
return ll.DoubleType()
|
||||
|
||||
def set_value(self, builder, v):
|
||||
if not isinstance(v, VFloat):
|
||||
raise TypeError
|
||||
self.auto_store(builder, v.auto_load(builder))
|
||||
|
||||
def set_const_value(self, builder, n):
|
||||
self.auto_store(builder, ll.Constant(self.get_llvm_type(), n))
|
||||
|
||||
def o_float(self, builder):
|
||||
r = VFloat()
|
||||
if builder is not None:
|
||||
r.auto_store(builder, self.auto_load(builder))
|
||||
return r
|
||||
|
||||
def o_bool(self, builder, inv=False):
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
r.auto_store(
|
||||
builder, builder.fcmp_ordered(
|
||||
"==" if inv else "!=",
|
||||
self.auto_load(builder),
|
||||
ll.Constant(self.get_llvm_type(), 0.0)))
|
||||
return r
|
||||
|
||||
def o_not(self, builder):
|
||||
return self.o_bool(builder, True)
|
||||
|
||||
def o_neg(self, builder):
|
||||
r = VFloat()
|
||||
if builder is not None:
|
||||
r.auto_store(
|
||||
builder, builder.fmul(
|
||||
self.auto_load(builder),
|
||||
ll.Constant(self.get_llvm_type(), -1.0)))
|
||||
return r
|
||||
|
||||
def o_intx(self, target_bits, builder):
|
||||
r = VInt(target_bits)
|
||||
if builder is not None:
|
||||
r.auto_store(builder, builder.fptosi(self.auto_load(builder),
|
||||
r.get_llvm_type()))
|
||||
return r
|
||||
|
||||
def o_roundx(self, target_bits, builder):
|
||||
r = VInt(target_bits)
|
||||
if builder is not None:
|
||||
function = builder.basic_block.function
|
||||
neg_block = function.append_basic_block("fr_neg")
|
||||
merge_block = function.append_basic_block("fr_merge")
|
||||
|
||||
half = VFloat()
|
||||
half.alloca(builder, "half")
|
||||
half.set_const_value(builder, 0.5)
|
||||
|
||||
condition = builder.fcmp_ordered(
|
||||
"<",
|
||||
self.auto_load(builder),
|
||||
ll.Constant(self.get_llvm_type(), 0.0))
|
||||
builder.cbranch(condition, neg_block, merge_block)
|
||||
|
||||
builder.position_at_end(neg_block)
|
||||
half.set_const_value(builder, -0.5)
|
||||
builder.branch(merge_block)
|
||||
|
||||
builder.position_at_end(merge_block)
|
||||
s = builder.fadd(self.auto_load(builder), half.auto_load(builder))
|
||||
r.auto_store(builder, builder.fptosi(s, r.get_llvm_type()))
|
||||
return r
|
||||
|
||||
def o_floordiv(self, other, builder):
|
||||
return self.o_truediv(other, builder).o_int64(builder).o_float(builder)
|
||||
|
||||
def _make_vfloat_binop_method(builder_name, reverse):
|
||||
def binop_method(self, other, builder):
|
||||
if not hasattr(other, "o_float"):
|
||||
return NotImplemented
|
||||
r = VFloat()
|
||||
if builder is not None:
|
||||
left = self.o_float(builder)
|
||||
right = other.o_float(builder)
|
||||
if reverse:
|
||||
left, right = right, left
|
||||
bf = getattr(builder, builder_name)
|
||||
r.auto_store(
|
||||
builder, bf(left.auto_load(builder),
|
||||
right.auto_load(builder)))
|
||||
return r
|
||||
return binop_method
|
||||
|
||||
for _method_name, _builder_name in (("add", "fadd"),
|
||||
("sub", "fsub"),
|
||||
("mul", "fmul"),
|
||||
("truediv", "fdiv")):
|
||||
setattr(VFloat, "o_" + _method_name,
|
||||
_make_vfloat_binop_method(_builder_name, False))
|
||||
setattr(VFloat, "or_" + _method_name,
|
||||
_make_vfloat_binop_method(_builder_name, True))
|
||||
|
||||
|
||||
def _make_vfloat_cmp_method(fcmp_val):
|
||||
def cmp_method(self, other, builder):
|
||||
if not hasattr(other, "o_float"):
|
||||
return NotImplemented
|
||||
r = VBool()
|
||||
if builder is not None:
|
||||
left = self.o_float(builder)
|
||||
right = other.o_float(builder)
|
||||
r.auto_store(
|
||||
builder,
|
||||
builder.fcmp_ordered(
|
||||
fcmp_val, left.auto_load(builder),
|
||||
right.auto_load(builder)))
|
||||
return r
|
||||
return cmp_method
|
||||
|
||||
for _method_name, _fcmp_val in (("o_eq", "=="),
|
||||
("o_ne", "!="),
|
||||
("o_lt", "<"),
|
||||
("o_le", "<="),
|
||||
("o_gt", ">"),
|
||||
("o_ge", ">=")):
|
||||
setattr(VFloat, _method_name, _make_vfloat_cmp_method(_fcmp_val))
|
@ -1,75 +0,0 @@
|
||||
import pythonparser.algorithm
|
||||
from pythonparser import ast
|
||||
from copy import deepcopy
|
||||
|
||||
from artiq.py2llvm.ast_body import Visitor
|
||||
from artiq.py2llvm import base_types
|
||||
|
||||
|
||||
class _TypeScanner(pythonparser.algorithm.Visitor):
|
||||
def __init__(self, env, ns):
|
||||
self.exprv = Visitor(env, ns)
|
||||
|
||||
def _update_target(self, target, val):
|
||||
ns = self.exprv.ns
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in ns:
|
||||
ns[target.id].merge(val)
|
||||
else:
|
||||
ns[target.id] = deepcopy(val)
|
||||
elif isinstance(target, ast.Subscript):
|
||||
target = target.value
|
||||
levels = 0
|
||||
while isinstance(target, ast.Subscript):
|
||||
target = target.value
|
||||
levels += 1
|
||||
if isinstance(target, ast.Name):
|
||||
target_value = ns[target.id]
|
||||
for i in range(levels):
|
||||
target_value = target_value.o_subscript(None, None)
|
||||
target_value.merge_subscript(val)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def visit_Assign(self, node):
|
||||
val = self.exprv.visit_expression(node.value)
|
||||
for target in node.targets:
|
||||
self._update_target(target, val)
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
val = self.exprv.visit_expression(ast.BinOp(
|
||||
op=node.op, left=node.target, right=node.value))
|
||||
self._update_target(node.target, val)
|
||||
|
||||
def visit_For(self, node):
|
||||
it = self.exprv.visit_expression(node.iter)
|
||||
self._update_target(node.target, it.get_value_ptr())
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Return(self, node):
|
||||
if node.value is None:
|
||||
val = base_types.VNone()
|
||||
else:
|
||||
val = self.exprv.visit_expression(node.value)
|
||||
ns = self.exprv.ns
|
||||
if "return" in ns:
|
||||
ns["return"].merge(val)
|
||||
else:
|
||||
ns["return"] = deepcopy(val)
|
||||
|
||||
|
||||
def infer_function_types(env, node, param_types):
|
||||
ns = deepcopy(param_types)
|
||||
ts = _TypeScanner(env, ns)
|
||||
ts.visit(node)
|
||||
while True:
|
||||
prev_ns = deepcopy(ns)
|
||||
ts = _TypeScanner(env, ns)
|
||||
ts.visit(node)
|
||||
if all(v.same_type(prev_ns[k]) for k, v in ns.items()):
|
||||
# no more promotions - completed
|
||||
if "return" not in ns:
|
||||
ns["return"] = base_types.VNone()
|
||||
return ns
|
@ -1,51 +0,0 @@
|
||||
from artiq.py2llvm.values import operators
|
||||
from artiq.py2llvm.base_types import VInt
|
||||
|
||||
class IRange:
|
||||
def __init__(self, builder, args):
|
||||
minimum, step = None, None
|
||||
if len(args) == 1:
|
||||
maximum = args[0]
|
||||
elif len(args) == 2:
|
||||
minimum, maximum = args
|
||||
else:
|
||||
minimum, maximum, step = args
|
||||
if minimum is None:
|
||||
minimum = VInt()
|
||||
if builder is not None:
|
||||
minimum.set_const_value(builder, 0)
|
||||
if step is None:
|
||||
step = VInt()
|
||||
if builder is not None:
|
||||
step.set_const_value(builder, 1)
|
||||
|
||||
self._counter = minimum.new()
|
||||
self._counter.merge(maximum)
|
||||
self._counter.merge(step)
|
||||
self._minimum = self._counter.new()
|
||||
self._maximum = self._counter.new()
|
||||
self._step = self._counter.new()
|
||||
|
||||
if builder is not None:
|
||||
self._minimum.alloca(builder, "irange_min")
|
||||
self._maximum.alloca(builder, "irange_max")
|
||||
self._step.alloca(builder, "irange_step")
|
||||
self._counter.alloca(builder, "irange_count")
|
||||
|
||||
self._minimum.set_value(builder, minimum)
|
||||
self._maximum.set_value(builder, maximum)
|
||||
self._step.set_value(builder, step)
|
||||
|
||||
counter_init = operators.sub(self._minimum, self._step, builder)
|
||||
self._counter.set_value(builder, counter_init)
|
||||
|
||||
# must be a pointer value that can be dereferenced anytime
|
||||
# to get the current value of the iterator
|
||||
def get_value_ptr(self):
|
||||
return self._counter
|
||||
|
||||
def o_next(self, builder):
|
||||
self._counter.set_value(
|
||||
builder,
|
||||
operators.add(self._counter, self._step, builder))
|
||||
return operators.lt(self._counter, self._maximum, builder)
|
@ -1,72 +0,0 @@
|
||||
import llvmlite_or1k.ir as ll
|
||||
|
||||
from artiq.py2llvm.values import VGeneric
|
||||
from artiq.py2llvm.base_types import VInt, VNone
|
||||
|
||||
|
||||
class VList(VGeneric):
|
||||
def __init__(self, el_type, alloc_count):
|
||||
VGeneric.__init__(self)
|
||||
self.el_type = el_type
|
||||
self.alloc_count = alloc_count
|
||||
|
||||
def get_llvm_type(self):
|
||||
count = 0 if self.alloc_count is None else self.alloc_count
|
||||
if isinstance(self.el_type, VNone):
|
||||
return ll.LiteralStructType([ll.IntType(32)])
|
||||
else:
|
||||
return ll.LiteralStructType([
|
||||
ll.IntType(32), ll.ArrayType(self.el_type.get_llvm_type(),
|
||||
count)])
|
||||
|
||||
def __repr__(self):
|
||||
return "<VList:{} x{}>".format(
|
||||
repr(self.el_type),
|
||||
"?" if self.alloc_count is None else self.alloc_count)
|
||||
|
||||
def same_type(self, other):
|
||||
return (isinstance(other, VList)
|
||||
and self.el_type.same_type(other.el_type))
|
||||
|
||||
def merge(self, other):
|
||||
if isinstance(other, VList):
|
||||
if self.alloc_count:
|
||||
if other.alloc_count:
|
||||
self.el_type.merge(other.el_type)
|
||||
if self.alloc_count < other.alloc_count:
|
||||
self.alloc_count = other.alloc_count
|
||||
else:
|
||||
self.el_type = other.el_type.new()
|
||||
self.alloc_count = other.alloc_count
|
||||
else:
|
||||
raise TypeError("Incompatible types: {} and {}"
|
||||
.format(repr(self), repr(other)))
|
||||
|
||||
def merge_subscript(self, other):
|
||||
self.el_type.merge(other)
|
||||
|
||||
def set_count(self, builder, count):
|
||||
count_ptr = builder.gep(self.llvm_value, [
|
||||
ll.Constant(ll.IntType(32), 0),
|
||||
ll.Constant(ll.IntType(32), 0)])
|
||||
builder.store(ll.Constant(ll.IntType(32), count), count_ptr)
|
||||
|
||||
def o_len(self, builder):
|
||||
r = VInt()
|
||||
if builder is not None:
|
||||
count_ptr = builder.gep(self.llvm_value, [
|
||||
ll.Constant(ll.IntType(32), 0),
|
||||
ll.Constant(ll.IntType(32), 0)])
|
||||
r.auto_store(builder, builder.load(count_ptr))
|
||||
return r
|
||||
|
||||
def o_subscript(self, index, builder):
|
||||
r = self.el_type.new()
|
||||
if builder is not None and not isinstance(r, VNone):
|
||||
index = index.o_int(builder).auto_load(builder)
|
||||
ssa_r = builder.gep(self.llvm_value, [
|
||||
ll.Constant(ll.IntType(32), 0),
|
||||
ll.Constant(ll.IntType(32), 1),
|
||||
index])
|
||||
r.auto_store(builder, ssa_r)
|
||||
return r
|
@ -1,62 +0,0 @@
|
||||
import llvmlite_or1k.ir as ll
|
||||
import llvmlite_or1k.binding as llvm
|
||||
|
||||
from artiq.py2llvm import infer_types, ast_body, base_types, fractions, tools
|
||||
|
||||
|
||||
class Module:
|
||||
def __init__(self, runtime=None):
|
||||
self.llvm_module = ll.Module("main")
|
||||
self.runtime = runtime
|
||||
|
||||
if self.runtime is not None:
|
||||
self.runtime.init_module(self)
|
||||
fractions.init_module(self)
|
||||
|
||||
def finalize(self):
|
||||
self.llvm_module_ref = llvm.parse_assembly(str(self.llvm_module))
|
||||
pmb = llvm.create_pass_manager_builder()
|
||||
pmb.opt_level = 2
|
||||
pm = llvm.create_module_pass_manager()
|
||||
pmb.populate(pm)
|
||||
pm.run(self.llvm_module_ref)
|
||||
|
||||
def get_ee(self):
|
||||
self.finalize()
|
||||
tm = llvm.Target.from_default_triple().create_target_machine()
|
||||
ee = llvm.create_mcjit_compiler(self.llvm_module_ref, tm)
|
||||
ee.finalize_object()
|
||||
return ee
|
||||
|
||||
def emit_object(self):
|
||||
self.finalize()
|
||||
return self.runtime.emit_object()
|
||||
|
||||
def compile_function(self, func_def, param_types):
|
||||
ns = infer_types.infer_function_types(self.runtime, func_def, param_types)
|
||||
retval = ns["return"]
|
||||
|
||||
function_type = ll.FunctionType(retval.get_llvm_type(),
|
||||
[ns[arg.arg].get_llvm_type() for arg in func_def.args.args])
|
||||
function = ll.Function(self.llvm_module, function_type, func_def.name)
|
||||
bb = function.append_basic_block("entry")
|
||||
builder = ll.IRBuilder()
|
||||
builder.position_at_end(bb)
|
||||
|
||||
for arg_ast, arg_llvm in zip(func_def.args.args, function.args):
|
||||
arg_llvm.name = arg_ast.arg
|
||||
for k, v in ns.items():
|
||||
v.alloca(builder, k)
|
||||
for arg_ast, arg_llvm in zip(func_def.args.args, function.args):
|
||||
ns[arg_ast.arg].auto_store(builder, arg_llvm)
|
||||
|
||||
visitor = ast_body.Visitor(self.runtime, ns, builder)
|
||||
visitor.visit_statements(func_def.body)
|
||||
|
||||
if not tools.is_terminated(builder.basic_block):
|
||||
if isinstance(retval, base_types.VNone):
|
||||
builder.ret_void()
|
||||
else:
|
||||
builder.ret(retval.auto_load(builder))
|
||||
|
||||
return function, retval
|
169
artiq/py2llvm_old/test/py2llvm.py
Normal file
169
artiq/py2llvm_old/test/py2llvm.py
Normal file
@ -0,0 +1,169 @@
|
||||
import unittest
|
||||
from pythonparser import parse, ast
|
||||
import inspect
|
||||
from fractions import Fraction
|
||||
from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double
|
||||
import struct
|
||||
|
||||
import llvmlite_or1k.binding as llvm
|
||||
|
||||
from artiq.language.core import int64
|
||||
from artiq.py2llvm.infer_types import infer_function_types
|
||||
from artiq.py2llvm import base_types, lists
|
||||
from artiq.py2llvm.module import Module
|
||||
|
||||
def simplify_encode(a, b):
|
||||
f = Fraction(a, b)
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def frac_arith_encode(op, a, b, c, d):
|
||||
if op == 0:
|
||||
f = Fraction(a, b) - Fraction(c, d)
|
||||
elif op == 1:
|
||||
f = Fraction(a, b) + Fraction(c, d)
|
||||
elif op == 2:
|
||||
f = Fraction(a, b) * Fraction(c, d)
|
||||
else:
|
||||
f = Fraction(a, b) / Fraction(c, d)
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def frac_arith_encode_int(op, a, b, x):
|
||||
if op == 0:
|
||||
f = Fraction(a, b) - x
|
||||
elif op == 1:
|
||||
f = Fraction(a, b) + x
|
||||
elif op == 2:
|
||||
f = Fraction(a, b) * x
|
||||
else:
|
||||
f = Fraction(a, b) / x
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def frac_arith_encode_int_rev(op, a, b, x):
|
||||
if op == 0:
|
||||
f = x - Fraction(a, b)
|
||||
elif op == 1:
|
||||
f = x + Fraction(a, b)
|
||||
elif op == 2:
|
||||
f = x * Fraction(a, b)
|
||||
else:
|
||||
f = x / Fraction(a, b)
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def frac_arith_float(op, a, b, x):
|
||||
if op == 0:
|
||||
return Fraction(a, b) - x
|
||||
elif op == 1:
|
||||
return Fraction(a, b) + x
|
||||
elif op == 2:
|
||||
return Fraction(a, b) * x
|
||||
else:
|
||||
return Fraction(a, b) / x
|
||||
|
||||
|
||||
def frac_arith_float_rev(op, a, b, x):
|
||||
if op == 0:
|
||||
return x - Fraction(a, b)
|
||||
elif op == 1:
|
||||
return x + Fraction(a, b)
|
||||
elif op == 2:
|
||||
return x * Fraction(a, b)
|
||||
else:
|
||||
return x / Fraction(a, b)
|
||||
|
||||
|
||||
class CodeGenCase(unittest.TestCase):
|
||||
def test_frac_simplify(self):
|
||||
simplify_encode_c = CompiledFunction(
|
||||
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
self.assertEqual(
|
||||
simplify_encode_c(a, b), simplify_encode(a, b))
|
||||
|
||||
def _test_frac_arith(self, op):
|
||||
frac_arith_encode_c = CompiledFunction(
|
||||
frac_arith_encode, {
|
||||
"op": base_types.VInt(),
|
||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
||||
"c": base_types.VInt(), "d": base_types.VInt()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
for c in _test_range():
|
||||
for d in _test_range():
|
||||
self.assertEqual(
|
||||
frac_arith_encode_c(op, a, b, c, d),
|
||||
frac_arith_encode(op, a, b, c, d))
|
||||
|
||||
def test_frac_add(self):
|
||||
self._test_frac_arith(0)
|
||||
|
||||
def test_frac_sub(self):
|
||||
self._test_frac_arith(1)
|
||||
|
||||
def test_frac_mul(self):
|
||||
self._test_frac_arith(2)
|
||||
|
||||
def test_frac_div(self):
|
||||
self._test_frac_arith(3)
|
||||
|
||||
def _test_frac_arith_int(self, op, rev):
|
||||
f = frac_arith_encode_int_rev if rev else frac_arith_encode_int
|
||||
f_c = CompiledFunction(f, {
|
||||
"op": base_types.VInt(),
|
||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
||||
"x": base_types.VInt()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
for x in _test_range():
|
||||
self.assertEqual(
|
||||
f_c(op, a, b, x),
|
||||
f(op, a, b, x))
|
||||
|
||||
def test_frac_add_int(self):
|
||||
self._test_frac_arith_int(0, False)
|
||||
self._test_frac_arith_int(0, True)
|
||||
|
||||
def test_frac_sub_int(self):
|
||||
self._test_frac_arith_int(1, False)
|
||||
self._test_frac_arith_int(1, True)
|
||||
|
||||
def test_frac_mul_int(self):
|
||||
self._test_frac_arith_int(2, False)
|
||||
self._test_frac_arith_int(2, True)
|
||||
|
||||
def test_frac_div_int(self):
|
||||
self._test_frac_arith_int(3, False)
|
||||
self._test_frac_arith_int(3, True)
|
||||
|
||||
def _test_frac_arith_float(self, op, rev):
|
||||
f = frac_arith_float_rev if rev else frac_arith_float
|
||||
f_c = CompiledFunction(f, {
|
||||
"op": base_types.VInt(),
|
||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
||||
"x": base_types.VFloat()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
for x in _test_range():
|
||||
self.assertAlmostEqual(
|
||||
f_c(op, a, b, x/2),
|
||||
f(op, a, b, x/2))
|
||||
|
||||
def test_frac_add_float(self):
|
||||
self._test_frac_arith_float(0, False)
|
||||
self._test_frac_arith_float(0, True)
|
||||
|
||||
def test_frac_sub_float(self):
|
||||
self._test_frac_arith_float(1, False)
|
||||
self._test_frac_arith_float(1, True)
|
||||
|
||||
def test_frac_mul_float(self):
|
||||
self._test_frac_arith_float(2, False)
|
||||
self._test_frac_arith_float(2, True)
|
||||
|
||||
def test_frac_div_float(self):
|
||||
self._test_frac_arith_float(3, False)
|
||||
self._test_frac_arith_float(3, True)
|
@ -1,5 +0,0 @@
|
||||
import llvmlite_or1k.ir as ll
|
||||
|
||||
def is_terminated(basic_block):
|
||||
return (basic_block.instructions
|
||||
and isinstance(basic_block.instructions[-1], ll.Terminator))
|
@ -1,94 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from copy import copy
|
||||
|
||||
import llvmlite_or1k.ir as ll
|
||||
|
||||
|
||||
class VGeneric:
|
||||
def __init__(self):
|
||||
self.llvm_value = None
|
||||
|
||||
def new(self):
|
||||
r = copy(self)
|
||||
r.llvm_value = None
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
return "<" + self.__class__.__name__ + ">"
|
||||
|
||||
def same_type(self, other):
|
||||
return isinstance(other, self.__class__)
|
||||
|
||||
def merge(self, other):
|
||||
if not self.same_type(other):
|
||||
raise TypeError("Incompatible types: {} and {}"
|
||||
.format(repr(self), repr(other)))
|
||||
|
||||
def auto_load(self, builder):
|
||||
if isinstance(self.llvm_value.type, ll.PointerType):
|
||||
return builder.load(self.llvm_value)
|
||||
else:
|
||||
return self.llvm_value
|
||||
|
||||
def auto_store(self, builder, llvm_value):
|
||||
if self.llvm_value is None:
|
||||
self.llvm_value = llvm_value
|
||||
elif isinstance(self.llvm_value.type, ll.PointerType):
|
||||
builder.store(llvm_value, self.llvm_value)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Attempted to set LLVM SSA value multiple times")
|
||||
|
||||
def alloca(self, builder, name=""):
|
||||
if self.llvm_value is not None:
|
||||
raise RuntimeError("Attempted to alloca existing LLVM value "+name)
|
||||
self.llvm_value = builder.alloca(self.get_llvm_type(), name=name)
|
||||
|
||||
def o_int(self, builder):
|
||||
return self.o_intx(32, builder)
|
||||
|
||||
def o_int64(self, builder):
|
||||
return self.o_intx(64, builder)
|
||||
|
||||
def o_round(self, builder):
|
||||
return self.o_roundx(32, builder)
|
||||
|
||||
def o_round64(self, builder):
|
||||
return self.o_roundx(64, builder)
|
||||
|
||||
|
||||
def _make_binary_operator(op_name):
|
||||
def op(l, r, builder):
|
||||
try:
|
||||
opf = getattr(l, "o_" + op_name)
|
||||
except AttributeError:
|
||||
result = NotImplemented
|
||||
else:
|
||||
result = opf(r, builder)
|
||||
if result is NotImplemented:
|
||||
try:
|
||||
ropf = getattr(r, "or_" + op_name)
|
||||
except AttributeError:
|
||||
result = NotImplemented
|
||||
else:
|
||||
result = ropf(l, builder)
|
||||
if result is NotImplemented:
|
||||
raise TypeError(
|
||||
"Unsupported operand types for {}: {} and {}"
|
||||
.format(op_name, type(l).__name__, type(r).__name__))
|
||||
return result
|
||||
return op
|
||||
|
||||
|
||||
def _make_operators():
|
||||
d = dict()
|
||||
for op_name in ("add", "sub", "mul",
|
||||
"truediv", "floordiv", "mod",
|
||||
"pow", "lshift", "rshift", "xor",
|
||||
"eq", "ne", "lt", "le", "gt", "ge"):
|
||||
d[op_name] = _make_binary_operator(op_name)
|
||||
d["and_"] = _make_binary_operator("and")
|
||||
d["or_"] = _make_binary_operator("or")
|
||||
return SimpleNamespace(**d)
|
||||
|
||||
operators = _make_operators()
|
@ -1,372 +0,0 @@
|
||||
import unittest
|
||||
from pythonparser import parse, ast
|
||||
import inspect
|
||||
from fractions import Fraction
|
||||
from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double
|
||||
import struct
|
||||
|
||||
import llvmlite_or1k.binding as llvm
|
||||
|
||||
from artiq.language.core import int64
|
||||
from artiq.py2llvm.infer_types import infer_function_types
|
||||
from artiq.py2llvm import base_types, lists
|
||||
from artiq.py2llvm.module import Module
|
||||
|
||||
|
||||
llvm.initialize()
|
||||
llvm.initialize_native_target()
|
||||
llvm.initialize_native_asmprinter()
|
||||
if struct.calcsize("P") < 8:
|
||||
from ctypes import _dlopen, RTLD_GLOBAL
|
||||
_dlopen("libgcc_s.so", RTLD_GLOBAL)
|
||||
|
||||
|
||||
def _base_types(choice):
|
||||
a = 2 # promoted later to int64
|
||||
b = a + 1 # initially int32, becomes int64 after a is promoted
|
||||
c = b//2 # initially int32, becomes int64 after b is promoted
|
||||
d = 4 and 5 # stays int32
|
||||
x = int64(7)
|
||||
a += x # promotes a to int64
|
||||
foo = True | True or False
|
||||
bar = None
|
||||
myf = 4.5
|
||||
myf2 = myf + x
|
||||
|
||||
if choice and foo and not bar:
|
||||
return d
|
||||
elif myf2:
|
||||
return x + c
|
||||
else:
|
||||
return int64(8)
|
||||
|
||||
|
||||
def _build_function_types(f):
|
||||
return infer_function_types(
|
||||
None, parse(inspect.getsource(f)),
|
||||
dict())
|
||||
|
||||
|
||||
class FunctionBaseTypesCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.ns = _build_function_types(_base_types)
|
||||
|
||||
def test_simple_types(self):
|
||||
self.assertIsInstance(self.ns["foo"], base_types.VBool)
|
||||
self.assertIsInstance(self.ns["bar"], base_types.VNone)
|
||||
self.assertIsInstance(self.ns["d"], base_types.VInt)
|
||||
self.assertEqual(self.ns["d"].nbits, 32)
|
||||
self.assertIsInstance(self.ns["x"], base_types.VInt)
|
||||
self.assertEqual(self.ns["x"].nbits, 64)
|
||||
self.assertIsInstance(self.ns["myf"], base_types.VFloat)
|
||||
self.assertIsInstance(self.ns["myf2"], base_types.VFloat)
|
||||
|
||||
def test_promotion(self):
|
||||
for v in "abc":
|
||||
self.assertIsInstance(self.ns[v], base_types.VInt)
|
||||
self.assertEqual(self.ns[v].nbits, 64)
|
||||
|
||||
def test_return(self):
|
||||
self.assertIsInstance(self.ns["return"], base_types.VInt)
|
||||
self.assertEqual(self.ns["return"].nbits, 64)
|
||||
|
||||
|
||||
def test_list_types():
|
||||
a = [0, 0, 0, 0, 0]
|
||||
for i in range(2):
|
||||
a[i] = int64(8)
|
||||
return a
|
||||
|
||||
|
||||
class FunctionListTypesCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.ns = _build_function_types(test_list_types)
|
||||
|
||||
def test_list_types(self):
|
||||
self.assertIsInstance(self.ns["a"], lists.VList)
|
||||
self.assertIsInstance(self.ns["a"].el_type, base_types.VInt)
|
||||
self.assertEqual(self.ns["a"].el_type.nbits, 64)
|
||||
self.assertEqual(self.ns["a"].alloc_count, 5)
|
||||
self.assertIsInstance(self.ns["i"], base_types.VInt)
|
||||
self.assertEqual(self.ns["i"].nbits, 32)
|
||||
|
||||
|
||||
def _value_to_ctype(v):
|
||||
if isinstance(v, base_types.VBool):
|
||||
return c_int
|
||||
elif isinstance(v, base_types.VInt):
|
||||
if v.nbits == 32:
|
||||
return c_int32
|
||||
elif v.nbits == 64:
|
||||
return c_int64
|
||||
else:
|
||||
raise NotImplementedError(str(v))
|
||||
elif isinstance(v, base_types.VFloat):
|
||||
return c_double
|
||||
else:
|
||||
raise NotImplementedError(str(v))
|
||||
|
||||
|
||||
class CompiledFunction:
|
||||
def __init__(self, function, param_types):
|
||||
module = Module()
|
||||
|
||||
func_def = parse(inspect.getsource(function)).body[0]
|
||||
function, retval = module.compile_function(func_def, param_types)
|
||||
argvals = [param_types[arg.arg] for arg in func_def.args.args]
|
||||
|
||||
ee = module.get_ee()
|
||||
cfptr = ee.get_pointer_to_global(
|
||||
module.llvm_module_ref.get_function(function.name))
|
||||
retval_ctype = _value_to_ctype(retval)
|
||||
argval_ctypes = [_value_to_ctype(argval) for argval in argvals]
|
||||
self.cfunc = CFUNCTYPE(retval_ctype, *argval_ctypes)(cfptr)
|
||||
|
||||
# HACK: prevent garbage collection of self.cfunc internals
|
||||
self.ee = ee
|
||||
|
||||
def __call__(self, *args):
|
||||
return self.cfunc(*args)
|
||||
|
||||
|
||||
def arith(op, a, b):
|
||||
if op == 0:
|
||||
return a + b
|
||||
elif op == 1:
|
||||
return a - b
|
||||
elif op == 2:
|
||||
return a * b
|
||||
else:
|
||||
return a / b
|
||||
|
||||
|
||||
def is_prime(x):
|
||||
d = 2
|
||||
while d*d <= x:
|
||||
if not x % d:
|
||||
return False
|
||||
d += 1
|
||||
return True
|
||||
|
||||
|
||||
def simplify_encode(a, b):
|
||||
f = Fraction(a, b)
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def frac_arith_encode(op, a, b, c, d):
|
||||
if op == 0:
|
||||
f = Fraction(a, b) - Fraction(c, d)
|
||||
elif op == 1:
|
||||
f = Fraction(a, b) + Fraction(c, d)
|
||||
elif op == 2:
|
||||
f = Fraction(a, b) * Fraction(c, d)
|
||||
else:
|
||||
f = Fraction(a, b) / Fraction(c, d)
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def frac_arith_encode_int(op, a, b, x):
|
||||
if op == 0:
|
||||
f = Fraction(a, b) - x
|
||||
elif op == 1:
|
||||
f = Fraction(a, b) + x
|
||||
elif op == 2:
|
||||
f = Fraction(a, b) * x
|
||||
else:
|
||||
f = Fraction(a, b) / x
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def frac_arith_encode_int_rev(op, a, b, x):
|
||||
if op == 0:
|
||||
f = x - Fraction(a, b)
|
||||
elif op == 1:
|
||||
f = x + Fraction(a, b)
|
||||
elif op == 2:
|
||||
f = x * Fraction(a, b)
|
||||
else:
|
||||
f = x / Fraction(a, b)
|
||||
return f.numerator*1000 + f.denominator
|
||||
|
||||
|
||||
def frac_arith_float(op, a, b, x):
|
||||
if op == 0:
|
||||
return Fraction(a, b) - x
|
||||
elif op == 1:
|
||||
return Fraction(a, b) + x
|
||||
elif op == 2:
|
||||
return Fraction(a, b) * x
|
||||
else:
|
||||
return Fraction(a, b) / x
|
||||
|
||||
|
||||
def frac_arith_float_rev(op, a, b, x):
|
||||
if op == 0:
|
||||
return x - Fraction(a, b)
|
||||
elif op == 1:
|
||||
return x + Fraction(a, b)
|
||||
elif op == 2:
|
||||
return x * Fraction(a, b)
|
||||
else:
|
||||
return x / Fraction(a, b)
|
||||
|
||||
|
||||
def list_test():
|
||||
x = 80
|
||||
a = [3 for x in range(7)]
|
||||
b = [1, 2, 4, 5, 4, 0, 5]
|
||||
a[3] = x
|
||||
a[0] += 6
|
||||
a[1] = b[1] + b[2]
|
||||
|
||||
acc = 0
|
||||
for i in range(7):
|
||||
if i and a[i]:
|
||||
acc += 1
|
||||
acc += a[i]
|
||||
return acc
|
||||
|
||||
|
||||
def corner_cases():
|
||||
two = True + True - (not True)
|
||||
three = two + True//True - False*True
|
||||
two_float = three - True/True
|
||||
one_float = two_float - (1.0 == bool(0.1))
|
||||
zero = int(one_float) + round(-0.6)
|
||||
eleven_float = zero + 5.5//0.5
|
||||
ten_float = eleven_float + round(Fraction(2, -3))
|
||||
return ten_float
|
||||
|
||||
|
||||
def _test_range():
|
||||
for i in range(5, 10):
|
||||
yield i
|
||||
yield -i
|
||||
|
||||
|
||||
class CodeGenCase(unittest.TestCase):
|
||||
def _test_float_arith(self, op):
|
||||
arith_c = CompiledFunction(arith, {
|
||||
"op": base_types.VInt(),
|
||||
"a": base_types.VFloat(), "b": base_types.VFloat()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
self.assertEqual(arith_c(op, a/2, b/2), arith(op, a/2, b/2))
|
||||
|
||||
def test_float_add(self):
|
||||
self._test_float_arith(0)
|
||||
|
||||
def test_float_sub(self):
|
||||
self._test_float_arith(1)
|
||||
|
||||
def test_float_mul(self):
|
||||
self._test_float_arith(2)
|
||||
|
||||
def test_float_div(self):
|
||||
self._test_float_arith(3)
|
||||
|
||||
def test_is_prime(self):
|
||||
is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()})
|
||||
for i in range(200):
|
||||
self.assertEqual(is_prime_c(i), is_prime(i))
|
||||
|
||||
def test_frac_simplify(self):
|
||||
simplify_encode_c = CompiledFunction(
|
||||
simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
self.assertEqual(
|
||||
simplify_encode_c(a, b), simplify_encode(a, b))
|
||||
|
||||
def _test_frac_arith(self, op):
|
||||
frac_arith_encode_c = CompiledFunction(
|
||||
frac_arith_encode, {
|
||||
"op": base_types.VInt(),
|
||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
||||
"c": base_types.VInt(), "d": base_types.VInt()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
for c in _test_range():
|
||||
for d in _test_range():
|
||||
self.assertEqual(
|
||||
frac_arith_encode_c(op, a, b, c, d),
|
||||
frac_arith_encode(op, a, b, c, d))
|
||||
|
||||
def test_frac_add(self):
|
||||
self._test_frac_arith(0)
|
||||
|
||||
def test_frac_sub(self):
|
||||
self._test_frac_arith(1)
|
||||
|
||||
def test_frac_mul(self):
|
||||
self._test_frac_arith(2)
|
||||
|
||||
def test_frac_div(self):
|
||||
self._test_frac_arith(3)
|
||||
|
||||
def _test_frac_arith_int(self, op, rev):
|
||||
f = frac_arith_encode_int_rev if rev else frac_arith_encode_int
|
||||
f_c = CompiledFunction(f, {
|
||||
"op": base_types.VInt(),
|
||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
||||
"x": base_types.VInt()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
for x in _test_range():
|
||||
self.assertEqual(
|
||||
f_c(op, a, b, x),
|
||||
f(op, a, b, x))
|
||||
|
||||
def test_frac_add_int(self):
|
||||
self._test_frac_arith_int(0, False)
|
||||
self._test_frac_arith_int(0, True)
|
||||
|
||||
def test_frac_sub_int(self):
|
||||
self._test_frac_arith_int(1, False)
|
||||
self._test_frac_arith_int(1, True)
|
||||
|
||||
def test_frac_mul_int(self):
|
||||
self._test_frac_arith_int(2, False)
|
||||
self._test_frac_arith_int(2, True)
|
||||
|
||||
def test_frac_div_int(self):
|
||||
self._test_frac_arith_int(3, False)
|
||||
self._test_frac_arith_int(3, True)
|
||||
|
||||
def _test_frac_arith_float(self, op, rev):
|
||||
f = frac_arith_float_rev if rev else frac_arith_float
|
||||
f_c = CompiledFunction(f, {
|
||||
"op": base_types.VInt(),
|
||||
"a": base_types.VInt(), "b": base_types.VInt(),
|
||||
"x": base_types.VFloat()})
|
||||
for a in _test_range():
|
||||
for b in _test_range():
|
||||
for x in _test_range():
|
||||
self.assertAlmostEqual(
|
||||
f_c(op, a, b, x/2),
|
||||
f(op, a, b, x/2))
|
||||
|
||||
def test_frac_add_float(self):
|
||||
self._test_frac_arith_float(0, False)
|
||||
self._test_frac_arith_float(0, True)
|
||||
|
||||
def test_frac_sub_float(self):
|
||||
self._test_frac_arith_float(1, False)
|
||||
self._test_frac_arith_float(1, True)
|
||||
|
||||
def test_frac_mul_float(self):
|
||||
self._test_frac_arith_float(2, False)
|
||||
self._test_frac_arith_float(2, True)
|
||||
|
||||
def test_frac_div_float(self):
|
||||
self._test_frac_arith_float(3, False)
|
||||
self._test_frac_arith_float(3, True)
|
||||
|
||||
def test_list(self):
|
||||
list_test_c = CompiledFunction(list_test, dict())
|
||||
self.assertEqual(list_test_c(), list_test())
|
||||
|
||||
def test_corner_cases(self):
|
||||
corner_cases_c = CompiledFunction(corner_cases, dict())
|
||||
self.assertEqual(corner_cases_c(), corner_cases())
|
@ -1,44 +0,0 @@
|
||||
import unittest
|
||||
import ast
|
||||
|
||||
from artiq import ns
|
||||
from artiq.coredevice import comm_dummy, core
|
||||
from artiq.transforms.unparse import unparse
|
||||
|
||||
|
||||
optimize_in = """
|
||||
|
||||
def run():
|
||||
dds_sysclk = Fraction(1000000000, 1)
|
||||
n = seconds_to_mu((1.2345 * Fraction(1, 1000000000)))
|
||||
with sequential:
|
||||
frequency = 345 * Fraction(1000000, 1)
|
||||
frequency_to_ftw_return = int((((2 ** 32) * frequency) / dds_sysclk))
|
||||
ftw = frequency_to_ftw_return
|
||||
with sequential:
|
||||
ftw2 = ftw
|
||||
ftw_to_frequency_return = ((ftw2 * dds_sysclk) / (2 ** 32))
|
||||
f = ftw_to_frequency_return
|
||||
phi = ((1000 * mu_to_seconds(n)) * f)
|
||||
do_something(int(phi))
|
||||
"""
|
||||
|
||||
optimize_out = """
|
||||
|
||||
def run():
|
||||
now = syscall('now_init')
|
||||
try:
|
||||
do_something(344)
|
||||
finally:
|
||||
syscall('now_save', now)
|
||||
"""
|
||||
|
||||
|
||||
class OptimizeCase(unittest.TestCase):
|
||||
def test_optimize(self):
|
||||
dmgr = dict()
|
||||
dmgr["comm"] = comm_dummy.Comm(dmgr)
|
||||
coredev = core.Core(dmgr, ref_period=1*ns)
|
||||
func_def = ast.parse(optimize_in).body[0]
|
||||
coredev.transform_stack(func_def, dict(), dict())
|
||||
self.assertEqual(unparse(func_def), optimize_out)
|
@ -1,156 +0,0 @@
|
||||
import ast
|
||||
import operator
|
||||
from fractions import Fraction
|
||||
|
||||
from artiq.transforms.tools import *
|
||||
from artiq.language.core import int64, round64
|
||||
|
||||
|
||||
_ast_unops = {
|
||||
ast.Invert: operator.inv,
|
||||
ast.Not: operator.not_,
|
||||
ast.UAdd: operator.pos,
|
||||
ast.USub: operator.neg
|
||||
}
|
||||
|
||||
|
||||
_ast_binops = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.FloorDiv: operator.floordiv,
|
||||
ast.Mod: operator.mod,
|
||||
ast.Pow: operator.pow,
|
||||
ast.LShift: operator.lshift,
|
||||
ast.RShift: operator.rshift,
|
||||
ast.BitOr: operator.or_,
|
||||
ast.BitXor: operator.xor,
|
||||
ast.BitAnd: operator.and_
|
||||
}
|
||||
|
||||
_ast_cmpops = {
|
||||
ast.Eq: operator.eq,
|
||||
ast.NotEq: operator.ne,
|
||||
ast.Lt: operator.lt,
|
||||
ast.LtE: operator.le,
|
||||
ast.Gt: operator.gt,
|
||||
ast.GtE: operator.ge
|
||||
}
|
||||
|
||||
_ast_boolops = {
|
||||
ast.Or: lambda x, y: x or y,
|
||||
ast.And: lambda x, y: x and y
|
||||
}
|
||||
|
||||
|
||||
class _ConstantFolder(ast.NodeTransformer):
|
||||
def visit_UnaryOp(self, node):
|
||||
self.generic_visit(node)
|
||||
try:
|
||||
operand = eval_constant(node.operand)
|
||||
except NotConstant:
|
||||
return node
|
||||
try:
|
||||
op = _ast_unops[type(node.op)]
|
||||
except KeyError:
|
||||
return node
|
||||
try:
|
||||
result = value_to_ast(op(operand))
|
||||
except:
|
||||
return node
|
||||
return ast.copy_location(result, node)
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
self.generic_visit(node)
|
||||
try:
|
||||
left, right = eval_constant(node.left), eval_constant(node.right)
|
||||
except NotConstant:
|
||||
return node
|
||||
try:
|
||||
op = _ast_binops[type(node.op)]
|
||||
except KeyError:
|
||||
return node
|
||||
try:
|
||||
result = value_to_ast(op(left, right))
|
||||
except:
|
||||
return node
|
||||
return ast.copy_location(result, node)
|
||||
|
||||
def visit_Compare(self, node):
|
||||
self.generic_visit(node)
|
||||
try:
|
||||
operands = [eval_constant(node.left)]
|
||||
except NotConstant:
|
||||
operands = [node.left]
|
||||
ops = []
|
||||
for op, right_ast in zip(node.ops, node.comparators):
|
||||
try:
|
||||
right = eval_constant(right_ast)
|
||||
except NotConstant:
|
||||
right = right_ast
|
||||
if (not isinstance(operands[-1], ast.AST)
|
||||
and not isinstance(right, ast.AST)):
|
||||
left = operands.pop()
|
||||
operands.append(_ast_cmpops[type(op)](left, right))
|
||||
else:
|
||||
ops.append(op)
|
||||
operands.append(right_ast)
|
||||
operands = [operand if isinstance(operand, ast.AST)
|
||||
else ast.copy_location(value_to_ast(operand), node)
|
||||
for operand in operands]
|
||||
if len(operands) == 1:
|
||||
return operands[0]
|
||||
else:
|
||||
node.left = operands[0]
|
||||
node.right = operands[1:]
|
||||
node.ops = ops
|
||||
return node
|
||||
|
||||
def visit_BoolOp(self, node):
|
||||
self.generic_visit(node)
|
||||
new_values = []
|
||||
for value in node.values:
|
||||
try:
|
||||
value_c = eval_constant(value)
|
||||
except NotConstant:
|
||||
new_values.append(value)
|
||||
else:
|
||||
if new_values and not isinstance(new_values[-1], ast.AST):
|
||||
op = _ast_boolops[type(node.op)]
|
||||
new_values[-1] = op(new_values[-1], value_c)
|
||||
else:
|
||||
new_values.append(value_c)
|
||||
new_values = [v if isinstance(v, ast.AST) else value_to_ast(v)
|
||||
for v in new_values]
|
||||
if len(new_values) > 1:
|
||||
node.values = new_values
|
||||
return node
|
||||
else:
|
||||
return new_values[0]
|
||||
|
||||
def visit_Call(self, node):
|
||||
self.generic_visit(node)
|
||||
fn = node.func.id
|
||||
constant_ops = {
|
||||
"int": int,
|
||||
"int64": int64,
|
||||
"round": round,
|
||||
"round64": round64,
|
||||
"Fraction": Fraction
|
||||
}
|
||||
if fn in constant_ops:
|
||||
args = []
|
||||
for arg in node.args:
|
||||
try:
|
||||
args.append(eval_constant(arg))
|
||||
except NotConstant:
|
||||
return node
|
||||
result = value_to_ast(constant_ops[fn](*args))
|
||||
return ast.copy_location(result, node)
|
||||
else:
|
||||
return node
|
||||
|
||||
|
||||
def fold_constants(node):
|
||||
_ConstantFolder().visit(node)
|
@ -1,59 +0,0 @@
|
||||
import ast
|
||||
|
||||
from artiq.transforms.tools import is_ref_transparent
|
||||
|
||||
|
||||
class _SourceLister(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.sources = set()
|
||||
|
||||
def visit_Name(self, node):
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
self.sources.add(node.id)
|
||||
|
||||
|
||||
class _DeadCodeRemover(ast.NodeTransformer):
|
||||
def __init__(self, kept_targets):
|
||||
self.kept_targets = kept_targets
|
||||
|
||||
def visit_Assign(self, node):
|
||||
new_targets = []
|
||||
for target in node.targets:
|
||||
if (not isinstance(target, ast.Name)
|
||||
or target.id in self.kept_targets):
|
||||
new_targets.append(target)
|
||||
if not new_targets and is_ref_transparent(node.value)[0]:
|
||||
return None
|
||||
else:
|
||||
return node
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
if (isinstance(node.target, ast.Name)
|
||||
and node.target.id not in self.kept_targets
|
||||
and is_ref_transparent(node.value)[0]):
|
||||
return None
|
||||
else:
|
||||
return node
|
||||
|
||||
def visit_If(self, node):
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.test, ast.NameConstant):
|
||||
if node.test.value:
|
||||
return node.body
|
||||
else:
|
||||
return node.orelse
|
||||
else:
|
||||
return node
|
||||
|
||||
def visit_While(self, node):
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.test, ast.NameConstant) and not node.test.value:
|
||||
return node.orelse
|
||||
else:
|
||||
return node
|
||||
|
||||
|
||||
def remove_dead_code(func_def):
|
||||
sl = _SourceLister()
|
||||
sl.visit(func_def)
|
||||
_DeadCodeRemover(sl.sources).visit(func_def)
|
@ -1,149 +0,0 @@
|
||||
import ast
|
||||
from copy import copy, deepcopy
|
||||
from collections import defaultdict
|
||||
|
||||
from artiq.transforms.tools import is_ref_transparent, count_all_nodes
|
||||
|
||||
|
||||
class _TargetLister(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.targets = set()
|
||||
|
||||
def visit_Name(self, node):
|
||||
if isinstance(node.ctx, ast.Store):
|
||||
self.targets.add(node.id)
|
||||
|
||||
|
||||
class _InterAssignRemover(ast.NodeTransformer):
|
||||
def __init__(self):
|
||||
self.replacements = dict()
|
||||
self.modified_names = set()
|
||||
# name -> set of names that depend on it
|
||||
# i.e. when x is modified, dependencies[x] is the set of names that
|
||||
# cannot be replaced anymore
|
||||
self.dependencies = defaultdict(set)
|
||||
|
||||
def invalidate(self, name):
|
||||
try:
|
||||
del self.replacements[name]
|
||||
except KeyError:
|
||||
pass
|
||||
for d in self.dependencies[name]:
|
||||
self.invalidate(d)
|
||||
del self.dependencies[name]
|
||||
|
||||
def visit_Name(self, node):
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
try:
|
||||
return deepcopy(self.replacements[node.id])
|
||||
except KeyError:
|
||||
return node
|
||||
else:
|
||||
self.modified_names.add(node.id)
|
||||
self.invalidate(node.id)
|
||||
return node
|
||||
|
||||
def visit_Assign(self, node):
|
||||
node.value = self.visit(node.value)
|
||||
node.targets = [self.visit(target) for target in node.targets]
|
||||
rt, depends_on = is_ref_transparent(node.value)
|
||||
if rt and count_all_nodes(node.value) < 100:
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id not in depends_on:
|
||||
self.replacements[target.id] = node.value
|
||||
for d in depends_on:
|
||||
self.dependencies[d].add(target.id)
|
||||
return node
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
left = deepcopy(node.target)
|
||||
left.ctx = ast.Load()
|
||||
newnode = ast.copy_location(
|
||||
ast.Assign(
|
||||
targets=[node.target],
|
||||
value=ast.BinOp(left=left, op=node.op, right=node.value)
|
||||
),
|
||||
node
|
||||
)
|
||||
return self.visit_Assign(newnode)
|
||||
|
||||
def modified_names_push(self):
|
||||
prev_modified_names = self.modified_names
|
||||
self.modified_names = set()
|
||||
return prev_modified_names
|
||||
|
||||
def modified_names_pop(self, prev_modified_names):
|
||||
for name in self.modified_names:
|
||||
self.invalidate(name)
|
||||
self.modified_names |= prev_modified_names
|
||||
|
||||
def visit_Try(self, node):
|
||||
prev_modified_names = self.modified_names_push()
|
||||
node.body = [self.visit(stmt) for stmt in node.body]
|
||||
self.modified_names_pop(prev_modified_names)
|
||||
|
||||
prev_modified_names = self.modified_names_push()
|
||||
prev_replacements = self.replacements
|
||||
for handler in node.handlers:
|
||||
self.replacements = copy(prev_replacements)
|
||||
handler.body = [self.visit(stmt) for stmt in handler.body]
|
||||
self.replacements = copy(prev_replacements)
|
||||
node.orelse = [self.visit(stmt) for stmt in node.orelse]
|
||||
self.modified_names_pop(prev_modified_names)
|
||||
|
||||
prev_modified_names = self.modified_names_push()
|
||||
node.finalbody = [self.visit(stmt) for stmt in node.finalbody]
|
||||
self.modified_names_pop(prev_modified_names)
|
||||
return node
|
||||
|
||||
def visit_If(self, node):
|
||||
node.test = self.visit(node.test)
|
||||
|
||||
prev_modified_names = self.modified_names_push()
|
||||
|
||||
prev_replacements = self.replacements
|
||||
self.replacements = copy(prev_replacements)
|
||||
node.body = [self.visit(n) for n in node.body]
|
||||
self.replacements = copy(prev_replacements)
|
||||
node.orelse = [self.visit(n) for n in node.orelse]
|
||||
self.replacements = prev_replacements
|
||||
|
||||
self.modified_names_pop(prev_modified_names)
|
||||
|
||||
return node
|
||||
|
||||
def visit_loop(self, node):
|
||||
prev_modified_names = self.modified_names_push()
|
||||
prev_replacements = self.replacements
|
||||
|
||||
self.replacements = copy(prev_replacements)
|
||||
tl = _TargetLister()
|
||||
for n in node.body:
|
||||
tl.visit(n)
|
||||
for name in tl.targets:
|
||||
self.invalidate(name)
|
||||
node.body = [self.visit(n) for n in node.body]
|
||||
|
||||
self.replacements = copy(prev_replacements)
|
||||
node.orelse = [self.visit(n) for n in node.orelse]
|
||||
|
||||
self.replacements = prev_replacements
|
||||
self.modified_names_pop(prev_modified_names)
|
||||
|
||||
def visit_For(self, node):
|
||||
prev_modified_names = self.modified_names_push()
|
||||
node.target = self.visit(node.target)
|
||||
self.modified_names_pop(prev_modified_names)
|
||||
node.iter = self.visit(node.iter)
|
||||
self.visit_loop(node)
|
||||
return node
|
||||
|
||||
def visit_While(self, node):
|
||||
self.visit_loop(node)
|
||||
node.test = self.visit(node.test)
|
||||
return node
|
||||
|
||||
|
||||
def remove_inter_assigns(func_def):
|
||||
_InterAssignRemover().visit(func_def)
|
@ -1,141 +0,0 @@
|
||||
import ast
|
||||
from fractions import Fraction
|
||||
|
||||
from artiq.language import core as core_language
|
||||
from artiq.language import units
|
||||
|
||||
|
||||
embeddable_funcs = (
|
||||
core_language.delay_mu, core_language.at_mu, core_language.now_mu,
|
||||
core_language.delay,
|
||||
core_language.seconds_to_mu, core_language.mu_to_seconds,
|
||||
core_language.syscall, core_language.watchdog,
|
||||
range, bool, int, float, round, len,
|
||||
core_language.int64, core_language.round64,
|
||||
Fraction, core_language.EncodedException
|
||||
)
|
||||
embeddable_func_names = {func.__name__ for func in embeddable_funcs}
|
||||
|
||||
|
||||
def is_embeddable(func):
|
||||
for ef in embeddable_funcs:
|
||||
if func is ef:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def eval_ast(expr, symdict=dict()):
|
||||
if not isinstance(expr, ast.Expression):
|
||||
expr = ast.copy_location(ast.Expression(expr), expr)
|
||||
ast.fix_missing_locations(expr)
|
||||
code = compile(expr, "<ast>", "eval")
|
||||
return eval(code, symdict)
|
||||
|
||||
|
||||
class NotASTRepresentable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def value_to_ast(value):
|
||||
if isinstance(value, core_language.int64): # must be before int
|
||||
return ast.Call(
|
||||
func=ast.Name("int64", ast.Load()),
|
||||
args=[ast.Num(int(value))],
|
||||
keywords=[], starargs=None, kwargs=None)
|
||||
elif isinstance(value, bool) or value is None:
|
||||
# must also be before int
|
||||
# isinstance(True/False, int) == True
|
||||
return ast.NameConstant(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
return ast.Num(value)
|
||||
elif isinstance(value, Fraction):
|
||||
return ast.Call(
|
||||
func=ast.Name("Fraction", ast.Load()),
|
||||
args=[ast.Num(value.numerator), ast.Num(value.denominator)],
|
||||
keywords=[], starargs=None, kwargs=None)
|
||||
elif isinstance(value, str):
|
||||
return ast.Str(value)
|
||||
elif isinstance(value, list):
|
||||
elts = [value_to_ast(elt) for elt in value]
|
||||
return ast.List(elts, ast.Load())
|
||||
else:
|
||||
for kg in core_language.kernel_globals:
|
||||
if value is getattr(core_language, kg):
|
||||
return ast.Name(kg, ast.Load())
|
||||
raise NotASTRepresentable(str(value))
|
||||
|
||||
|
||||
class NotConstant(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def eval_constant(node):
|
||||
if isinstance(node, ast.Num):
|
||||
return node.n
|
||||
elif isinstance(node, ast.Str):
|
||||
return node.s
|
||||
elif isinstance(node, ast.NameConstant):
|
||||
return node.value
|
||||
elif isinstance(node, ast.Call):
|
||||
funcname = node.func.id
|
||||
if funcname == "int64":
|
||||
return core_language.int64(eval_constant(node.args[0]))
|
||||
elif funcname == "Fraction":
|
||||
numerator = eval_constant(node.args[0])
|
||||
denominator = eval_constant(node.args[1])
|
||||
return Fraction(numerator, denominator)
|
||||
else:
|
||||
raise NotConstant
|
||||
else:
|
||||
raise NotConstant
|
||||
|
||||
|
||||
_replaceable_funcs = {
|
||||
"bool", "int", "float", "round",
|
||||
"int64", "round64", "Fraction",
|
||||
"seconds_to_mu", "mu_to_seconds"
|
||||
}
|
||||
|
||||
|
||||
def _is_ref_transparent(dependencies, expr):
|
||||
if isinstance(expr, (ast.NameConstant, ast.Num, ast.Str)):
|
||||
return True
|
||||
elif isinstance(expr, ast.Name):
|
||||
dependencies.add(expr.id)
|
||||
return True
|
||||
elif isinstance(expr, ast.UnaryOp):
|
||||
return _is_ref_transparent(dependencies, expr.operand)
|
||||
elif isinstance(expr, ast.BinOp):
|
||||
return (_is_ref_transparent(dependencies, expr.left)
|
||||
and _is_ref_transparent(dependencies, expr.right))
|
||||
elif isinstance(expr, ast.BoolOp):
|
||||
return all(_is_ref_transparent(dependencies, v) for v in expr.values)
|
||||
elif isinstance(expr, ast.Call):
|
||||
return (expr.func.id in _replaceable_funcs and
|
||||
all(_is_ref_transparent(dependencies, arg)
|
||||
for arg in expr.args))
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_ref_transparent(expr):
|
||||
dependencies = set()
|
||||
if _is_ref_transparent(dependencies, expr):
|
||||
return True, dependencies
|
||||
else:
|
||||
return False, None
|
||||
|
||||
|
||||
class _NodeCounter(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
def generic_visit(self, node):
|
||||
self.count += 1
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
|
||||
def count_all_nodes(node):
|
||||
nc = _NodeCounter()
|
||||
nc.visit(node)
|
||||
return nc.count
|
@ -1,600 +0,0 @@
|
||||
import sys
|
||||
import ast
|
||||
|
||||
|
||||
# Large float and imaginary literals get turned into infinities in the AST.
|
||||
# We unparse those infinities to INFSTR.
|
||||
INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
|
||||
|
||||
|
||||
def _interleave(inter, f, seq):
|
||||
"""Call f on each item in seq, calling inter() in between.
|
||||
"""
|
||||
seq = iter(seq)
|
||||
try:
|
||||
f(next(seq))
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
for x in seq:
|
||||
inter()
|
||||
f(x)
|
||||
|
||||
|
||||
class _Unparser:
|
||||
"""Methods in this class recursively traverse an AST and
|
||||
output source code for the abstract syntax; original formatting
|
||||
is disregarded. """
|
||||
|
||||
def __init__(self, tree):
|
||||
"""Print the source for tree to the "result" string."""
|
||||
self.result = ""
|
||||
self._indent = 0
|
||||
self.dispatch(tree)
|
||||
self.result += "\n"
|
||||
|
||||
def fill(self, text=""):
|
||||
"Indent a piece of text, according to the current indentation level"
|
||||
self.result += "\n"+" "*self._indent + text
|
||||
|
||||
def write(self, text):
|
||||
"Append a piece of text to the current line."
|
||||
self.result += text
|
||||
|
||||
def enter(self):
|
||||
"Print ':', and increase the indentation."
|
||||
self.write(":")
|
||||
self._indent += 1
|
||||
|
||||
def leave(self):
|
||||
"Decrease the indentation level."
|
||||
self._indent -= 1
|
||||
|
||||
def dispatch(self, tree):
|
||||
"Dispatcher function, dispatching tree type T to method _T."
|
||||
if isinstance(tree, list):
|
||||
for t in tree:
|
||||
self.dispatch(t)
|
||||
return
|
||||
meth = getattr(self, "_"+tree.__class__.__name__)
|
||||
meth(tree)
|
||||
|
||||
# Unparsing methods
|
||||
#
|
||||
# There should be one method per concrete grammar type
|
||||
# Constructors should be grouped by sum type. Ideally,
|
||||
# this would follow the order in the grammar, but
|
||||
# currently doesn't.
|
||||
|
||||
def _Module(self, tree):
|
||||
for stmt in tree.body:
|
||||
self.dispatch(stmt)
|
||||
|
||||
# stmt
|
||||
def _Expr(self, tree):
|
||||
self.fill()
|
||||
self.dispatch(tree.value)
|
||||
|
||||
def _Import(self, t):
|
||||
self.fill("import ")
|
||||
_interleave(lambda: self.write(", "), self.dispatch, t.names)
|
||||
|
||||
def _ImportFrom(self, t):
|
||||
self.fill("from ")
|
||||
self.write("." * t.level)
|
||||
if t.module:
|
||||
self.write(t.module)
|
||||
self.write(" import ")
|
||||
_interleave(lambda: self.write(", "), self.dispatch, t.names)
|
||||
|
||||
def _Assign(self, t):
|
||||
self.fill()
|
||||
for target in t.targets:
|
||||
self.dispatch(target)
|
||||
self.write(" = ")
|
||||
self.dispatch(t.value)
|
||||
|
||||
def _AugAssign(self, t):
|
||||
self.fill()
|
||||
self.dispatch(t.target)
|
||||
self.write(" "+self.binop[t.op.__class__.__name__]+"= ")
|
||||
self.dispatch(t.value)
|
||||
|
||||
def _Return(self, t):
|
||||
self.fill("return")
|
||||
if t.value:
|
||||
self.write(" ")
|
||||
self.dispatch(t.value)
|
||||
|
||||
def _Pass(self, t):
|
||||
self.fill("pass")
|
||||
|
||||
def _Break(self, t):
|
||||
self.fill("break")
|
||||
|
||||
def _Continue(self, t):
|
||||
self.fill("continue")
|
||||
|
||||
def _Delete(self, t):
|
||||
self.fill("del ")
|
||||
_interleave(lambda: self.write(", "), self.dispatch, t.targets)
|
||||
|
||||
def _Assert(self, t):
|
||||
self.fill("assert ")
|
||||
self.dispatch(t.test)
|
||||
if t.msg:
|
||||
self.write(", ")
|
||||
self.dispatch(t.msg)
|
||||
|
||||
def _Global(self, t):
|
||||
self.fill("global ")
|
||||
_interleave(lambda: self.write(", "), self.write, t.names)
|
||||
|
||||
def _Nonlocal(self, t):
|
||||
self.fill("nonlocal ")
|
||||
_interleave(lambda: self.write(", "), self.write, t.names)
|
||||
|
||||
def _Yield(self, t):
|
||||
self.write("(")
|
||||
self.write("yield")
|
||||
if t.value:
|
||||
self.write(" ")
|
||||
self.dispatch(t.value)
|
||||
self.write(")")
|
||||
|
||||
def _YieldFrom(self, t):
|
||||
self.write("(")
|
||||
self.write("yield from")
|
||||
if t.value:
|
||||
self.write(" ")
|
||||
self.dispatch(t.value)
|
||||
self.write(")")
|
||||
|
||||
def _Raise(self, t):
|
||||
self.fill("raise")
|
||||
if not t.exc:
|
||||
assert not t.cause
|
||||
return
|
||||
self.write(" ")
|
||||
self.dispatch(t.exc)
|
||||
if t.cause:
|
||||
self.write(" from ")
|
||||
self.dispatch(t.cause)
|
||||
|
||||
def _Try(self, t):
|
||||
self.fill("try")
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
for ex in t.handlers:
|
||||
self.dispatch(ex)
|
||||
if t.orelse:
|
||||
self.fill("else")
|
||||
self.enter()
|
||||
self.dispatch(t.orelse)
|
||||
self.leave()
|
||||
if t.finalbody:
|
||||
self.fill("finally")
|
||||
self.enter()
|
||||
self.dispatch(t.finalbody)
|
||||
self.leave()
|
||||
|
||||
def _ExceptHandler(self, t):
|
||||
self.fill("except")
|
||||
if t.type:
|
||||
self.write(" ")
|
||||
self.dispatch(t.type)
|
||||
if t.name:
|
||||
self.write(" as ")
|
||||
self.write(t.name)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
|
||||
def _ClassDef(self, t):
|
||||
self.write("\n")
|
||||
for deco in t.decorator_list:
|
||||
self.fill("@")
|
||||
self.dispatch(deco)
|
||||
self.fill("class "+t.name)
|
||||
self.write("(")
|
||||
comma = False
|
||||
for e in t.bases:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.dispatch(e)
|
||||
for e in t.keywords:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.dispatch(e)
|
||||
if t.starargs:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.write("*")
|
||||
self.dispatch(t.starargs)
|
||||
if t.kwargs:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.write("**")
|
||||
self.dispatch(t.kwargs)
|
||||
self.write(")")
|
||||
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
|
||||
def _FunctionDef(self, t):
|
||||
self.write("\n")
|
||||
for deco in t.decorator_list:
|
||||
self.fill("@")
|
||||
self.dispatch(deco)
|
||||
self.fill("def "+t.name + "(")
|
||||
self.dispatch(t.args)
|
||||
self.write(")")
|
||||
if t.returns:
|
||||
self.write(" -> ")
|
||||
self.dispatch(t.returns)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
|
||||
def _For(self, t):
|
||||
self.fill("for ")
|
||||
self.dispatch(t.target)
|
||||
self.write(" in ")
|
||||
self.dispatch(t.iter)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
if t.orelse:
|
||||
self.fill("else")
|
||||
self.enter()
|
||||
self.dispatch(t.orelse)
|
||||
self.leave()
|
||||
|
||||
def _If(self, t):
|
||||
self.fill("if ")
|
||||
self.dispatch(t.test)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
# collapse nested ifs into equivalent elifs.
|
||||
while (t.orelse and len(t.orelse) == 1 and
|
||||
isinstance(t.orelse[0], ast.If)):
|
||||
t = t.orelse[0]
|
||||
self.fill("elif ")
|
||||
self.dispatch(t.test)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
# final else
|
||||
if t.orelse:
|
||||
self.fill("else")
|
||||
self.enter()
|
||||
self.dispatch(t.orelse)
|
||||
self.leave()
|
||||
|
||||
def _While(self, t):
|
||||
self.fill("while ")
|
||||
self.dispatch(t.test)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
if t.orelse:
|
||||
self.fill("else")
|
||||
self.enter()
|
||||
self.dispatch(t.orelse)
|
||||
self.leave()
|
||||
|
||||
def _With(self, t):
|
||||
self.fill("with ")
|
||||
_interleave(lambda: self.write(", "), self.dispatch, t.items)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
|
||||
# expr
|
||||
def _Bytes(self, t):
|
||||
self.write(repr(t.s))
|
||||
|
||||
def _Str(self, tree):
|
||||
self.write(repr(tree.s))
|
||||
|
||||
def _Name(self, t):
|
||||
self.write(t.id)
|
||||
|
||||
def _NameConstant(self, t):
|
||||
self.write(repr(t.value))
|
||||
|
||||
def _Num(self, t):
|
||||
# Substitute overflowing decimal literal for AST infinities.
|
||||
self.write(repr(t.n).replace("inf", INFSTR))
|
||||
|
||||
def _List(self, t):
|
||||
self.write("[")
|
||||
_interleave(lambda: self.write(", "), self.dispatch, t.elts)
|
||||
self.write("]")
|
||||
|
||||
def _ListComp(self, t):
|
||||
self.write("[")
|
||||
self.dispatch(t.elt)
|
||||
for gen in t.generators:
|
||||
self.dispatch(gen)
|
||||
self.write("]")
|
||||
|
||||
def _GeneratorExp(self, t):
|
||||
self.write("(")
|
||||
self.dispatch(t.elt)
|
||||
for gen in t.generators:
|
||||
self.dispatch(gen)
|
||||
self.write(")")
|
||||
|
||||
def _SetComp(self, t):
|
||||
self.write("{")
|
||||
self.dispatch(t.elt)
|
||||
for gen in t.generators:
|
||||
self.dispatch(gen)
|
||||
self.write("}")
|
||||
|
||||
def _DictComp(self, t):
|
||||
self.write("{")
|
||||
self.dispatch(t.key)
|
||||
self.write(": ")
|
||||
self.dispatch(t.value)
|
||||
for gen in t.generators:
|
||||
self.dispatch(gen)
|
||||
self.write("}")
|
||||
|
||||
def _comprehension(self, t):
|
||||
self.write(" for ")
|
||||
self.dispatch(t.target)
|
||||
self.write(" in ")
|
||||
self.dispatch(t.iter)
|
||||
for if_clause in t.ifs:
|
||||
self.write(" if ")
|
||||
self.dispatch(if_clause)
|
||||
|
||||
def _IfExp(self, t):
|
||||
self.write("(")
|
||||
self.dispatch(t.body)
|
||||
self.write(" if ")
|
||||
self.dispatch(t.test)
|
||||
self.write(" else ")
|
||||
self.dispatch(t.orelse)
|
||||
self.write(")")
|
||||
|
||||
def _Set(self, t):
|
||||
assert(t.elts) # should be at least one element
|
||||
self.write("{")
|
||||
_interleave(lambda: self.write(", "), self.dispatch, t.elts)
|
||||
self.write("}")
|
||||
|
||||
def _Dict(self, t):
|
||||
self.write("{")
|
||||
|
||||
def write_pair(pair):
|
||||
(k, v) = pair
|
||||
self.dispatch(k)
|
||||
self.write(": ")
|
||||
self.dispatch(v)
|
||||
_interleave(lambda: self.write(", "), write_pair,
|
||||
zip(t.keys, t.values))
|
||||
self.write("}")
|
||||
|
||||
def _Tuple(self, t):
|
||||
self.write("(")
|
||||
if len(t.elts) == 1:
|
||||
(elt,) = t.elts
|
||||
self.dispatch(elt)
|
||||
self.write(",")
|
||||
else:
|
||||
_interleave(lambda: self.write(", "), self.dispatch, t.elts)
|
||||
self.write(")")
|
||||
|
||||
unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
|
||||
|
||||
def _UnaryOp(self, t):
|
||||
self.write("(")
|
||||
self.write(self.unop[t.op.__class__.__name__])
|
||||
self.write(" ")
|
||||
self.dispatch(t.operand)
|
||||
self.write(")")
|
||||
|
||||
binop = {"Add": "+", "Sub": "-", "Mult": "*", "Div": "/", "Mod": "%",
|
||||
"LShift": "<<", "RShift": ">>",
|
||||
"BitOr": "|", "BitXor": "^", "BitAnd": "&",
|
||||
"FloorDiv": "//", "Pow": "**"}
|
||||
|
||||
def _BinOp(self, t):
|
||||
self.write("(")
|
||||
self.dispatch(t.left)
|
||||
self.write(" " + self.binop[t.op.__class__.__name__] + " ")
|
||||
self.dispatch(t.right)
|
||||
self.write(")")
|
||||
|
||||
cmpops = {"Eq": "==", "NotEq": "!=",
|
||||
"Lt": "<", "LtE": "<=", "Gt": ">", "GtE": ">=",
|
||||
"Is": "is", "IsNot": "is not", "In": "in", "NotIn": "not in"}
|
||||
|
||||
def _Compare(self, t):
|
||||
self.write("(")
|
||||
self.dispatch(t.left)
|
||||
for o, e in zip(t.ops, t.comparators):
|
||||
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
|
||||
self.dispatch(e)
|
||||
self.write(")")
|
||||
|
||||
boolops = {ast.And: "and", ast.Or: "or"}
|
||||
|
||||
def _BoolOp(self, t):
|
||||
self.write("(")
|
||||
s = " %s " % self.boolops[t.op.__class__]
|
||||
_interleave(lambda: self.write(s), self.dispatch, t.values)
|
||||
self.write(")")
|
||||
|
||||
def _Attribute(self, t):
|
||||
self.dispatch(t.value)
|
||||
# Special case: 3.__abs__() is a syntax error, so if t.value
|
||||
# is an integer literal then we need to either parenthesize
|
||||
# it or add an extra space to get 3 .__abs__().
|
||||
if isinstance(t.value, ast.Num) and isinstance(t.value.n, int):
|
||||
self.write(" ")
|
||||
self.write(".")
|
||||
self.write(t.attr)
|
||||
|
||||
def _Call(self, t):
|
||||
self.dispatch(t.func)
|
||||
self.write("(")
|
||||
comma = False
|
||||
for e in t.args:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.dispatch(e)
|
||||
for e in t.keywords:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.dispatch(e)
|
||||
if t.starargs:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.write("*")
|
||||
self.dispatch(t.starargs)
|
||||
if t.kwargs:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.write("**")
|
||||
self.dispatch(t.kwargs)
|
||||
self.write(")")
|
||||
|
||||
def _Subscript(self, t):
|
||||
self.dispatch(t.value)
|
||||
self.write("[")
|
||||
self.dispatch(t.slice)
|
||||
self.write("]")
|
||||
|
||||
def _Starred(self, t):
|
||||
self.write("*")
|
||||
self.dispatch(t.value)
|
||||
|
||||
# slice
|
||||
def _Ellipsis(self, t):
|
||||
self.write("...")
|
||||
|
||||
def _Index(self, t):
|
||||
self.dispatch(t.value)
|
||||
|
||||
def _Slice(self, t):
|
||||
if t.lower:
|
||||
self.dispatch(t.lower)
|
||||
self.write(":")
|
||||
if t.upper:
|
||||
self.dispatch(t.upper)
|
||||
if t.step:
|
||||
self.write(":")
|
||||
self.dispatch(t.step)
|
||||
|
||||
def _ExtSlice(self, t):
|
||||
_interleave(lambda: self.write(', '), self.dispatch, t.dims)
|
||||
|
||||
# argument
|
||||
def _arg(self, t):
|
||||
self.write(t.arg)
|
||||
if t.annotation:
|
||||
self.write(": ")
|
||||
self.dispatch(t.annotation)
|
||||
|
||||
# others
|
||||
def _arguments(self, t):
|
||||
first = True
|
||||
# normal arguments
|
||||
defaults = [None] * (len(t.args) - len(t.defaults)) + t.defaults
|
||||
for a, d in zip(t.args, defaults):
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
self.write(", ")
|
||||
self.dispatch(a)
|
||||
if d:
|
||||
self.write("=")
|
||||
self.dispatch(d)
|
||||
|
||||
# varargs, or bare '*' if no varargs but keyword-only arguments present
|
||||
if t.vararg or t.kwonlyargs:
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
self.write(", ")
|
||||
self.write("*")
|
||||
if t.vararg:
|
||||
self.write(t.vararg.arg)
|
||||
if t.vararg.annotation:
|
||||
self.write(": ")
|
||||
self.dispatch(t.vararg.annotation)
|
||||
|
||||
# keyword-only arguments
|
||||
if t.kwonlyargs:
|
||||
for a, d in zip(t.kwonlyargs, t.kw_defaults):
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
self.write(", ")
|
||||
self.dispatch(a),
|
||||
if d:
|
||||
self.write("=")
|
||||
self.dispatch(d)
|
||||
|
||||
# kwargs
|
||||
if t.kwarg:
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
self.write(", ")
|
||||
self.write("**"+t.kwarg.arg)
|
||||
if t.kwarg.annotation:
|
||||
self.write(": ")
|
||||
self.dispatch(t.kwarg.annotation)
|
||||
|
||||
def _keyword(self, t):
|
||||
self.write(t.arg)
|
||||
self.write("=")
|
||||
self.dispatch(t.value)
|
||||
|
||||
def _Lambda(self, t):
|
||||
self.write("(")
|
||||
self.write("lambda ")
|
||||
self.dispatch(t.args)
|
||||
self.write(": ")
|
||||
self.dispatch(t.body)
|
||||
self.write(")")
|
||||
|
||||
def _alias(self, t):
|
||||
self.write(t.name)
|
||||
if t.asname:
|
||||
self.write(" as "+t.asname)
|
||||
|
||||
def _withitem(self, t):
|
||||
self.dispatch(t.context_expr)
|
||||
if t.optional_vars:
|
||||
self.write(" as ")
|
||||
self.dispatch(t.optional_vars)
|
||||
|
||||
|
||||
def unparse(tree):
|
||||
unparser = _Unparser(tree)
|
||||
return unparser.result
|
Loading…
Reference in New Issue
Block a user