compiler/ir_ast_body: refactor and add statement visitor

This commit is contained in:
Sebastien Bourdeauducq 2014-08-17 21:46:11 +08:00
parent d41d06835c
commit 30ef6119e6
2 changed files with 160 additions and 86 deletions

View File

@ -2,91 +2,165 @@ import ast
from artiq.compiler import ir_values from artiq.compiler import ir_values
_ast_unops = { class Visitor:
ast.Invert: ir_values.operators.inv,
ast.Not: ir_values.operators.not_,
ast.UAdd: ir_values.operators.pos,
ast.USub: ir_values.operators.neg
}
_ast_binops = {
ast.Add: ir_values.operators.add,
ast.Sub: ir_values.operators.sub,
ast.Mult: ir_values.operators.mul,
ast.Div: ir_values.operators.truediv,
ast.FloorDiv: ir_values.operators.floordiv,
ast.Mod: ir_values.operators.mod,
ast.Pow: ir_values.operators.pow,
ast.LShift: ir_values.operators.lshift,
ast.RShift: ir_values.operators.rshift,
ast.BitOr: ir_values.operators.or_,
ast.BitXor: ir_values.operators.xor,
ast.BitAnd: ir_values.operators.and_
}
_ast_cmps = {
ast.Eq: ir_values.operators.eq,
ast.NotEq: ir_values.operators.ne,
ast.Lt: ir_values.operators.lt,
ast.LtE: ir_values.operators.le,
ast.Gt: ir_values.operators.gt,
ast.GtE: ir_values.operators.ge
}
_ast_unfuns = {
"bool": ir_values.operators.bool,
"int": ir_values.operators.int,
"int64": ir_values.operators.int64,
"round": ir_values.operators.round,
"round64": ir_values.operators.round64,
}
class ExpressionVisitor:
def __init__(self, builder, ns): def __init__(self, builder, ns):
self.builder = builder self.builder = builder
self.ns = ns self.ns = ns
def visit(self, node): # builder can be None for visit_expression
if isinstance(node, ast.Name): def visit_expression(self, node):
return self.ns.load(self.builder, node.id) method = "_visit_expr_" + node.__class__.__name__
elif isinstance(node, ast.NameConstant): try:
v = node.value visitor = getattr(self, method)
if isinstance(v, bool): except AttributeError:
r = ir_values.VBool() raise NotImplementedError("Unsupported node '{}' in expression".format(node.__class__.__name__))
else: return visitor(node)
raise NotImplementedError
if self.builder is not None: def _visit_expr_Name(self, node):
r.create_constant(v) return self.ns.load(self.builder, node.id)
return r
elif isinstance(node, ast.Num): def _visit_expr_NameConstant(self, node):
n = node.n v = node.value
if isinstance(n, int): if isinstance(v, bool):
if abs(n) < 2**31: r = ir_values.VBool()
r = ir_values.VInt()
else:
r = ir_values.VInt(64)
else:
raise NotImplementedError
if self.builder is not None:
r.create_constant(n)
return r
elif isinstance(node, ast.UnaryOp):
return _ast_unops[type(node.op)](self.visit(node.operand), self.builder)
elif isinstance(node, ast.BinOp):
return _ast_binops[type(node.op)](self.visit(node.left), self.visit(node.right), self.builder)
elif isinstance(node, ast.Compare):
comparisons = []
old_comparator = self.visit(node.left)
for op, comparator_a in zip(node.ops, node.comparators):
comparator = self.visit(comparator_a)
comparison = _ast_cmps[type(op)](old_comparator, comparator)
comparisons.append(comparison)
old_comparator = comparator
r = comparisons[0]
for comparison in comparisons[1:]:
r = ir_values.operators.and_(r, comparison)
return r
elif isinstance(node, ast.Call):
return _ast_unfuns[node.func.id](self.visit(node.args[0]), self.builder)
else: else:
raise NotImplementedError raise NotImplementedError
if self.builder is not None:
r.create_constant(v)
return r
def _visit_expr_Num(self, node):
n = node.n
if isinstance(n, int):
if abs(n) < 2**31:
r = ir_values.VInt()
else:
r = ir_values.VInt(64)
else:
raise NotImplementedError
if self.builder is not None:
r.create_constant(n)
return r
def _visit_expr_UnaryOp(self, node):
ast_unops = {
ast.Invert: ir_values.operators.inv,
ast.Not: ir_values.operators.not_,
ast.UAdd: ir_values.operators.pos,
ast.USub: ir_values.operators.neg
}
return ast_unops[type(node.op)](self.visit_expression(node.operand), self.builder)
def _visit_expr_BinOp(self, node):
ast_binops = {
ast.Add: ir_values.operators.add,
ast.Sub: ir_values.operators.sub,
ast.Mult: ir_values.operators.mul,
ast.Div: ir_values.operators.truediv,
ast.FloorDiv: ir_values.operators.floordiv,
ast.Mod: ir_values.operators.mod,
ast.Pow: ir_values.operators.pow,
ast.LShift: ir_values.operators.lshift,
ast.RShift: ir_values.operators.rshift,
ast.BitOr: ir_values.operators.or_,
ast.BitXor: ir_values.operators.xor,
ast.BitAnd: ir_values.operators.and_
}
return ast_binops[type(node.op)](self.visit_expression(node.left), self.visit_expression(node.right), self.builder)
def _visit_expr_Compare(self, node):
ast_cmps = {
ast.Eq: ir_values.operators.eq,
ast.NotEq: ir_values.operators.ne,
ast.Lt: ir_values.operators.lt,
ast.LtE: ir_values.operators.le,
ast.Gt: ir_values.operators.gt,
ast.GtE: ir_values.operators.ge
}
comparisons = []
old_comparator = self.visit_expression(node.left)
for op, comparator_a in zip(node.ops, node.comparators):
comparator = self.visit_expression(comparator_a)
comparison = ast_cmps[type(op)](old_comparator, comparator)
comparisons.append(comparison)
old_comparator = comparator
r = comparisons[0]
for comparison in comparisons[1:]:
r = ir_values.operators.and_(r, comparison)
return r
def _visit_expr_Call(self, node):
ast_unfuns = {
"bool": ir_values.operators.bool,
"int": ir_values.operators.int,
"int64": ir_values.operators.int64,
"round": ir_values.operators.round,
"round64": ir_values.operators.round64,
}
return ast_unfuns[node.func.id](self.visit_expression(node.args[0]), self.builder)
def visit_statements(self, stmts):
for node in stmts:
method = "_visit_stmt_" + node.__class__.__name__
try:
visitor = getattr(self, method)
except AttributeError:
raise NotImplementedError("Unsupported node '{}' in statement".format(node.__class__.__name__))
visitor(node)
def _visit_stmt_Assign(self, node):
val = self.visit_expression(node.value)
for target in node.targets:
if isinstance(target, ast.Name):
self.ns.store(self.builder, val, target.id)
else:
raise NotImplementedError
def _visit_stmt_AugAssign(self, node):
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
if isinstance(node.target, ast.Name):
self.ns.store(self.builder, val, node.target.id)
else:
raise NotImplementedError
def _visit_stmt_Expr(self, node):
self.visit_expression(node.value)
def _visit_stmt_If(self, node):
function = self.builder.basic_block.function
then_block = function.append_basic_block("i_then")
else_block = function.append_basic_block("i_else")
merge_block = function.append_basic_block("i_merge")
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
self.builder.cbranch(condition.llvm_value, then_block, else_block)
self.builder.position_at_end(then_block)
self.visit_statements(node.body)
self.builder.branch(merge_block)
self.builder.position_at_end(else_block)
self.visit_statements(node.orelse)
self.builder.branch(merge_block)
self.builder.position_at_end(merge_block)
def _visit_stmt_While(self, node):
function = self.builder.basic_block.function
body_block = function.append_basic_block("w_body")
else_block = function.append_basic_block("w_else")
merge_block = function.append_basic_block("w_merge")
condition = self.visit_expression(node.test)
self.builder.cbranch(condition, body_block, else_block)
self.builder.position_at_end(body_block)
self.visit_statements(node.body)
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
self.builder.cbranch(condition.llvm_value, body_block, merge_block)
self.builder.position_at_end(else_block)
self.visit_statements(node.orelse)
self.builder.branch(merge_block)
self.builder.position_at_end(merge_block)

