forked from M-Labs/artiq
compiler/ir_ast_body,ir_infer_types: support syscalls
This commit is contained in:
parent
4b0788d92c
commit
65566ec710
@ -6,7 +6,8 @@ from llvm import core as lc
|
||||
from artiq.compiler import ir_values
|
||||
|
||||
class Visitor:
|
||||
def __init__(self, ns, builder=None):
|
||||
def __init__(self, env, ns, builder=None):
|
||||
self.env = env
|
||||
self.ns = ns
|
||||
self.builder = builder
|
||||
|
||||
@ -112,7 +113,15 @@ class Visitor:
|
||||
"round": ir_values.operators.round,
|
||||
"round64": ir_values.operators.round64,
|
||||
}
|
||||
return ast_unfuns[node.func.id](self.visit_expression(node.args[0]), self.builder)
|
||||
fn = node.func.id
|
||||
if fn in ast_unfuns:
|
||||
return ast_unfuns[fn](self.visit_expression(node.args[0]), self.builder)
|
||||
elif fn == "syscall":
|
||||
return self.env.syscall(node.args[0].s,
|
||||
[self.visit_expression(expr) for expr in node.args[1:]],
|
||||
self.builder)
|
||||
else:
|
||||
raise NameError("Function '{}' is not defined".format(fn))
|
||||
|
||||
def visit_statements(self, stmts):
|
||||
for node in stmts:
|
||||
|
@ -5,8 +5,8 @@ from copy import deepcopy
|
||||
from artiq.compiler.ir_ast_body import Visitor
|
||||
|
||||
class _TypeScanner(ast.NodeVisitor):
|
||||
def __init__(self, ns):
|
||||
self.exprv = Visitor(ns)
|
||||
def __init__(self, env, ns):
|
||||
self.exprv = Visitor(env, ns)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
val = self.exprv.visit_expression(node.value)
|
||||
@ -32,11 +32,11 @@ class _TypeScanner(ast.NodeVisitor):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def infer_types(node):
|
||||
def infer_types(env, node):
|
||||
ns = dict()
|
||||
while True:
|
||||
prev_ns = deepcopy(ns)
|
||||
ts = _TypeScanner(ns)
|
||||
ts = _TypeScanner(env, ns)
|
||||
ts.visit(node)
|
||||
if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()):
|
||||
# no more promotions - completed
|
||||
@ -53,6 +53,6 @@ a += x # promotes a to int64
|
||||
foo = True
|
||||
bar = None
|
||||
"""
|
||||
ns = infer_types(ast.parse(testcode))
|
||||
ns = infer_types(None, ast.parse(testcode))
|
||||
for k, v in sorted(ns.items(), key=itemgetter(0)):
|
||||
print("{:10}--> {}".format(k, str(v)))
|
||||
|
Loading…
Reference in New Issue
Block a user