forked from M-Labs/artiq
compiler/ir_ast_body: refactor and add statement visitor
This commit is contained in:
parent
d41d06835c
commit
30ef6119e6
|
@ -2,14 +2,57 @@ import ast
|
||||||
|
|
||||||
from artiq.compiler import ir_values
|
from artiq.compiler import ir_values
|
||||||
|
|
||||||
_ast_unops = {
|
class Visitor:
|
||||||
|
def __init__(self, builder, ns):
|
||||||
|
self.builder = builder
|
||||||
|
self.ns = ns
|
||||||
|
|
||||||
|
# builder can be None for visit_expression
|
||||||
|
def visit_expression(self, node):
|
||||||
|
method = "_visit_expr_" + node.__class__.__name__
|
||||||
|
try:
|
||||||
|
visitor = getattr(self, method)
|
||||||
|
except AttributeError:
|
||||||
|
raise NotImplementedError("Unsupported node '{}' in expression".format(node.__class__.__name__))
|
||||||
|
return visitor(node)
|
||||||
|
|
||||||
|
def _visit_expr_Name(self, node):
|
||||||
|
return self.ns.load(self.builder, node.id)
|
||||||
|
|
||||||
|
def _visit_expr_NameConstant(self, node):
|
||||||
|
v = node.value
|
||||||
|
if isinstance(v, bool):
|
||||||
|
r = ir_values.VBool()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
if self.builder is not None:
|
||||||
|
r.create_constant(v)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def _visit_expr_Num(self, node):
|
||||||
|
n = node.n
|
||||||
|
if isinstance(n, int):
|
||||||
|
if abs(n) < 2**31:
|
||||||
|
r = ir_values.VInt()
|
||||||
|
else:
|
||||||
|
r = ir_values.VInt(64)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
if self.builder is not None:
|
||||||
|
r.create_constant(n)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def _visit_expr_UnaryOp(self, node):
|
||||||
|
ast_unops = {
|
||||||
ast.Invert: ir_values.operators.inv,
|
ast.Invert: ir_values.operators.inv,
|
||||||
ast.Not: ir_values.operators.not_,
|
ast.Not: ir_values.operators.not_,
|
||||||
ast.UAdd: ir_values.operators.pos,
|
ast.UAdd: ir_values.operators.pos,
|
||||||
ast.USub: ir_values.operators.neg
|
ast.USub: ir_values.operators.neg
|
||||||
}
|
}
|
||||||
|
return ast_unops[type(node.op)](self.visit_expression(node.operand), self.builder)
|
||||||
|
|
||||||
_ast_binops = {
|
def _visit_expr_BinOp(self, node):
|
||||||
|
ast_binops = {
|
||||||
ast.Add: ir_values.operators.add,
|
ast.Add: ir_values.operators.add,
|
||||||
ast.Sub: ir_values.operators.sub,
|
ast.Sub: ir_values.operators.sub,
|
||||||
ast.Mult: ir_values.operators.mul,
|
ast.Mult: ir_values.operators.mul,
|
||||||
|
@ -22,71 +65,102 @@ _ast_binops = {
|
||||||
ast.BitOr: ir_values.operators.or_,
|
ast.BitOr: ir_values.operators.or_,
|
||||||
ast.BitXor: ir_values.operators.xor,
|
ast.BitXor: ir_values.operators.xor,
|
||||||
ast.BitAnd: ir_values.operators.and_
|
ast.BitAnd: ir_values.operators.and_
|
||||||
}
|
}
|
||||||
|
return ast_binops[type(node.op)](self.visit_expression(node.left), self.visit_expression(node.right), self.builder)
|
||||||
|
|
||||||
_ast_cmps = {
|
def _visit_expr_Compare(self, node):
|
||||||
|
ast_cmps = {
|
||||||
ast.Eq: ir_values.operators.eq,
|
ast.Eq: ir_values.operators.eq,
|
||||||
ast.NotEq: ir_values.operators.ne,
|
ast.NotEq: ir_values.operators.ne,
|
||||||
ast.Lt: ir_values.operators.lt,
|
ast.Lt: ir_values.operators.lt,
|
||||||
ast.LtE: ir_values.operators.le,
|
ast.LtE: ir_values.operators.le,
|
||||||
ast.Gt: ir_values.operators.gt,
|
ast.Gt: ir_values.operators.gt,
|
||||||
ast.GtE: ir_values.operators.ge
|
ast.GtE: ir_values.operators.ge
|
||||||
}
|
}
|
||||||
|
|
||||||
_ast_unfuns = {
|
|
||||||
"bool": ir_values.operators.bool,
|
|
||||||
"int": ir_values.operators.int,
|
|
||||||
"int64": ir_values.operators.int64,
|
|
||||||
"round": ir_values.operators.round,
|
|
||||||
"round64": ir_values.operators.round64,
|
|
||||||
}
|
|
||||||
|
|
||||||
class ExpressionVisitor:
|
|
||||||
def __init__(self, builder, ns):
|
|
||||||
self.builder = builder
|
|
||||||
self.ns = ns
|
|
||||||
|
|
||||||
def visit(self, node):
|
|
||||||
if isinstance(node, ast.Name):
|
|
||||||
return self.ns.load(self.builder, node.id)
|
|
||||||
elif isinstance(node, ast.NameConstant):
|
|
||||||
v = node.value
|
|
||||||
if isinstance(v, bool):
|
|
||||||
r = ir_values.VBool()
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
if self.builder is not None:
|
|
||||||
r.create_constant(v)
|
|
||||||
return r
|
|
||||||
elif isinstance(node, ast.Num):
|
|
||||||
n = node.n
|
|
||||||
if isinstance(n, int):
|
|
||||||
if abs(n) < 2**31:
|
|
||||||
r = ir_values.VInt()
|
|
||||||
else:
|
|
||||||
r = ir_values.VInt(64)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
if self.builder is not None:
|
|
||||||
r.create_constant(n)
|
|
||||||
return r
|
|
||||||
elif isinstance(node, ast.UnaryOp):
|
|
||||||
return _ast_unops[type(node.op)](self.visit(node.operand), self.builder)
|
|
||||||
elif isinstance(node, ast.BinOp):
|
|
||||||
return _ast_binops[type(node.op)](self.visit(node.left), self.visit(node.right), self.builder)
|
|
||||||
elif isinstance(node, ast.Compare):
|
|
||||||
comparisons = []
|
comparisons = []
|
||||||
old_comparator = self.visit(node.left)
|
old_comparator = self.visit_expression(node.left)
|
||||||
for op, comparator_a in zip(node.ops, node.comparators):
|
for op, comparator_a in zip(node.ops, node.comparators):
|
||||||
comparator = self.visit(comparator_a)
|
comparator = self.visit_expression(comparator_a)
|
||||||
comparison = _ast_cmps[type(op)](old_comparator, comparator)
|
comparison = ast_cmps[type(op)](old_comparator, comparator)
|
||||||
comparisons.append(comparison)
|
comparisons.append(comparison)
|
||||||
old_comparator = comparator
|
old_comparator = comparator
|
||||||
r = comparisons[0]
|
r = comparisons[0]
|
||||||
for comparison in comparisons[1:]:
|
for comparison in comparisons[1:]:
|
||||||
r = ir_values.operators.and_(r, comparison)
|
r = ir_values.operators.and_(r, comparison)
|
||||||
return r
|
return r
|
||||||
elif isinstance(node, ast.Call):
|
|
||||||
return _ast_unfuns[node.func.id](self.visit(node.args[0]), self.builder)
|
def _visit_expr_Call(self, node):
|
||||||
|
ast_unfuns = {
|
||||||
|
"bool": ir_values.operators.bool,
|
||||||
|
"int": ir_values.operators.int,
|
||||||
|
"int64": ir_values.operators.int64,
|
||||||
|
"round": ir_values.operators.round,
|
||||||
|
"round64": ir_values.operators.round64,
|
||||||
|
}
|
||||||
|
return ast_unfuns[node.func.id](self.visit_expression(node.args[0]), self.builder)
|
||||||
|
|
||||||
|
def visit_statements(self, stmts):
|
||||||
|
for node in stmts:
|
||||||
|
method = "_visit_stmt_" + node.__class__.__name__
|
||||||
|
try:
|
||||||
|
visitor = getattr(self, method)
|
||||||
|
except AttributeError:
|
||||||
|
raise NotImplementedError("Unsupported node '{}' in statement".format(node.__class__.__name__))
|
||||||
|
visitor(node)
|
||||||
|
|
||||||
|
def _visit_stmt_Assign(self, node):
|
||||||
|
val = self.visit_expression(node.value)
|
||||||
|
for target in node.targets:
|
||||||
|
if isinstance(target, ast.Name):
|
||||||
|
self.ns.store(self.builder, val, target.id)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _visit_stmt_AugAssign(self, node):
|
||||||
|
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
||||||
|
if isinstance(node.target, ast.Name):
|
||||||
|
self.ns.store(self.builder, val, node.target.id)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _visit_stmt_Expr(self, node):
|
||||||
|
self.visit_expression(node.value)
|
||||||
|
|
||||||
|
def _visit_stmt_If(self, node):
|
||||||
|
function = self.builder.basic_block.function
|
||||||
|
then_block = function.append_basic_block("i_then")
|
||||||
|
else_block = function.append_basic_block("i_else")
|
||||||
|
merge_block = function.append_basic_block("i_merge")
|
||||||
|
|
||||||
|
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
||||||
|
self.builder.cbranch(condition.llvm_value, then_block, else_block)
|
||||||
|
|
||||||
|
self.builder.position_at_end(then_block)
|
||||||
|
self.visit_statements(node.body)
|
||||||
|
self.builder.branch(merge_block)
|
||||||
|
|
||||||
|
self.builder.position_at_end(else_block)
|
||||||
|
self.visit_statements(node.orelse)
|
||||||
|
self.builder.branch(merge_block)
|
||||||
|
|
||||||
|
self.builder.position_at_end(merge_block)
|
||||||
|
|
||||||
|
def _visit_stmt_While(self, node):
|
||||||
|
function = self.builder.basic_block.function
|
||||||
|
body_block = function.append_basic_block("w_body")
|
||||||
|
else_block = function.append_basic_block("w_else")
|
||||||
|
merge_block = function.append_basic_block("w_merge")
|
||||||
|
|
||||||
|
condition = self.visit_expression(node.test)
|
||||||
|
self.builder.cbranch(condition, body_block, else_block)
|
||||||
|
|
||||||
|
self.builder.position_at_end(body_block)
|
||||||
|
self.visit_statements(node.body)
|
||||||
|
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder)
|
||||||
|
self.builder.cbranch(condition.llvm_value, body_block, merge_block)
|
||||||
|
|
||||||
|
self.builder.position_at_end(else_block)
|
||||||
|
self.visit_statements(node.orelse)
|
||||||
|
self.builder.branch(merge_block)
|
||||||
|
|
||||||
|
self.builder.position_at_end(merge_block)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import ast
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from artiq.compiler.ir_ast_body import ExpressionVisitor
|
from artiq.compiler.ir_ast_body import Visitor
|
||||||
|
|
||||||
class _Namespace:
|
class _Namespace:
|
||||||
def __init__(self, name_to_value):
|
def __init__(self, name_to_value):
|
||||||
|
@ -13,10 +13,10 @@ class _Namespace:
|
||||||
|
|
||||||
class _TypeScanner(ast.NodeVisitor):
|
class _TypeScanner(ast.NodeVisitor):
|
||||||
def __init__(self, namespace):
|
def __init__(self, namespace):
|
||||||
self.exprv = ExpressionVisitor(None, namespace)
|
self.exprv = Visitor(None, namespace)
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
val = self.exprv.visit(node.value)
|
val = self.exprv.visit_expression(node.value)
|
||||||
n2v = self.exprv.ns.name_to_value
|
n2v = self.exprv.ns.name_to_value
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
|
@ -28,7 +28,7 @@ class _TypeScanner(ast.NodeVisitor):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def visit_AugAssign(self, node):
|
def visit_AugAssign(self, node):
|
||||||
val = self.exprv.visit(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
||||||
n2v = self.exprv.ns.name_to_value
|
n2v = self.exprv.ns.name_to_value
|
||||||
target = node.target
|
target = node.target
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
|
|
Loading…
Reference in New Issue