forked from M-Labs/artiq
1
0
Fork 0

py2llvm: add support for function parameters and return values, add unit test

This commit is contained in:
Sebastien Bourdeauducq 2014-09-06 19:03:08 +08:00
parent 64c29bcfa6
commit 6deaf7b81a
7 changed files with 200 additions and 68 deletions

View File

@ -1,21 +1,9 @@
from llvm import core as lc from llvm import core as lc
from llvm import passes as lp from llvm import passes as lp
from artiq.py2llvm import infer_types, ast_body, values from artiq.py2llvm import values
from artiq.py2llvm.functions import compile_function
from artiq.py2llvm.tools import add_common_passes
def _compile_function(module, env, funcdef):
function_type = lc.Type.function(lc.Type.void(), [])
function = module.add_function(function_type, funcdef.name)
bb = function.append_basic_block("entry")
builder = lc.Builder.new(bb)
ns = infer_types.infer_types(env, funcdef)
for k, v in ns.items():
v.alloca(builder, k)
visitor = ast_body.Visitor(env, ns, builder)
visitor.visit_statements(funcdef.body)
builder.ret_void()
def get_runtime_binary(env, funcdef): def get_runtime_binary(env, funcdef):
@ -23,14 +11,10 @@ def get_runtime_binary(env, funcdef):
env.init_module(module) env.init_module(module)
values.init_module(module) values.init_module(module)
_compile_function(module, env, funcdef) compile_function(module, env, funcdef, dict())
pass_manager = lp.PassManager.new() pass_manager = lp.PassManager.new()
pass_manager.add(lp.PASS_MEM2REG) add_common_passes(pass_manager)
pass_manager.add(lp.PASS_INSTCOMBINE)
pass_manager.add(lp.PASS_REASSOCIATE)
pass_manager.add(lp.PASS_GVN)
pass_manager.add(lp.PASS_SIMPLIFYCFG)
pass_manager.run(module) pass_manager.run(module)
return env.emit_object() return env.emit_object()

View File

