forked from M-Labs/artiq
compiler/ir_ast_body,ir_infer_types: use Python dict directly as namespace
This commit is contained in:
parent
30ef6119e6
commit
9e21ea5658
|
@ -1,11 +1,14 @@
|
|||
import ast
|
||||
from copy import copy
|
||||
|
||||
from llvm import core as lc
|
||||
|
||||
from artiq.compiler import ir_values
|
||||
|
||||
class Visitor:
|
||||
def __init__(self, builder, ns):
|
||||
self.builder = builder
|
||||
def __init__(self, ns, builder=None):
|
||||
self.ns = ns
|
||||
self.builder = builder
|
||||
|
||||
# builder can be None for visit_expression
|
||||
def visit_expression(self, node):
|
||||
|
@ -17,7 +20,17 @@ class Visitor:
|
|||
return visitor(node)
|
||||
|
||||
def _visit_expr_Name(self, node):
|
||||
return self.ns.load(self.builder, node.id)
|
||||
try:
|
||||
r = self.ns[node.id]
|
||||
except KeyError:
|
||||
raise NameError("Name '{}' is not defined".format(node.id))
|
||||
r = copy(r)
|
||||
if self.builder is None:
|
||||
r.llvm_value = None
|
||||
else:
|
||||
if isinstance(r.llvm_value, lc.AllocaInstruction):
|
||||
r.llvm_value = self.builder.load(r.llvm_value)
|
||||
return r
|
||||
|
||||
def _visit_expr_NameConstant(self, node):
|
||||
v = node.value
|
||||
|
@ -112,14 +125,14 @@ class Visitor:
|
|||
val = self.visit_expression(node.value)
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
self.ns.store(self.builder, val, target.id)
|
||||
self.builder.store(val, self.ns[target.id])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _visit_stmt_AugAssign(self, node):
|
||||
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
||||
if isinstance(node.target, ast.Name):
|
||||
self.ns.store(self.builder, val, node.target.id)
|
||||
self.builder.store(val, self.ns[node.target.id])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -4,51 +4,43 @@ from copy import deepcopy
|
|||
|
||||
from artiq.compiler.ir_ast_body import Visitor
|
||||
|
||||
class _Namespace:
|
||||
def __init__(self, name_to_value):
|
||||
self.name_to_value = name_to_value
|
||||
|
||||
def load(self, builder, name):
|
||||
return self.name_to_value[name]
|
||||
|
||||
class _TypeScanner(ast.NodeVisitor):
|
||||
def __init__(self, namespace):
|
||||
self.exprv = Visitor(None, namespace)
|
||||
def __init__(self, ns):
|
||||
self.exprv = Visitor(ns)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
val = self.exprv.visit_expression(node.value)
|
||||
n2v = self.exprv.ns.name_to_value
|
||||
ns = self.exprv.ns
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in n2v:
|
||||
n2v[target.id].merge(val)
|
||||
if target.id in ns:
|
||||
ns[target.id].merge(val)
|
||||
else:
|
||||
n2v[target.id] = val
|
||||
ns[target.id] = val
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
||||
n2v = self.exprv.ns.name_to_value
|
||||
ns = self.exprv.ns
|
||||
target = node.target
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in n2v:
|
||||
n2v[target.id].merge(val)
|
||||
if target.id in ns:
|
||||
ns[target.id].merge(val)
|
||||
else:
|
||||
n2v[target.id] = val
|
||||
ns[target.id] = val
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def infer_types(node):
|
||||
name_to_value = dict()
|
||||
ns = dict()
|
||||
while True:
|
||||
prev_name_to_value = deepcopy(name_to_value)
|
||||
ns = _Namespace(name_to_value)
|
||||
prev_ns = deepcopy(ns)
|
||||
ts = _TypeScanner(ns)
|
||||
ts.visit(node)
|
||||
if prev_name_to_value and all(v.same_type(prev_name_to_value[k]) for k, v in name_to_value.items()):
|
||||
if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()):
|
||||
# no more promotions - completed
|
||||
return name_to_value
|
||||
return ns
|
||||
|
||||
if __name__ == "__main__":
|
||||
testcode = """
|
||||
|
@ -60,6 +52,6 @@ x = int64(7)
|
|||
a += x # promotes a to int64
|
||||
foo = True
|
||||
"""
|
||||
n2v = infer_types(ast.parse(testcode))
|
||||
for k, v in sorted(n2v.items(), key=itemgetter(0)):
|
||||
ns = infer_types(ast.parse(testcode))
|
||||
for k, v in sorted(ns.items(), key=itemgetter(0)):
|
||||
print("{:10}--> {}".format(k, str(v)))
|
||||
|
|
Loading…
Reference in New Issue