From 30ef6119e630931105320e323ff0c783dbbb4c3a Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sun, 17 Aug 2014 21:46:11 +0800 Subject: [PATCH] compiler/ir_ast_body: refactor and add statement visitor --- artiq/compiler/ir_ast_body.py | 238 ++++++++++++++++++++----------- artiq/compiler/ir_infer_types.py | 8 +- 2 files changed, 160 insertions(+), 86 deletions(-) diff --git a/artiq/compiler/ir_ast_body.py b/artiq/compiler/ir_ast_body.py index 58f18f1c8..7ac1c8846 100644 --- a/artiq/compiler/ir_ast_body.py +++ b/artiq/compiler/ir_ast_body.py @@ -2,91 +2,165 @@ import ast from artiq.compiler import ir_values -_ast_unops = { - ast.Invert: ir_values.operators.inv, - ast.Not: ir_values.operators.not_, - ast.UAdd: ir_values.operators.pos, - ast.USub: ir_values.operators.neg -} - -_ast_binops = { - ast.Add: ir_values.operators.add, - ast.Sub: ir_values.operators.sub, - ast.Mult: ir_values.operators.mul, - ast.Div: ir_values.operators.truediv, - ast.FloorDiv: ir_values.operators.floordiv, - ast.Mod: ir_values.operators.mod, - ast.Pow: ir_values.operators.pow, - ast.LShift: ir_values.operators.lshift, - ast.RShift: ir_values.operators.rshift, - ast.BitOr: ir_values.operators.or_, - ast.BitXor: ir_values.operators.xor, - ast.BitAnd: ir_values.operators.and_ -} - -_ast_cmps = { - ast.Eq: ir_values.operators.eq, - ast.NotEq: ir_values.operators.ne, - ast.Lt: ir_values.operators.lt, - ast.LtE: ir_values.operators.le, - ast.Gt: ir_values.operators.gt, - ast.GtE: ir_values.operators.ge -} - -_ast_unfuns = { - "bool": ir_values.operators.bool, - "int": ir_values.operators.int, - "int64": ir_values.operators.int64, - "round": ir_values.operators.round, - "round64": ir_values.operators.round64, -} - -class ExpressionVisitor: +class Visitor: def __init__(self, builder, ns): self.builder = builder self.ns = ns - def visit(self, node): - if isinstance(node, ast.Name): - return self.ns.load(self.builder, node.id) - elif isinstance(node, ast.NameConstant): - v = node.value - if isinstance(v, bool): - r = ir_values.VBool() - else: - raise NotImplementedError - if self.builder is not None: - r.create_constant(v) - return r - elif isinstance(node, ast.Num): - n = node.n - if isinstance(n, int): - if abs(n) < 2**31: - r = ir_values.VInt() - else: - r = ir_values.VInt(64) - else: - raise NotImplementedError - if self.builder is not None: - r.create_constant(n) - return r - elif isinstance(node, ast.UnaryOp): - return _ast_unops[type(node.op)](self.visit(node.operand), self.builder) - elif isinstance(node, ast.BinOp): - return _ast_binops[type(node.op)](self.visit(node.left), self.visit(node.right), self.builder) - elif isinstance(node, ast.Compare): - comparisons = [] - old_comparator = self.visit(node.left) - for op, comparator_a in zip(node.ops, node.comparators): - comparator = self.visit(comparator_a) - comparison = _ast_cmps[type(op)](old_comparator, comparator) - comparisons.append(comparison) - old_comparator = comparator - r = comparisons[0] - for comparison in comparisons[1:]: - r = ir_values.operators.and_(r, comparison) - return r - elif isinstance(node, ast.Call): - return _ast_unfuns[node.func.id](self.visit(node.args[0]), self.builder) + # builder can be None for visit_expression + def visit_expression(self, node): + method = "_visit_expr_" + node.__class__.__name__ + try: + visitor = getattr(self, method) + except AttributeError: + raise NotImplementedError("Unsupported node '{}' in expression".format(node.__class__.__name__)) + return visitor(node) + + def _visit_expr_Name(self, node): + return self.ns.load(self.builder, node.id) + + def _visit_expr_NameConstant(self, node): + v = node.value + if isinstance(v, bool): + r = ir_values.VBool() else: raise NotImplementedError + if self.builder is not None: + r.create_constant(v) + return r + + def _visit_expr_Num(self, node): + n = node.n + if isinstance(n, int): + if abs(n) < 2**31: + r = ir_values.VInt() + else: + r = ir_values.VInt(64) + else: + raise NotImplementedError + if self.builder is not None: + r.create_constant(n) + return r + + def _visit_expr_UnaryOp(self, node): + ast_unops = { + ast.Invert: ir_values.operators.inv, + ast.Not: ir_values.operators.not_, + ast.UAdd: ir_values.operators.pos, + ast.USub: ir_values.operators.neg + } + return ast_unops[type(node.op)](self.visit_expression(node.operand), self.builder) + + def _visit_expr_BinOp(self, node): + ast_binops = { + ast.Add: ir_values.operators.add, + ast.Sub: ir_values.operators.sub, + ast.Mult: ir_values.operators.mul, + ast.Div: ir_values.operators.truediv, + ast.FloorDiv: ir_values.operators.floordiv, + ast.Mod: ir_values.operators.mod, + ast.Pow: ir_values.operators.pow, + ast.LShift: ir_values.operators.lshift, + ast.RShift: ir_values.operators.rshift, + ast.BitOr: ir_values.operators.or_, + ast.BitXor: ir_values.operators.xor, + ast.BitAnd: ir_values.operators.and_ + } + return ast_binops[type(node.op)](self.visit_expression(node.left), self.visit_expression(node.right), self.builder) + + def _visit_expr_Compare(self, node): + ast_cmps = { + ast.Eq: ir_values.operators.eq, + ast.NotEq: ir_values.operators.ne, + ast.Lt: ir_values.operators.lt, + ast.LtE: ir_values.operators.le, + ast.Gt: ir_values.operators.gt, + ast.GtE: ir_values.operators.ge + } + comparisons = [] + old_comparator = self.visit_expression(node.left) + for op, comparator_a in zip(node.ops, node.comparators): + comparator = self.visit_expression(comparator_a) + comparison = ast_cmps[type(op)](old_comparator, comparator) + comparisons.append(comparison) + old_comparator = comparator + r = comparisons[0] + for comparison in comparisons[1:]: + r = ir_values.operators.and_(r, comparison) + return r + + def _visit_expr_Call(self, node): + ast_unfuns = { + "bool": ir_values.operators.bool, + "int": ir_values.operators.int, + "int64": ir_values.operators.int64, + "round": ir_values.operators.round, + "round64": ir_values.operators.round64, + } + return ast_unfuns[node.func.id](self.visit_expression(node.args[0]), self.builder) + + def visit_statements(self, stmts): + for node in stmts: + method = "_visit_stmt_" + node.__class__.__name__ + try: + visitor = getattr(self, method) + except AttributeError: + raise NotImplementedError("Unsupported node '{}' in statement".format(node.__class__.__name__)) + visitor(node) + + def _visit_stmt_Assign(self, node): + val = self.visit_expression(node.value) + for target in node.targets: + if isinstance(target, ast.Name): + self.ns.store(self.builder, val, target.id) + else: + raise NotImplementedError + + def _visit_stmt_AugAssign(self, node): + val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) + if isinstance(node.target, ast.Name): + self.ns.store(self.builder, val, node.target.id) + else: + raise NotImplementedError + + def _visit_stmt_Expr(self, node): + self.visit_expression(node.value) + + def _visit_stmt_If(self, node): + function = self.builder.basic_block.function + then_block = function.append_basic_block("i_then") + else_block = function.append_basic_block("i_else") + merge_block = function.append_basic_block("i_merge") + + condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) + self.builder.cbranch(condition.llvm_value, then_block, else_block) + + self.builder.position_at_end(then_block) + self.visit_statements(node.body) + self.builder.branch(merge_block) + + self.builder.position_at_end(else_block) + self.visit_statements(node.orelse) + self.builder.branch(merge_block) + + self.builder.position_at_end(merge_block) + + def _visit_stmt_While(self, node): + function = self.builder.basic_block.function + body_block = function.append_basic_block("w_body") + else_block = function.append_basic_block("w_else") + merge_block = function.append_basic_block("w_merge") + + condition = self.visit_expression(node.test) + self.builder.cbranch(condition, body_block, else_block) + + self.builder.position_at_end(body_block) + self.visit_statements(node.body) + condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) + self.builder.cbranch(condition.llvm_value, body_block, merge_block) + + self.builder.position_at_end(else_block) + self.visit_statements(node.orelse) + self.builder.branch(merge_block) + + self.builder.position_at_end(merge_block) diff --git a/artiq/compiler/ir_infer_types.py b/artiq/compiler/ir_infer_types.py index afea172a5..1d91f8c18 100644 --- a/artiq/compiler/ir_infer_types.py +++ b/artiq/compiler/ir_infer_types.py @@ -2,7 +2,7 @@ import ast from operator import itemgetter from copy import deepcopy -from artiq.compiler.ir_ast_body import ExpressionVisitor +from artiq.compiler.ir_ast_body import Visitor class _Namespace: def __init__(self, name_to_value): @@ -13,10 +13,10 @@ class _Namespace: class _TypeScanner(ast.NodeVisitor): def __init__(self, namespace): - self.exprv = ExpressionVisitor(None, namespace) + self.exprv = Visitor(None, namespace) def visit_Assign(self, node): - val = self.exprv.visit(node.value) + val = self.exprv.visit_expression(node.value) n2v = self.exprv.ns.name_to_value for target in node.targets: if isinstance(target, ast.Name): @@ -28,7 +28,7 @@ class _TypeScanner(ast.NodeVisitor): raise NotImplementedError def visit_AugAssign(self, node): - val = self.exprv.visit(ast.BinOp(op=node.op, left=node.target, right=node.value)) + val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) n2v = self.exprv.ns.name_to_value target = node.target if isinstance(target, ast.Name):