forked from M-Labs/artiq
1
0
Fork 0

compiler/ir_ast_body,ir_infer_types: use Python dict directly as namespace

This commit is contained in:
Sebastien Bourdeauducq 2014-08-17 22:15:10 +08:00
parent 30ef6119e6
commit 9e21ea5658
2 changed files with 34 additions and 29 deletions

View File

@ -1,11 +1,14 @@
import ast import ast
from copy import copy
from llvm import core as lc
from artiq.compiler import ir_values from artiq.compiler import ir_values
class Visitor: class Visitor:
def __init__(self, builder, ns): def __init__(self, ns, builder=None):
self.builder = builder
self.ns = ns self.ns = ns
self.builder = builder
# builder can be None for visit_expression # builder can be None for visit_expression
def visit_expression(self, node): def visit_expression(self, node):
@ -17,7 +20,17 @@ class Visitor:
return visitor(node) return visitor(node)
def _visit_expr_Name(self, node): def _visit_expr_Name(self, node):
return self.ns.load(self.builder, node.id) try:
r = self.ns[node.id]
except KeyError:
raise NameError("Name '{}' is not defined".format(node.id))
r = copy(r)
if self.builder is None:
r.llvm_value = None
else:
if isinstance(r.llvm_value, lc.AllocaInstruction):
r.llvm_value = self.builder.load(r.llvm_value)
return r
def _visit_expr_NameConstant(self, node): def _visit_expr_NameConstant(self, node):
v = node.value v = node.value
@ -112,14 +125,14 @@ class Visitor:
val = self.visit_expression(node.value) val = self.visit_expression(node.value)
for target in node.targets: for target in node.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
self.ns.store(self.builder, val, target.id) self.builder.store(val, self.ns[target.id])
else: else:
raise NotImplementedError raise NotImplementedError
def _visit_stmt_AugAssign(self, node): def _visit_stmt_AugAssign(self, node):
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
if isinstance(node.target, ast.Name): if isinstance(node.target, ast.Name):
self.ns.store(self.builder, val, node.target.id) self.builder.store(val, self.ns[node.target.id])
else: else:
raise NotImplementedError raise NotImplementedError

View File

@ -4,51 +4,43 @@ from copy import deepcopy
from artiq.compiler.ir_ast_body import Visitor from artiq.compiler.ir_ast_body import Visitor
class _Namespace:
def __init__(self, name_to_value):
self.name_to_value = name_to_value
def load(self, builder, name):
return self.name_to_value[name]
class _TypeScanner(ast.NodeVisitor): class _TypeScanner(ast.NodeVisitor):
def __init__(self, namespace): def __init__(self, ns):
self.exprv = Visitor(None, namespace) self.exprv = Visitor(ns)
def visit_Assign(self, node): def visit_Assign(self, node):
val = self.exprv.visit_expression(node.value) val = self.exprv.visit_expression(node.value)
n2v = self.exprv.ns.name_to_value ns = self.exprv.ns
for target in node.targets: for target in node.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
if target.id in n2v: if target.id in ns:
n2v[target.id].merge(val) ns[target.id].merge(val)
else: else:
n2v[target.id] = val ns[target.id] = val
else: else:
raise NotImplementedError raise NotImplementedError
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
n2v = self.exprv.ns.name_to_value ns = self.exprv.ns
target = node.target target = node.target
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
if target.id in n2v: if target.id in ns:
n2v[target.id].merge(val) ns[target.id].merge(val)
else: else:
n2v[target.id] = val ns[target.id] = val
else: else:
raise NotImplementedError raise NotImplementedError
def infer_types(node): def infer_types(node):
name_to_value = dict() ns = dict()
while True: while True:
prev_name_to_value = deepcopy(name_to_value) prev_ns = deepcopy(ns)
ns = _Namespace(name_to_value)
ts = _TypeScanner(ns) ts = _TypeScanner(ns)
ts.visit(node) ts.visit(node)
if prev_name_to_value and all(v.same_type(prev_name_to_value[k]) for k, v in name_to_value.items()): if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()):
# no more promotions - completed # no more promotions - completed
return name_to_value return ns
if __name__ == "__main__": if __name__ == "__main__":
testcode = """ testcode = """
@ -60,6 +52,6 @@ x = int64(7)
a += x # promotes a to int64 a += x # promotes a to int64
foo = True foo = True
""" """
n2v = infer_types(ast.parse(testcode)) ns = infer_types(ast.parse(testcode))
for k, v in sorted(n2v.items(), key=itemgetter(0)): for k, v in sorted(ns.items(), key=itemgetter(0)):
print("{:10}--> {}".format(k, str(v))) print("{:10}--> {}".format(k, str(v)))