forked from M-Labs/artiq
compiler/ir_ast_body,ir_infer_types: use Python dict directly as namespace
This commit is contained in:
parent
30ef6119e6
commit
9e21ea5658
|
@ -1,11 +1,14 @@
|
||||||
import ast
|
import ast
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
from llvm import core as lc
|
||||||
|
|
||||||
from artiq.compiler import ir_values
|
from artiq.compiler import ir_values
|
||||||
|
|
||||||
class Visitor:
|
class Visitor:
|
||||||
def __init__(self, builder, ns):
|
def __init__(self, ns, builder=None):
|
||||||
self.builder = builder
|
|
||||||
self.ns = ns
|
self.ns = ns
|
||||||
|
self.builder = builder
|
||||||
|
|
||||||
# builder can be None for visit_expression
|
# builder can be None for visit_expression
|
||||||
def visit_expression(self, node):
|
def visit_expression(self, node):
|
||||||
|
@ -17,7 +20,17 @@ class Visitor:
|
||||||
return visitor(node)
|
return visitor(node)
|
||||||
|
|
||||||
def _visit_expr_Name(self, node):
|
def _visit_expr_Name(self, node):
|
||||||
return self.ns.load(self.builder, node.id)
|
try:
|
||||||
|
r = self.ns[node.id]
|
||||||
|
except KeyError:
|
||||||
|
raise NameError("Name '{}' is not defined".format(node.id))
|
||||||
|
r = copy(r)
|
||||||
|
if self.builder is None:
|
||||||
|
r.llvm_value = None
|
||||||
|
else:
|
||||||
|
if isinstance(r.llvm_value, lc.AllocaInstruction):
|
||||||
|
r.llvm_value = self.builder.load(r.llvm_value)
|
||||||
|
return r
|
||||||
|
|
||||||
def _visit_expr_NameConstant(self, node):
|
def _visit_expr_NameConstant(self, node):
|
||||||
v = node.value
|
v = node.value
|
||||||
|
@ -112,14 +125,14 @@ class Visitor:
|
||||||
val = self.visit_expression(node.value)
|
val = self.visit_expression(node.value)
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
self.ns.store(self.builder, val, target.id)
|
self.builder.store(val, self.ns[target.id])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _visit_stmt_AugAssign(self, node):
|
def _visit_stmt_AugAssign(self, node):
|
||||||
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
||||||
if isinstance(node.target, ast.Name):
|
if isinstance(node.target, ast.Name):
|
||||||
self.ns.store(self.builder, val, node.target.id)
|
self.builder.store(val, self.ns[node.target.id])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -4,51 +4,43 @@ from copy import deepcopy
|
||||||
|
|
||||||
from artiq.compiler.ir_ast_body import Visitor
|
from artiq.compiler.ir_ast_body import Visitor
|
||||||
|
|
||||||
class _Namespace:
|
|
||||||
def __init__(self, name_to_value):
|
|
||||||
self.name_to_value = name_to_value
|
|
||||||
|
|
||||||
def load(self, builder, name):
|
|
||||||
return self.name_to_value[name]
|
|
||||||
|
|
||||||
class _TypeScanner(ast.NodeVisitor):
|
class _TypeScanner(ast.NodeVisitor):
|
||||||
def __init__(self, namespace):
|
def __init__(self, ns):
|
||||||
self.exprv = Visitor(None, namespace)
|
self.exprv = Visitor(ns)
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
val = self.exprv.visit_expression(node.value)
|
val = self.exprv.visit_expression(node.value)
|
||||||
n2v = self.exprv.ns.name_to_value
|
ns = self.exprv.ns
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
if target.id in n2v:
|
if target.id in ns:
|
||||||
n2v[target.id].merge(val)
|
ns[target.id].merge(val)
|
||||||
else:
|
else:
|
||||||
n2v[target.id] = val
|
ns[target.id] = val
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def visit_AugAssign(self, node):
|
def visit_AugAssign(self, node):
|
||||||
val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
|
||||||
n2v = self.exprv.ns.name_to_value
|
ns = self.exprv.ns
|
||||||
target = node.target
|
target = node.target
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
if target.id in n2v:
|
if target.id in ns:
|
||||||
n2v[target.id].merge(val)
|
ns[target.id].merge(val)
|
||||||
else:
|
else:
|
||||||
n2v[target.id] = val
|
ns[target.id] = val
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def infer_types(node):
|
def infer_types(node):
|
||||||
name_to_value = dict()
|
ns = dict()
|
||||||
while True:
|
while True:
|
||||||
prev_name_to_value = deepcopy(name_to_value)
|
prev_ns = deepcopy(ns)
|
||||||
ns = _Namespace(name_to_value)
|
|
||||||
ts = _TypeScanner(ns)
|
ts = _TypeScanner(ns)
|
||||||
ts.visit(node)
|
ts.visit(node)
|
||||||
if prev_name_to_value and all(v.same_type(prev_name_to_value[k]) for k, v in name_to_value.items()):
|
if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()):
|
||||||
# no more promotions - completed
|
# no more promotions - completed
|
||||||
return name_to_value
|
return ns
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
testcode = """
|
testcode = """
|
||||||
|
@ -60,6 +52,6 @@ x = int64(7)
|
||||||
a += x # promotes a to int64
|
a += x # promotes a to int64
|
||||||
foo = True
|
foo = True
|
||||||
"""
|
"""
|
||||||
n2v = infer_types(ast.parse(testcode))
|
ns = infer_types(ast.parse(testcode))
|
||||||
for k, v in sorted(n2v.items(), key=itemgetter(0)):
|
for k, v in sorted(ns.items(), key=itemgetter(0)):
|
||||||
print("{:10}--> {}".format(k, str(v)))
|
print("{:10}--> {}".format(k, str(v)))
|
||||||
|
|
Loading…
Reference in New Issue