@ -1,6 +1,7 @@
import ast import ast
from artiq.py2llvm import values from artiq.py2llvm import values
from artiq.py2llvm.tools import is_terminated
class Visitor: class Visitor:
@ -131,13 +132,16 @@ class Visitor:
def visit_statements(self, stmts): def visit_statements(self, stmts):
for node in stmts: for node in stmts:
method = "_visit_stmt_" + node.__class__.__name__ node_type = node.__class__.__name__
method = "_visit_stmt_" + node_type
try: try:
visitor = getattr(self, method) visitor = getattr(self, method)
except AttributeError: except AttributeError:
raise NotImplementedError("Unsupported node '{}' in statement" raise NotImplementedError("Unsupported node '{}' in statement"
.format(node.__class__.__name__)) .format(node_type))
visitor(node) visitor(node)
if node_type == "Return":
break
def _visit_stmt_Assign(self, node): def _visit_stmt_Assign(self, node):
val = self.visit_expression(node.value) val = self.visit_expression(node.value)
@ -171,10 +175,12 @@ class Visitor:
self.builder.position_at_end(then_block) self.builder.position_at_end(then_block)
self.visit_statements(node.body) self.visit_statements(node.body)
if not is_terminated(self.builder.basic_block):
self.builder.branch(merge_block) self.builder.branch(merge_block)
self.builder.position_at_end(else_block) self.builder.position_at_end(else_block)
self.visit_statements(node.orelse) self.visit_statements(node.orelse)
if not is_terminated(self.builder.basic_block):
self.builder.branch(merge_block) self.builder.branch(merge_block)
self.builder.position_at_end(merge_block) self.builder.position_at_end(merge_block)
@ -192,6 +198,7 @@ class Visitor:
self.builder.position_at_end(body_block) self.builder.position_at_end(body_block)
self.visit_statements(node.body) self.visit_statements(node.body)
if not is_terminated(self.builder.basic_block):
condition = values.operators.bool( condition = values.operators.bool(
self.visit_expression(node.test), self.builder) self.visit_expression(node.test), self.builder)
self.builder.cbranch( self.builder.cbranch(
@ -199,6 +206,17 @@ class Visitor:
self.builder.position_at_end(else_block) self.builder.position_at_end(else_block)
self.visit_statements(node.orelse) self.visit_statements(node.orelse)
if not is_terminated(self.builder.basic_block):
self.builder.branch(merge_block) self.builder.branch(merge_block)
self.builder.position_at_end(merge_block) self.builder.position_at_end(merge_block)
def _visit_stmt_Return(self, node):
if node.value is None:
val = values.VNone()
else:
val = self.visit_expression(node.value)
if isinstance(val, values.VNone):
self.builder.ret_void()
else:
self.builder.ret(val.get_ssa_value(self.builder))

View File

@ -0,0 +1,31 @@
from llvm import core as lc
from artiq.py2llvm import infer_types, ast_body, values, tools
def compile_function(module, env, funcdef, param_types):
ns = infer_types.infer_function_types(env, funcdef, param_types)
retval = ns["return"]
function_type = lc.Type.function(retval.get_llvm_type(),
[ns[arg.arg].get_llvm_type() for arg in funcdef.args.args])
function = module.add_function(function_type, funcdef.name)
bb = function.append_basic_block("entry")
builder = lc.Builder.new(bb)
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
arg_llvm.name = arg_ast.arg
for k, v in ns.items():
v.alloca(builder, k)
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)
visitor = ast_body.Visitor(env, ns, builder)
visitor.visit_statements(funcdef.body)
if not tools.is_terminated(builder.basic_block):
if isinstance(retval, values.VNone):
builder.ret_void()
else:
builder.ret(retval.get_ssa_value(builder))
return function, retval

View File

@ -1,61 +1,55 @@
import ast import ast
from operator import itemgetter
from copy import deepcopy from copy import deepcopy
from artiq.py2llvm.ast_body import Visitor from artiq.py2llvm.ast_body import Visitor
from artiq.py2llvm import values
class _TypeScanner(ast.NodeVisitor): class _TypeScanner(ast.NodeVisitor):
def __init__(self, env, ns): def __init__(self, env, ns):
self.exprv = Visitor(env, ns) self.exprv = Visitor(env, ns)
def visit_Assign(self, node): def _update_target(self, target, val):
val = self.exprv.visit_expression(node.value)
ns = self.exprv.ns ns = self.exprv.ns
for target in node.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
if target.id in ns: if target.id in ns:
ns[target.id].merge(val) ns[target.id].merge(val)
else: else:
ns[target.id] = val ns[target.id] = deepcopy(val)
else: else:
raise NotImplementedError raise NotImplementedError
def visit_Assign(self, node):
val = self.exprv.visit_expression(node.value)
for target in node.targets:
self._update_target(target, val)
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
val = self.exprv.visit_expression(ast.BinOp( val = self.exprv.visit_expression(ast.BinOp(
op=node.op, left=node.target, right=node.value)) op=node.op, left=node.target, right=node.value))
self._update_target(node.target, val)
def visit_Return(self, node):
if node.value is None:
val = values.VNone()
else:
val = self.exprv.visit_expression(node.value)
ns = self.exprv.ns ns = self.exprv.ns
target = node.target if "return" in ns:
if isinstance(target, ast.Name): ns["return"].merge(val)
if target.id in ns:
ns[target.id].merge(val)
else: else:
ns[target.id] = val ns["return"] = deepcopy(val)
else:
raise NotImplementedError
def infer_function_types(env, node, param_types):
def infer_types(env, node): ns = deepcopy(param_types)
ns = dict() ts = _TypeScanner(env, ns)
ts.visit(node)
while True: while True:
prev_ns = deepcopy(ns) prev_ns = deepcopy(ns)
ts = _TypeScanner(env, ns) ts = _TypeScanner(env, ns)
ts.visit(node) ts.visit(node)
if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()): if all(v.same_type(prev_ns[k]) for k, v in ns.items()):
# no more promotions - completed # no more promotions - completed
if "return" not in ns:
ns["return"] = values.VNone()
return ns return ns
if __name__ == "__main__":
testcode = """
a = 2 # promoted later to int64
b = a + 1 # initially int32, becomes int64 after a is promoted
c = b//2 # initially int32, becomes int64 after b is promoted
d = 4 # stays int32
x = int64(7)
a += x # promotes a to int64
foo = True
bar = None
"""
ns = infer_types(None, ast.parse(testcode))
for k, v in sorted(ns.items(), key=itemgetter(0)):
print("{:10}--> {}".format(k, str(v)))