View File

@ -2,7 +2,7 @@ import ast
from operator import itemgetter from operator import itemgetter
from copy import deepcopy from copy import deepcopy
from artiq.compiler.ir_ast_body import ExpressionVisitor from artiq.compiler.ir_ast_body import Visitor
class _Namespace: class _Namespace:
def __init__(self, name_to_value): def __init__(self, name_to_value):
@ -13,10 +13,10 @@ class _Namespace:
class _TypeScanner(ast.NodeVisitor): class _TypeScanner(ast.NodeVisitor):
def __init__(self, namespace): def __init__(self, namespace):
self.exprv = ExpressionVisitor(None, namespace) self.exprv = Visitor(None, namespace)
def visit_Assign(self, node): def visit_Assign(self, node):
val = self.exprv.visit(node.value) val = self.exprv.visit_expression(node.value)
n2v = self.exprv.ns.name_to_value n2v = self.exprv.ns.name_to_value
for target in node.targets: for target in node.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
@ -28,7 +28,7 @@ class _TypeScanner(ast.NodeVisitor):
raise NotImplementedError raise NotImplementedError
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
val = self.exprv.visit(ast.BinOp(op=node.op, left=node.target, right=node.value)) val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
n2v = self.exprv.ns.name_to_value n2v = self.exprv.ns.name_to_value
target = node.target target = node.target
if isinstance(target, ast.Name): if isinstance(target, ast.Name):