From 9e21ea5658e0723b8a76eea014aceb401ed9e177 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sun, 17 Aug 2014 22:15:10 +0800 Subject: [PATCH] compiler/ir_ast_body,ir_infer_types: use Python dict directly as namespace --- artiq/compiler/ir_ast_body.py | 23 ++++++++++++++---- artiq/compiler/ir_infer_types.py | 40 +++++++++++++------------------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/artiq/compiler/ir_ast_body.py b/artiq/compiler/ir_ast_body.py index 7ac1c8846..ced70968a 100644 --- a/artiq/compiler/ir_ast_body.py +++ b/artiq/compiler/ir_ast_body.py @@ -1,11 +1,14 @@ import ast +from copy import copy + +from llvm import core as lc from artiq.compiler import ir_values class Visitor: - def __init__(self, builder, ns): - self.builder = builder + def __init__(self, ns, builder=None): self.ns = ns + self.builder = builder # builder can be None for visit_expression def visit_expression(self, node): @@ -17,7 +20,17 @@ class Visitor: return visitor(node) def _visit_expr_Name(self, node): - return self.ns.load(self.builder, node.id) + try: + r = self.ns[node.id] + except KeyError: + raise NameError("Name '{}' is not defined".format(node.id)) + r = copy(r) + if self.builder is None: + r.llvm_value = None + else: + if isinstance(r.llvm_value, lc.AllocaInstruction): + r.llvm_value = self.builder.load(r.llvm_value) + return r def _visit_expr_NameConstant(self, node): v = node.value @@ -112,14 +125,14 @@ class Visitor: val = self.visit_expression(node.value) for target in node.targets: if isinstance(target, ast.Name): - self.ns.store(self.builder, val, target.id) + self.builder.store(val, self.ns[target.id]) else: raise NotImplementedError def _visit_stmt_AugAssign(self, node): val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) if isinstance(node.target, ast.Name): - self.ns.store(self.builder, val, node.target.id) + self.builder.store(val, self.ns[node.target.id]) else: raise NotImplementedError diff --git a/artiq/compiler/ir_infer_types.py b/artiq/compiler/ir_infer_types.py index 1d91f8c18..62f2e3c44 100644 --- a/artiq/compiler/ir_infer_types.py +++ b/artiq/compiler/ir_infer_types.py @@ -4,51 +4,43 @@ from copy import deepcopy from artiq.compiler.ir_ast_body import Visitor -class _Namespace: - def __init__(self, name_to_value): - self.name_to_value = name_to_value - - def load(self, builder, name): - return self.name_to_value[name] - class _TypeScanner(ast.NodeVisitor): - def __init__(self, namespace): - self.exprv = Visitor(None, namespace) + def __init__(self, ns): + self.exprv = Visitor(ns) def visit_Assign(self, node): val = self.exprv.visit_expression(node.value) - n2v = self.exprv.ns.name_to_value + ns = self.exprv.ns for target in node.targets: if isinstance(target, ast.Name): - if target.id in n2v: - n2v[target.id].merge(val) + if target.id in ns: + ns[target.id].merge(val) else: - n2v[target.id] = val + ns[target.id] = val else: raise NotImplementedError def visit_AugAssign(self, node): val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) - n2v = self.exprv.ns.name_to_value + ns = self.exprv.ns target = node.target if isinstance(target, ast.Name): - if target.id in n2v: - n2v[target.id].merge(val) + if target.id in ns: + ns[target.id].merge(val) else: - n2v[target.id] = val + ns[target.id] = val else: raise NotImplementedError def infer_types(node): - name_to_value = dict() + ns = dict() while True: - prev_name_to_value = deepcopy(name_to_value) - ns = _Namespace(name_to_value) + prev_ns = deepcopy(ns) ts = _TypeScanner(ns) ts.visit(node) - if prev_name_to_value and all(v.same_type(prev_name_to_value[k]) for k, v in name_to_value.items()): + if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()): # no more promotions - completed - return name_to_value + return ns if __name__ == "__main__": testcode = """ @@ -60,6 +52,6 @@ x = int64(7) a += x # promotes a to int64 foo = True """ - n2v = infer_types(ast.parse(testcode)) - for k, v in sorted(n2v.items(), key=itemgetter(0)): + ns = infer_types(ast.parse(testcode)) + for k, v in sorted(ns.items(), key=itemgetter(0)): print("{:10}--> {}".format(k, str(v)))