mirror of https://github.com/m-labs/artiq.git
py2llvm: add support for function parameters and return values, add unit test
This commit is contained in:
parent
64c29bcfa6
commit
6deaf7b81a
|
@ -1,21 +1,9 @@
|
|||
from llvm import core as lc
|
||||
from llvm import passes as lp
|
||||
|
||||
from artiq.py2llvm import infer_types, ast_body, values
|
||||
|
||||
|
||||
def _compile_function(module, env, funcdef):
|
||||
function_type = lc.Type.function(lc.Type.void(), [])
|
||||
function = module.add_function(function_type, funcdef.name)
|
||||
bb = function.append_basic_block("entry")
|
||||
builder = lc.Builder.new(bb)
|
||||
|
||||
ns = infer_types.infer_types(env, funcdef)
|
||||
for k, v in ns.items():
|
||||
v.alloca(builder, k)
|
||||
visitor = ast_body.Visitor(env, ns, builder)
|
||||
visitor.visit_statements(funcdef.body)
|
||||
builder.ret_void()
|
||||
from artiq.py2llvm import values
|
||||
from artiq.py2llvm.functions import compile_function
|
||||
from artiq.py2llvm.tools import add_common_passes
|
||||
|
||||
|
||||
def get_runtime_binary(env, funcdef):
|
||||
|
@ -23,14 +11,10 @@ def get_runtime_binary(env, funcdef):
|
|||
env.init_module(module)
|
||||
values.init_module(module)
|
||||
|
||||
_compile_function(module, env, funcdef)
|
||||
compile_function(module, env, funcdef, dict())
|
||||
|
||||
pass_manager = lp.PassManager.new()
|
||||
pass_manager.add(lp.PASS_MEM2REG)
|
||||
pass_manager.add(lp.PASS_INSTCOMBINE)
|
||||
pass_manager.add(lp.PASS_REASSOCIATE)
|
||||
pass_manager.add(lp.PASS_GVN)
|
||||
pass_manager.add(lp.PASS_SIMPLIFYCFG)
|
||||
add_common_passes(pass_manager)
|
||||
pass_manager.run(module)
|
||||
|
||||
return env.emit_object()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import ast
|
||||
|
||||
from artiq.py2llvm import values
|
||||
from artiq.py2llvm.tools import is_terminated
|
||||
|
||||
|
||||
class Visitor:
|
||||
|
@ -131,13 +132,16 @@ class Visitor:
|
|||
|
||||
def visit_statements(self, stmts):
|
||||
for node in stmts:
|
||||
method = "_visit_stmt_" + node.__class__.__name__
|
||||
node_type = node.__class__.__name__
|
||||
method = "_visit_stmt_" + node_type
|
||||
try:
|
||||
visitor = getattr(self, method)
|
||||
except AttributeError:
|
||||
raise NotImplementedError("Unsupported node '{}' in statement"
|
||||
.format(node.__class__.__name__))
|
||||
.format(node_type))
|
||||
visitor(node)
|
||||
if node_type == "Return":
|
||||
break
|
||||
|
||||
def _visit_stmt_Assign(self, node):
|
||||
val = self.visit_expression(node.value)
|
||||
|
@ -165,17 +169,19 @@ class Visitor:
|
|||
merge_block = function.append_basic_block("i_merge")
|
||||
|
||||
condition = values.operators.bool(self.visit_expression(node.test),
|
||||
self.builder)
|
||||
self.builder)
|
||||
self.builder.cbranch(condition.get_ssa_value(self.builder),
|
||||
then_block, else_block)
|
||||
|
||||
self.builder.position_at_end(then_block)
|
||||
self.visit_statements(node.body)
|
||||
self.builder.branch(merge_block)
|
||||
if not is_terminated(self.builder.basic_block):
|
||||
self.builder.branch(merge_block)
|
||||
|
||||
self.builder.position_at_end(else_block)
|
||||
self.visit_statements(node.orelse)
|
||||
self.builder.branch(merge_block)
|
||||
if not is_terminated(self.builder.basic_block):
|
||||
self.builder.branch(merge_block)
|
||||
|
||||
self.builder.position_at_end(merge_block)
|
||||
|
||||
|
@ -192,13 +198,25 @@ class Visitor:
|
|||
|
||||
self.builder.position_at_end(body_block)
|
||||
self.visit_statements(node.body)
|
||||
condition = values.operators.bool(
|
||||
self.visit_expression(node.test), self.builder)
|
||||
self.builder.cbranch(
|
||||
condition.get_ssa_value(self.builder), body_block, merge_block)
|
||||
if not is_terminated(self.builder.basic_block):
|
||||
condition = values.operators.bool(
|
||||
self.visit_expression(node.test), self.builder)
|
||||
self.builder.cbranch(
|
||||
condition.get_ssa_value(self.builder), body_block, merge_block)
|
||||
|
||||
self.builder.position_at_end(else_block)
|
||||
self.visit_statements(node.orelse)
|
||||
self.builder.branch(merge_block)
|
||||
if not is_terminated(self.builder.basic_block):
|
||||
self.builder.branch(merge_block)
|
||||
|
||||
self.builder.position_at_end(merge_block)
|
||||
|
||||
def _visit_stmt_Return(self, node):
|
||||
if node.value is None:
|
||||
val = values.VNone()
|
||||
else:
|
||||
val = self.visit_expression(node.value)
|
||||
if isinstance(val, values.VNone):
|
||||
self.builder.ret_void()
|
||||
else:
|
||||
self.builder.ret(val.get_ssa_value(self.builder))
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
from llvm import core as lc
|
||||
|
||||
from artiq.py2llvm import infer_types, ast_body, values, tools
|
||||
|
||||
def compile_function(module, env, funcdef, param_types):
|
||||
ns = infer_types.infer_function_types(env, funcdef, param_types)
|
||||
retval = ns["return"]
|
||||
|
||||
function_type = lc.Type.function(retval.get_llvm_type(),
|
||||
[ns[arg.arg].get_llvm_type() for arg in funcdef.args.args])
|
||||
function = module.add_function(function_type, funcdef.name)
|
||||
bb = function.append_basic_block("entry")
|
||||
builder = lc.Builder.new(bb)
|
||||
|
||||
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
||||
arg_llvm.name = arg_ast.arg
|
||||
for k, v in ns.items():
|
||||
v.alloca(builder, k)
|
||||
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
||||
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
|
||||
|
||||
visitor = ast_body.Visitor(env, ns, builder)
|
||||
visitor.visit_statements(funcdef.body)
|
||||
|
||||
if not tools.is_terminated(builder.basic_block):
|
||||
if isinstance(retval, values.VNone):
|
||||
builder.ret_void()
|
||||
else:
|
||||
builder.ret(retval.get_ssa_value(builder))
|
||||
|
||||
return function, retval
|
|
@ -1,61 +1,55 @@
|
|||
import ast
|
||||
from operator import itemgetter
|
||||
from copy import deepcopy
|
||||
|
||||
from artiq.py2llvm.ast_body import Visitor
|
||||
from artiq.py2llvm import values
|
||||
|
||||
|
||||
class _TypeScanner(ast.NodeVisitor):
|
||||
def __init__(self, env, ns):
|
||||
self.exprv = Visitor(env, ns)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
val = self.exprv.visit_expression(node.value)
|
||||
def _update_target(self, target, val):
|
||||
ns = self.exprv.ns
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in ns:
|
||||
ns[target.id].merge(val)
|
||||
else:
|
||||
ns[target.id] = val
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
val = self.exprv.visit_expression(ast.BinOp(
|
||||
op=node.op, left=node.target, right=node.value))
|
||||
ns = self.exprv.ns
|
||||
target = node.target
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in ns:
|
||||
ns[target.id].merge(val)
|
||||
else:
|
||||
ns[target.id] = val
|
||||
ns[target.id] = deepcopy(val)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def visit_Assign(self, node):
|
||||
val = self.exprv.visit_expression(node.value)
|
||||
for target in node.targets:
|
||||
self._update_target(target, val)
|
||||
|
||||
def infer_types(env, node):
|
||||
ns = dict()
|
||||
def visit_AugAssign(self, node):
|
||||
val = self.exprv.visit_expression(ast.BinOp(
|
||||
op=node.op, left=node.target, right=node.value))
|
||||
self._update_target(node.target, val)
|
||||
|
||||
def visit_Return(self, node):
|
||||
if node.value is None:
|
||||
val = values.VNone()
|
||||
else:
|
||||
val = self.exprv.visit_expression(node.value)
|
||||
ns = self.exprv.ns
|
||||
if "return" in ns:
|
||||
ns["return"].merge(val)
|
||||
else:
|
||||
ns["return"] = deepcopy(val)
|
||||
|
||||
def infer_function_types(env, node, param_types):
|
||||
ns = deepcopy(param_types)
|
||||
ts = _TypeScanner(env, ns)
|
||||
ts.visit(node)
|
||||
while True:
|
||||
prev_ns = deepcopy(ns)
|
||||
ts = _TypeScanner(env, ns)
|
||||
ts.visit(node)
|
||||
if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()):
|
||||
if all(v.same_type(prev_ns[k]) for k, v in ns.items()):
|
||||
# no more promotions - completed
|
||||
if "return" not in ns:
|
||||
ns["return"] = values.VNone()
|
||||
return ns
|
||||
|
||||
if __name__ == "__main__":
|
||||
testcode = """
|
||||
a = 2 # promoted later to int64
|
||||
b = a + 1 # initially int32, becomes int64 after a is promoted
|
||||
c = b//2 # initially int32, becomes int64 after b is promoted
|
||||
d = 4 # stays int32
|
||||
x = int64(7)
|
||||
a += x # promotes a to int64
|
||||
foo = True
|
||||
bar = None
|
||||
"""
|
||||
ns = infer_types(None, ast.parse(testcode))
|
||||
for k, v in sorted(ns.items(), key=itemgetter(0)):
|
||||
print("{:10}--> {}".format(k, str(v)))
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
from llvm import passes as lp
|
||||
|
||||
def is_terminated(basic_block):
|
||||
return basic_block.instructions and basic_block.instructions[-1].is_terminator
|
||||
|
||||
def add_common_passes(pass_manager):
|
||||
pass_manager.add(lp.PASS_MEM2REG)
|
||||
pass_manager.add(lp.PASS_INSTCOMBINE)
|
||||
pass_manager.add(lp.PASS_REASSOCIATE)
|
||||
pass_manager.add(lp.PASS_GVN)
|
||||
pass_manager.add(lp.PASS_SIMPLIFYCFG)
|
|
@ -24,7 +24,7 @@ class _Value:
|
|||
|
||||
def alloca(self, builder, name):
|
||||
if self._llvm_value is not None:
|
||||
raise RuntimeError("Attempted to alloca existing LLVM value")
|
||||
raise RuntimeError("Attempted to alloca existing LLVM value "+name)
|
||||
self._llvm_value = builder.alloca(self.get_llvm_type(), name=name)
|
||||
|
||||
def o_int(self, builder):
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
import unittest
|
||||
import ast
|
||||
import inspect
|
||||
|
||||
from llvm import core as lc
|
||||
from llvm import passes as lp
|
||||
from llvm import ee as le
|
||||
|
||||
from artiq.py2llvm.infer_types import infer_function_types
|
||||
from artiq.py2llvm import values
|
||||
from artiq.py2llvm import compile_function
|
||||
from artiq.py2llvm.tools import add_common_passes
|
||||
|
||||
|
||||
def test_types(choice):
|
||||
a = 2 # promoted later to int64
|
||||
b = a + 1 # initially int32, becomes int64 after a is promoted
|
||||
c = b//2 # initially int32, becomes int64 after b is promoted
|
||||
d = 4 # stays int32
|
||||
x = int64(7)
|
||||
a += x # promotes a to int64
|
||||
foo = True
|
||||
bar = None
|
||||
|
||||
if choice:
|
||||
return 3
|
||||
else:
|
||||
return x
|
||||
|
||||
class FunctionTypesCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.ns = infer_function_types(
|
||||
None, ast.parse(inspect.getsource(test_types)),
|
||||
dict())
|
||||
|
||||
def test_base_types(self):
|
||||
self.assertIsInstance(self.ns["foo"], values.VBool)
|
||||
self.assertIsInstance(self.ns["bar"], values.VNone)
|
||||
self.assertIsInstance(self.ns["d"], values.VInt)
|
||||
self.assertEqual(self.ns["d"].nbits, 32)
|
||||
self.assertIsInstance(self.ns["x"], values.VInt)
|
||||
self.assertEqual(self.ns["x"].nbits, 64)
|
||||
|
||||
def test_promotion(self):
|
||||
for v in "abc":
|
||||
self.assertIsInstance(self.ns[v], values.VInt)
|
||||
self.assertEqual(self.ns[v].nbits, 64)
|
||||
|
||||
def test_return(self):
|
||||
self.assertIsInstance(self.ns["return"], values.VInt)
|
||||
self.assertEqual(self.ns["return"].nbits, 64)
|
||||
|
||||
|
||||
class CompiledFunction:
|
||||
def __init__(self, function, param_types):
|
||||
module = lc.Module.new("main")
|
||||
values.init_module(module)
|
||||
|
||||
funcdef = ast.parse(inspect.getsource(function)).body[0]
|
||||
self.function, self.retval = compile_function(
|
||||
module, None, funcdef, param_types)
|
||||
self.argval = [param_types[arg.arg] for arg in funcdef.args.args]
|
||||
|
||||
self.executor = le.ExecutionEngine.new(module)
|
||||
pass_manager = lp.PassManager.new()
|
||||
add_common_passes(pass_manager)
|
||||
pass_manager.run(module)
|
||||
|
||||
def __call__(self, *args):
|
||||
args_llvm = [
|
||||
le.GenericValue.int(av.get_llvm_type(), a)
|
||||
for av, a in zip(self.argval, args)]
|
||||
result = self.executor.run_function(self.function, args_llvm)
|
||||
if isinstance(self.retval, values.VBool):
|
||||
return bool(result.as_int())
|
||||
elif isinstance(self.retval, values.VInt):
|
||||
return result.as_int_signed()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def is_prime(x):
|
||||
d = 2
|
||||
while d*d <= x:
|
||||
if not x % d:
|
||||
return False
|
||||
d += 1
|
||||
return True
|
||||
|
||||
class CodeGenCase(unittest.TestCase):
|
||||
def test_is_prime(self):
|
||||
is_prime_c = CompiledFunction(is_prime, {"x": values.VInt(32)})
|
||||
for i in range(200):
|
||||
self.assertEqual(is_prime_c(i), is_prime(i))
|
Loading…
Reference in New Issue