forked from M-Labs/artiq
compiler/ir_ast_body: refactor and add statement visitor
This commit is contained in:
parent
d41d06835c
commit
30ef6119e6
|
@ -2,14 +2,57 @@ import ast
|
|||
|
||||
from artiq.compiler import ir_values
|
||||
|
||||
_ast_unops = {
|
||||
class Visitor:
|
||||
def __init__(self, builder, ns):
|
||||
self.builder = builder
|
||||
self.ns = ns
|
||||
|
||||
# builder can be None for visit_expression
|
||||
def visit_expression(self, node):
|
||||
method = "_visit_expr_" + node.__class__.__name__
|
||||
try:
|
||||
visitor = getattr(self, method)
|
||||
except AttributeError:
|
||||
raise NotImplementedError("Unsupported node '{}' in expression".format(node.__class__.__name__))
|
||||
return visitor(node)
|
||||
|
||||
def _visit_expr_Name(self, node):
|
||||
return self.ns.load(self.builder, node.id)
|
||||
|
||||
def _visit_expr_NameConstant(self, node):
|
||||
v = node.value
|
||||
if isinstance(v, bool):
|
||||
r = ir_values.VBool()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.builder is not None:
|
||||
r.create_constant(v)
|
||||
return r
|
||||
|
||||
def _visit_expr_Num(self, node):
|
||||
n = node.n
|
||||
if isinstance(n, int):
|
||||
if abs(n) < 2**31:
|
||||
r = ir_values.VInt()
|
||||
else:
|
||||
r = ir_values.VInt(64)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.builder is not None:
|
||||
r.create_constant(n)
|
||||
return r
|
||||
|
||||
def _visit_expr_UnaryOp(self, node):
|
||||
ast_unops = {
|
||||
ast.Invert: ir_values.operators.inv,
|
||||
ast.Not: ir_values.operators.not_,
|
||||
ast.UAdd: ir_values.operators.pos,
|
||||
ast.USub: ir_values.operators.neg
|
||||
}
|
||||
}
|
||||
return ast_unops[type(node.op)](self.visit_expression(node.operand), self.builder)
|
||||
|
||||
_ast_binops = {
|
||||
def _visit_expr_BinOp(self, node):
|
||||
ast_binops = {
|
||||
ast.Add: ir_values.operators.add,
|
||||
ast.Sub: ir_values.operators.sub,
|
||||
ast.Mult: ir_values.operators.mul,
|
||||
|
@ -22,71 +65,102 @@ _ast_binops = {
|
|||
ast.BitOr: ir_values.operators.or_,
|
||||
ast.BitXor: ir_values.operators.xor,
|
||||
ast.BitAnd: ir_values.operators.and_
|
||||
}
|
||||
}
|
||||
return ast_binops[type(node.op)](self.visit_expression(node.left), self.visit_expression(node.right), self.builder)
|
||||
|
||||
_ast_cmps = {
|
||||
def _visit_expr_Compare(self, node):
|
||||
ast_cmps = {
|
||||
ast.Eq: ir_values.operators.eq,
|
||||
ast.NotEq: ir_values.operators.ne,
|
||||
ast.Lt: ir_values.operators.lt,
|
||||
ast.LtE: ir_values.operators.le,
|
||||
ast.Gt: ir_values.operators.gt,
|
||||
ast.GtE: ir_values.operators.ge
|
||||
}
|
||||
|
||||
_ast_unfuns = {
|
||||
"bool": ir_values.operators.bool,
|
||||
"int": ir_values.operators.int,
|
||||
"int64": ir_values.operators.int64,
|
||||
"round": ir_values.operators.round,
|
||||
"round64": ir_values.operators.round64,
|
||||
}
|
||||
|
||||
class ExpressionVisitor:
|
||||
def __init__(self, builder, ns):
|
||||
self.builder = builder
|
||||
self.ns = ns
|
||||
|
||||
def visit(self, node):
|
||||
if isinstance(node, ast.Name):
|
||||
return self.ns.load(self.builder, node.id)
|
||||
elif isinstance(node, ast.NameConstant):
|
||||
v = node.value
|
||||
if isinstance(v, bool):
|
||||
r = ir_values.VBool()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.builder is not None:
|
||||
r.create_constant(v)
|
||||
return r
|
||||
elif isinstance(node, ast.Num):
|
||||
n = node.n
|
||||
if isinstance(n, int):
|
||||
if abs(n) < 2**31:
|
||||
r = ir_values.VInt()
|
||||
else:
|
||||
r = ir_values.VInt(64)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.builder is not None:
|
||||
r.create_constant(n)
|
||||
return r
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
return _ast_unops[type(node.op)](self.visit(node.operand), self.builder)
|
||||
elif isinstance(node, ast.BinOp):
|
||||
return _ast_binops[type(node.op)](self.visit(node.left), self.visit(node.right), self.builder)
|
||||
elif isinstance(node, ast.Compare):
|
||||
}
|
||||
comparisons = []
|
||||
old_comparator = self.visit(node.left)
|
||||
old_comparator = self.visit_expression(node.left)
|
||||
for op, comparator_a in zip(node.ops, node.comparators):
|
||||
comparator = self.visit(comparator_a)
|
||||
comparison = _ast_cmps[type(op)](old_comparator, comparator)
|
||||
comparator = self.visit_expression(comparator_a)
|
||||
comparison = ast_cmps[type(op)](old_comparator, comparator)
|
||||
comparisons.append(comparison)
|
||||
old_comparator = comparator
|
||||
r = comparisons[0]
|
||||
for comparison in comparisons[1:]:
|
||||
r = ir_values.operators.and_(r, comparison)
|
||||
return r
|
||||
elif isinstance(node, ast.Call):
|
||||
return _ast_unfuns[node.func.id](self.visit(node.args[0]), self.builder)
|
||||
|
||||
def _visit_expr_Call(self, node):
|
||||
ast_unfuns = {
|
||||
"bool": ir_values.operators.bool,
|
||||
"int": ir_values.operators.int,
|
||||
"int64": ir_values.operators.int64,
|
||||
"round": ir_values.operators.round,
|
||||
"round64": ir_values.operators.round64,
|
||||
}
|
||||
return ast_unfuns[node.func.id](self.visit_expression(node.args[0]), self.builder)
|
||||
|
||||
def visit_statements(self, stmts):
|
||||
for node in stmts:
|
||||
method = "_visit_stmt_" + node.__class__.__name__
|
||||
try:
|
||||
visitor = getattr(self, method)
|
||||
except AttributeError:
|
||||
raise NotImplementedError("Unsupported node '{}' in statement".format(node.__class__.__name__))
|
||||
visitor(node)
|
||||
|
||||
def _visit_stmt_Assign(self, node):
|
||||
val = self.visit_expression(node.value)
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
self.ns.store(self.builder, val, target.id)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _visit_stmt_AugAssign(self, node):
|
||||
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
||||
if isinstance(node.target, ast.Name):
|
||||
self.ns.store(self.builder, val, node.target.id)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _visit_stmt_Expr(self, node):
|
||||
self.visit_expression(node.value)
|
||||
|
||||
def _visit_stmt_If(self, node):
|
||||
function = self.builder.basic_block.function
|
||||
then_block = function.append_basic_block("i_then")
|
||||
else_block = function.append_basic_block("i_else")
|
||||
merge_block = function.append_basic_block("i_merge")
|
||||
|
||||
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
||||
self.builder.cbranch(condition.llvm_value, then_block, else_block)
|
||||
|
||||
self.builder.position_at_end(then_block)
|
||||
self.visit_statements(node.body)
|
||||
self.builder.branch(merge_block)
|
||||
|
||||
self.builder.position_at_end(else_block)
|
||||
self.visit_statements(node.orelse)
|
||||
self.builder.branch(merge_block)
|
||||
|
||||
self.builder.position_at_end(merge_block)
|
||||
|
||||
def _visit_stmt_While(self, node):
|
||||
function = self.builder.basic_block.function
|
||||
body_block = function.append_basic_block("w_body")
|
||||
else_block = function.append_basic_block("w_else")
|
||||
merge_block = function.append_basic_block("w_merge")
|
||||
|
||||
condition = self.visit_expression(node.test)
|
||||
self.builder.cbranch(condition, body_block, else_block)
|
||||
|
||||
self.builder.position_at_end(body_block)
|
||||
self.visit_statements(node.body)
|
||||
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
||||
self.builder.cbranch(condition.llvm_value, body_block, merge_block)
|
||||
|
||||
self.builder.position_at_end(else_block)
|
||||
self.visit_statements(node.orelse)
|
||||
self.builder.branch(merge_block)
|
||||
|
||||
self.builder.position_at_end(merge_block)
|
||||
|
|
|
@ -2,7 +2,7 @@ import ast
|
|||
from operator import itemgetter
|
||||
from copy import deepcopy
|
||||
|
||||
from artiq.compiler.ir_ast_body import ExpressionVisitor
|
||||
from artiq.compiler.ir_ast_body import Visitor
|
||||
|
||||
class _Namespace:
|
||||
def __init__(self, name_to_value):
|
||||
|
@ -13,10 +13,10 @@ class _Namespace:
|
|||
|
||||
class _TypeScanner(ast.NodeVisitor):
|
||||
def __init__(self, namespace):
|
||||
self.exprv = ExpressionVisitor(None, namespace)
|
||||
self.exprv = Visitor(None, namespace)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
val = self.exprv.visit(node.value)
|
||||
val = self.exprv.visit_expression(node.value)
|
||||
n2v = self.exprv.ns.name_to_value
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
|
@ -28,7 +28,7 @@ class _TypeScanner(ast.NodeVisitor):
|
|||
raise NotImplementedError
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
val = self.exprv.visit(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
||||
val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
||||
n2v = self.exprv.ns.name_to_value
|
||||
target = node.target
|
||||
if isinstance(target, ast.Name):
|
||||
|
|
Loading…
Reference in New Issue