11
artiq/py2llvm/tools.py Normal file
View File

@ -0,0 +1,11 @@
from llvm import passes as lp
def is_terminated(basic_block):
return basic_block.instructions and basic_block.instructions[-1].is_terminator
def add_common_passes(pass_manager):
pass_manager.add(lp.PASS_MEM2REG)
pass_manager.add(lp.PASS_INSTCOMBINE)
pass_manager.add(lp.PASS_REASSOCIATE)
pass_manager.add(lp.PASS_GVN)
pass_manager.add(lp.PASS_SIMPLIFYCFG)

View File

@ -24,7 +24,7 @@ class _Value:
def alloca(self, builder, name): def alloca(self, builder, name):
if self._llvm_value is not None: if self._llvm_value is not None:
raise RuntimeError("Attempted to alloca existing LLVM value") raise RuntimeError("Attempted to alloca existing LLVM value "+name)
self._llvm_value = builder.alloca(self.get_llvm_type(), name=name) self._llvm_value = builder.alloca(self.get_llvm_type(), name=name)
def o_int(self, builder): def o_int(self, builder):

94
test/py2llvm.py Normal file
View File

@ -0,0 +1,94 @@
import unittest
import ast
import inspect
from llvm import core as lc
from llvm import passes as lp
from llvm import ee as le
from artiq.py2llvm.infer_types import infer_function_types
from artiq.py2llvm import values
from artiq.py2llvm import compile_function
from artiq.py2llvm.tools import add_common_passes
def test_types(choice):
a = 2 # promoted later to int64
b = a + 1 # initially int32, becomes int64 after a is promoted
c = b//2 # initially int32, becomes int64 after b is promoted
d = 4 # stays int32
x = int64(7)
a += x # promotes a to int64
foo = True
bar = None
if choice:
return 3
else:
return x
class FunctionTypesCase(unittest.TestCase):
def setUp(self):
self.ns = infer_function_types(
None, ast.parse(inspect.getsource(test_types)),
dict())
def test_base_types(self):
self.assertIsInstance(self.ns["foo"], values.VBool)
self.assertIsInstance(self.ns["bar"], values.VNone)
self.assertIsInstance(self.ns["d"], values.VInt)
self.assertEqual(self.ns["d"].nbits, 32)
self.assertIsInstance(self.ns["x"], values.VInt)
self.assertEqual(self.ns["x"].nbits, 64)
def test_promotion(self):
for v in "abc":
self.assertIsInstance(self.ns[v], values.VInt)
self.assertEqual(self.ns[v].nbits, 64)
def test_return(self):
self.assertIsInstance(self.ns["return"], values.VInt)
self.assertEqual(self.ns["return"].nbits, 64)
class CompiledFunction:
def __init__(self, function, param_types):
module = lc.Module.new("main")
values.init_module(module)
funcdef = ast.parse(inspect.getsource(function)).body[0]
self.function, self.retval = compile_function(
module, None, funcdef, param_types)
self.argval = [param_types[arg.arg] for arg in funcdef.args.args]
self.executor = le.ExecutionEngine.new(module)
pass_manager = lp.PassManager.new()
add_common_passes(pass_manager)
pass_manager.run(module)
def __call__(self, *args):
args_llvm = [
le.GenericValue.int(av.get_llvm_type(), a)
for av, a in zip(self.argval, args)]
result = self.executor.run_function(self.function, args_llvm)
if isinstance(self.retval, values.VBool):
return bool(result.as_int())
elif isinstance(self.retval, values.VInt):
return result.as_int_signed()
else:
raise NotImplementedError
def is_prime(x):
d = 2
while d*d <= x:
if not x % d:
return False
d += 1
return True
class CodeGenCase(unittest.TestCase):
def test_is_prime(self):
is_prime_c = CompiledFunction(is_prime, {"x": values.VInt(32)})
for i in range(200):
self.assertEqual(is_prime_c(i), is_prime(i))