forked from M-Labs/artiq
py2llvm: add support for function parameters and return values, add unit test
This commit is contained in:
parent
64c29bcfa6
commit
6deaf7b81a
|
@ -1,21 +1,9 @@
|
||||||
from llvm import core as lc
|
from llvm import core as lc
|
||||||
from llvm import passes as lp
|
from llvm import passes as lp
|
||||||
|
|
||||||
from artiq.py2llvm import infer_types, ast_body, values
|
from artiq.py2llvm import values
|
||||||
|
from artiq.py2llvm.functions import compile_function
|
||||||
|
from artiq.py2llvm.tools import add_common_passes
|
||||||
def _compile_function(module, env, funcdef):
|
|
||||||
function_type = lc.Type.function(lc.Type.void(), [])
|
|
||||||
function = module.add_function(function_type, funcdef.name)
|
|
||||||
bb = function.append_basic_block("entry")
|
|
||||||
builder = lc.Builder.new(bb)
|
|
||||||
|
|
||||||
ns = infer_types.infer_types(env, funcdef)
|
|
||||||
for k, v in ns.items():
|
|
||||||
v.alloca(builder, k)
|
|
||||||
visitor = ast_body.Visitor(env, ns, builder)
|
|
||||||
visitor.visit_statements(funcdef.body)
|
|
||||||
builder.ret_void()
|
|
||||||
|
|
||||||
|
|
||||||
def get_runtime_binary(env, funcdef):
|
def get_runtime_binary(env, funcdef):
|
||||||
|
@ -23,14 +11,10 @@ def get_runtime_binary(env, funcdef):
|
||||||
env.init_module(module)
|
env.init_module(module)
|
||||||
values.init_module(module)
|
values.init_module(module)
|
||||||
|
|
||||||
_compile_function(module, env, funcdef)
|
compile_function(module, env, funcdef, dict())
|
||||||
|
|
||||||
pass_manager = lp.PassManager.new()
|
pass_manager = lp.PassManager.new()
|
||||||
pass_manager.add(lp.PASS_MEM2REG)
|
add_common_passes(pass_manager)
|
||||||
pass_manager.add(lp.PASS_INSTCOMBINE)
|
|
||||||
pass_manager.add(lp.PASS_REASSOCIATE)
|
|
||||||
pass_manager.add(lp.PASS_GVN)
|
|
||||||
pass_manager.add(lp.PASS_SIMPLIFYCFG)
|
|
||||||
pass_manager.run(module)
|
pass_manager.run(module)
|
||||||
|
|
||||||
return env.emit_object()
|
return env.emit_object()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
from artiq.py2llvm import values
|
from artiq.py2llvm import values
|
||||||
|
from artiq.py2llvm.tools import is_terminated
|
||||||
|
|
||||||
|
|
||||||
class Visitor:
|
class Visitor:
|
||||||
|
@ -131,13 +132,16 @@ class Visitor:
|
||||||
|
|
||||||
def visit_statements(self, stmts):
|
def visit_statements(self, stmts):
|
||||||
for node in stmts:
|
for node in stmts:
|
||||||
method = "_visit_stmt_" + node.__class__.__name__
|
node_type = node.__class__.__name__
|
||||||
|
method = "_visit_stmt_" + node_type
|
||||||
try:
|
try:
|
||||||
visitor = getattr(self, method)
|
visitor = getattr(self, method)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise NotImplementedError("Unsupported node '{}' in statement"
|
raise NotImplementedError("Unsupported node '{}' in statement"
|
||||||
.format(node.__class__.__name__))
|
.format(node_type))
|
||||||
visitor(node)
|
visitor(node)
|
||||||
|
if node_type == "Return":
|
||||||
|
break
|
||||||
|
|
||||||
def _visit_stmt_Assign(self, node):
|
def _visit_stmt_Assign(self, node):
|
||||||
val = self.visit_expression(node.value)
|
val = self.visit_expression(node.value)
|
||||||
|
@ -171,10 +175,12 @@ class Visitor:
|
||||||
|
|
||||||
self.builder.position_at_end(then_block)
|
self.builder.position_at_end(then_block)
|
||||||
self.visit_statements(node.body)
|
self.visit_statements(node.body)
|
||||||
|
if not is_terminated(self.builder.basic_block):
|
||||||
self.builder.branch(merge_block)
|
self.builder.branch(merge_block)
|
||||||
|
|
||||||
self.builder.position_at_end(else_block)
|
self.builder.position_at_end(else_block)
|
||||||
self.visit_statements(node.orelse)
|
self.visit_statements(node.orelse)
|
||||||
|
if not is_terminated(self.builder.basic_block):
|
||||||
self.builder.branch(merge_block)
|
self.builder.branch(merge_block)
|
||||||
|
|
||||||
self.builder.position_at_end(merge_block)
|
self.builder.position_at_end(merge_block)
|
||||||
|
@ -192,6 +198,7 @@ class Visitor:
|
||||||
|
|
||||||
self.builder.position_at_end(body_block)
|
self.builder.position_at_end(body_block)
|
||||||
self.visit_statements(node.body)
|
self.visit_statements(node.body)
|
||||||
|
if not is_terminated(self.builder.basic_block):
|
||||||
condition = values.operators.bool(
|
condition = values.operators.bool(
|
||||||
self.visit_expression(node.test), self.builder)
|
self.visit_expression(node.test), self.builder)
|
||||||
self.builder.cbranch(
|
self.builder.cbranch(
|
||||||
|
@ -199,6 +206,17 @@ class Visitor:
|
||||||
|
|
||||||
self.builder.position_at_end(else_block)
|
self.builder.position_at_end(else_block)
|
||||||
self.visit_statements(node.orelse)
|
self.visit_statements(node.orelse)
|
||||||
|
if not is_terminated(self.builder.basic_block):
|
||||||
self.builder.branch(merge_block)
|
self.builder.branch(merge_block)
|
||||||
|
|
||||||
self.builder.position_at_end(merge_block)
|
self.builder.position_at_end(merge_block)
|
||||||
|
|
||||||
|
def _visit_stmt_Return(self, node):
|
||||||
|
if node.value is None:
|
||||||
|
val = values.VNone()
|
||||||
|
else:
|
||||||
|
val = self.visit_expression(node.value)
|
||||||
|
if isinstance(val, values.VNone):
|
||||||
|
self.builder.ret_void()
|
||||||
|
else:
|
||||||
|
self.builder.ret(val.get_ssa_value(self.builder))
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
from llvm import core as lc
|
||||||
|
|
||||||
|
from artiq.py2llvm import infer_types, ast_body, values, tools
|
||||||
|
|
||||||
|
def compile_function(module, env, funcdef, param_types):
|
||||||
|
ns = infer_types.infer_function_types(env, funcdef, param_types)
|
||||||
|
retval = ns["return"]
|
||||||
|
|
||||||
|
function_type = lc.Type.function(retval.get_llvm_type(),
|
||||||
|
[ns[arg.arg].get_llvm_type() for arg in funcdef.args.args])
|
||||||
|
function = module.add_function(function_type, funcdef.name)
|
||||||
|
bb = function.append_basic_block("entry")
|
||||||
|
builder = lc.Builder.new(bb)
|
||||||
|
|
||||||
|
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
||||||
|
arg_llvm.name = arg_ast.arg
|
||||||
|
for k, v in ns.items():
|
||||||
|
v.alloca(builder, k)
|
||||||
|
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
|
||||||
|
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
|
||||||
|
|
||||||
|
visitor = ast_body.Visitor(env, ns, builder)
|
||||||
|
visitor.visit_statements(funcdef.body)
|
||||||
|
|
||||||
|
if not tools.is_terminated(builder.basic_block):
|
||||||
|
if isinstance(retval, values.VNone):
|
||||||
|
builder.ret_void()
|
||||||
|
else:
|
||||||
|
builder.ret(retval.get_ssa_value(builder))
|
||||||
|
|
||||||
|
return function, retval
|
|
@ -1,61 +1,55 @@
|
||||||
import ast
|
import ast
|
||||||
from operator import itemgetter
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from artiq.py2llvm.ast_body import Visitor
|
from artiq.py2llvm.ast_body import Visitor
|
||||||
|
from artiq.py2llvm import values
|
||||||
|
|
||||||
|
|
||||||
class _TypeScanner(ast.NodeVisitor):
|
class _TypeScanner(ast.NodeVisitor):
|
||||||
def __init__(self, env, ns):
|
def __init__(self, env, ns):
|
||||||
self.exprv = Visitor(env, ns)
|
self.exprv = Visitor(env, ns)
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def _update_target(self, target, val):
|
||||||
val = self.exprv.visit_expression(node.value)
|
|
||||||
ns = self.exprv.ns
|
ns = self.exprv.ns
|
||||||
for target in node.targets:
|
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
if target.id in ns:
|
if target.id in ns:
|
||||||
ns[target.id].merge(val)
|
ns[target.id].merge(val)
|
||||||
else:
|
else:
|
||||||
ns[target.id] = val
|
ns[target.id] = deepcopy(val)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def visit_Assign(self, node):
|
||||||
|
val = self.exprv.visit_expression(node.value)
|
||||||
|
for target in node.targets:
|
||||||
|
self._update_target(target, val)
|
||||||
|
|
||||||
def visit_AugAssign(self, node):
|
def visit_AugAssign(self, node):
|
||||||
val = self.exprv.visit_expression(ast.BinOp(
|
val = self.exprv.visit_expression(ast.BinOp(
|
||||||
op=node.op, left=node.target, right=node.value))
|
op=node.op, left=node.target, right=node.value))
|
||||||
|
self._update_target(node.target, val)
|
||||||
|
|
||||||
|
def visit_Return(self, node):
|
||||||
|
if node.value is None:
|
||||||
|
val = values.VNone()
|
||||||
|
else:
|
||||||
|
val = self.exprv.visit_expression(node.value)
|
||||||
ns = self.exprv.ns
|
ns = self.exprv.ns
|
||||||
target = node.target
|
if "return" in ns:
|
||||||
if isinstance(target, ast.Name):
|
ns["return"].merge(val)
|
||||||
if target.id in ns:
|
|
||||||
ns[target.id].merge(val)
|
|
||||||
else:
|
else:
|
||||||
ns[target.id] = val
|
ns["return"] = deepcopy(val)
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
def infer_function_types(env, node, param_types):
|
||||||
def infer_types(env, node):
|
ns = deepcopy(param_types)
|
||||||
ns = dict()
|
ts = _TypeScanner(env, ns)
|
||||||
|
ts.visit(node)
|
||||||
while True:
|
while True:
|
||||||
prev_ns = deepcopy(ns)
|
prev_ns = deepcopy(ns)
|
||||||
ts = _TypeScanner(env, ns)
|
ts = _TypeScanner(env, ns)
|
||||||
ts.visit(node)
|
ts.visit(node)
|
||||||
if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()):
|
if all(v.same_type(prev_ns[k]) for k, v in ns.items()):
|
||||||
# no more promotions - completed
|
# no more promotions - completed
|
||||||
|
if "return" not in ns:
|
||||||
|
ns["return"] = values.VNone()
|
||||||
return ns
|
return ns
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
testcode = """
|
|
||||||
a = 2 # promoted later to int64
|
|
||||||
b = a + 1 # initially int32, becomes int64 after a is promoted
|
|
||||||
c = b//2 # initially int32, becomes int64 after b is promoted
|
|
||||||
d = 4 # stays int32
|
|
||||||
x = int64(7)
|
|
||||||
a += x # promotes a to int64
|
|
||||||
foo = True
|
|
||||||
bar = None
|
|
||||||
"""
|
|
||||||
ns = infer_types(None, ast.parse(testcode))
|
|
||||||
for k, v in sorted(ns.items(), key=itemgetter(0)):
|
|
||||||
print("{:10}--> {}".format(k, str(v)))
|
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
from llvm import passes as lp
|
||||||
|
|
||||||
|
def is_terminated(basic_block):
|
||||||
|
return basic_block.instructions and basic_block.instructions[-1].is_terminator
|
||||||
|
|
||||||
|
def add_common_passes(pass_manager):
|
||||||
|
pass_manager.add(lp.PASS_MEM2REG)
|
||||||
|
pass_manager.add(lp.PASS_INSTCOMBINE)
|
||||||
|
pass_manager.add(lp.PASS_REASSOCIATE)
|
||||||
|
pass_manager.add(lp.PASS_GVN)
|
||||||
|
pass_manager.add(lp.PASS_SIMPLIFYCFG)
|
|
@ -24,7 +24,7 @@ class _Value:
|
||||||
|
|
||||||
def alloca(self, builder, name):
|
def alloca(self, builder, name):
|
||||||
if self._llvm_value is not None:
|
if self._llvm_value is not None:
|
||||||
raise RuntimeError("Attempted to alloca existing LLVM value")
|
raise RuntimeError("Attempted to alloca existing LLVM value "+name)
|
||||||
self._llvm_value = builder.alloca(self.get_llvm_type(), name=name)
|
self._llvm_value = builder.alloca(self.get_llvm_type(), name=name)
|
||||||
|
|
||||||
def o_int(self, builder):
|
def o_int(self, builder):
|
||||||
|
|
|
@ -0,0 +1,94 @@
|
||||||
|
import unittest
|
||||||
|
import ast
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from llvm import core as lc
|
||||||
|
from llvm import passes as lp
|
||||||
|
from llvm import ee as le
|
||||||
|
|
||||||
|
from artiq.py2llvm.infer_types import infer_function_types
|
||||||
|
from artiq.py2llvm import values
|
||||||
|
from artiq.py2llvm import compile_function
|
||||||
|
from artiq.py2llvm.tools import add_common_passes
|
||||||
|
|
||||||
|
|
||||||
|
def test_types(choice):
|
||||||
|
a = 2 # promoted later to int64
|
||||||
|
b = a + 1 # initially int32, becomes int64 after a is promoted
|
||||||
|
c = b//2 # initially int32, becomes int64 after b is promoted
|
||||||
|
d = 4 # stays int32
|
||||||
|
x = int64(7)
|
||||||
|
a += x # promotes a to int64
|
||||||
|
foo = True
|
||||||
|
bar = None
|
||||||
|
|
||||||
|
if choice:
|
||||||
|
return 3
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
class FunctionTypesCase(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.ns = infer_function_types(
|
||||||
|
None, ast.parse(inspect.getsource(test_types)),
|
||||||
|
dict())
|
||||||
|
|
||||||
|
def test_base_types(self):
|
||||||
|
self.assertIsInstance(self.ns["foo"], values.VBool)
|
||||||
|
self.assertIsInstance(self.ns["bar"], values.VNone)
|
||||||
|
self.assertIsInstance(self.ns["d"], values.VInt)
|
||||||
|
self.assertEqual(self.ns["d"].nbits, 32)
|
||||||
|
self.assertIsInstance(self.ns["x"], values.VInt)
|
||||||
|
self.assertEqual(self.ns["x"].nbits, 64)
|
||||||
|
|
||||||
|
def test_promotion(self):
|
||||||
|
for v in "abc":
|
||||||
|
self.assertIsInstance(self.ns[v], values.VInt)
|
||||||
|
self.assertEqual(self.ns[v].nbits, 64)
|
||||||
|
|
||||||
|
def test_return(self):
|
||||||
|
self.assertIsInstance(self.ns["return"], values.VInt)
|
||||||
|
self.assertEqual(self.ns["return"].nbits, 64)
|
||||||
|
|
||||||
|
|
||||||
|
class CompiledFunction:
|
||||||
|
def __init__(self, function, param_types):
|
||||||
|
module = lc.Module.new("main")
|
||||||
|
values.init_module(module)
|
||||||
|
|
||||||
|
funcdef = ast.parse(inspect.getsource(function)).body[0]
|
||||||
|
self.function, self.retval = compile_function(
|
||||||
|
module, None, funcdef, param_types)
|
||||||
|
self.argval = [param_types[arg.arg] for arg in funcdef.args.args]
|
||||||
|
|
||||||
|
self.executor = le.ExecutionEngine.new(module)
|
||||||
|
pass_manager = lp.PassManager.new()
|
||||||
|
add_common_passes(pass_manager)
|
||||||
|
pass_manager.run(module)
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
args_llvm = [
|
||||||
|
le.GenericValue.int(av.get_llvm_type(), a)
|
||||||
|
for av, a in zip(self.argval, args)]
|
||||||
|
result = self.executor.run_function(self.function, args_llvm)
|
||||||
|
if isinstance(self.retval, values.VBool):
|
||||||
|
return bool(result.as_int())
|
||||||
|
elif isinstance(self.retval, values.VInt):
|
||||||
|
return result.as_int_signed()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def is_prime(x):
|
||||||
|
d = 2
|
||||||
|
while d*d <= x:
|
||||||
|
if not x % d:
|
||||||
|
return False
|
||||||
|
d += 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
class CodeGenCase(unittest.TestCase):
|
||||||
|
def test_is_prime(self):
|
||||||
|
is_prime_c = CompiledFunction(is_prime, {"x": values.VInt(32)})
|
||||||
|
for i in range(200):
|
||||||
|
self.assertEqual(is_prime_c(i), is_prime(i))
|
Loading…
Reference in New Issue