From 6deaf7b81a02039194f602bcaa50d33837da221d Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sat, 6 Sep 2014 19:03:08 +0800 Subject: [PATCH] py2llvm: add support for function parameters and return values, add unit test --- artiq/py2llvm/__init__.py | 26 ++-------- artiq/py2llvm/ast_body.py | 38 +++++++++++---- artiq/py2llvm/functions.py | 31 ++++++++++++ artiq/py2llvm/infer_types.py | 66 ++++++++++++------------- artiq/py2llvm/tools.py | 11 +++++ artiq/py2llvm/values.py | 2 +- test/py2llvm.py | 94 ++++++++++++++++++++++++++++++++++++ 7 files changed, 200 insertions(+), 68 deletions(-) create mode 100644 artiq/py2llvm/functions.py create mode 100644 artiq/py2llvm/tools.py create mode 100644 test/py2llvm.py diff --git a/artiq/py2llvm/__init__.py b/artiq/py2llvm/__init__.py index ba273cb8c..6f23f0fc0 100644 --- a/artiq/py2llvm/__init__.py +++ b/artiq/py2llvm/__init__.py @@ -1,21 +1,9 @@ from llvm import core as lc from llvm import passes as lp -from artiq.py2llvm import infer_types, ast_body, values - - -def _compile_function(module, env, funcdef): - function_type = lc.Type.function(lc.Type.void(), []) - function = module.add_function(function_type, funcdef.name) - bb = function.append_basic_block("entry") - builder = lc.Builder.new(bb) - - ns = infer_types.infer_types(env, funcdef) - for k, v in ns.items(): - v.alloca(builder, k) - visitor = ast_body.Visitor(env, ns, builder) - visitor.visit_statements(funcdef.body) - builder.ret_void() +from artiq.py2llvm import values +from artiq.py2llvm.functions import compile_function +from artiq.py2llvm.tools import add_common_passes def get_runtime_binary(env, funcdef): @@ -23,14 +11,10 @@ def get_runtime_binary(env, funcdef): env.init_module(module) values.init_module(module) - _compile_function(module, env, funcdef) + compile_function(module, env, funcdef, dict()) pass_manager = lp.PassManager.new() - pass_manager.add(lp.PASS_MEM2REG) - pass_manager.add(lp.PASS_INSTCOMBINE) - pass_manager.add(lp.PASS_REASSOCIATE) - pass_manager.add(lp.PASS_GVN) - pass_manager.add(lp.PASS_SIMPLIFYCFG) + add_common_passes(pass_manager) pass_manager.run(module) return env.emit_object() diff --git a/artiq/py2llvm/ast_body.py b/artiq/py2llvm/ast_body.py index 2e3833c77..86d731251 100644 --- a/artiq/py2llvm/ast_body.py +++ b/artiq/py2llvm/ast_body.py @@ -1,6 +1,7 @@ import ast from artiq.py2llvm import values +from artiq.py2llvm.tools import is_terminated class Visitor: @@ -131,13 +132,16 @@ class Visitor: def visit_statements(self, stmts): for node in stmts: - method = "_visit_stmt_" + node.__class__.__name__ + node_type = node.__class__.__name__ + method = "_visit_stmt_" + node_type try: visitor = getattr(self, method) except AttributeError: raise NotImplementedError("Unsupported node '{}' in statement" - .format(node.__class__.__name__)) + .format(node_type)) visitor(node) + if node_type == "Return": + break def _visit_stmt_Assign(self, node): val = self.visit_expression(node.value) @@ -165,17 +169,19 @@ class Visitor: merge_block = function.append_basic_block("i_merge") condition = values.operators.bool(self.visit_expression(node.test), - self.builder) + self.builder) self.builder.cbranch(condition.get_ssa_value(self.builder), then_block, else_block) self.builder.position_at_end(then_block) self.visit_statements(node.body) - self.builder.branch(merge_block) + if not is_terminated(self.builder.basic_block): + self.builder.branch(merge_block) self.builder.position_at_end(else_block) self.visit_statements(node.orelse) - self.builder.branch(merge_block) + if not is_terminated(self.builder.basic_block): + self.builder.branch(merge_block) self.builder.position_at_end(merge_block) @@ -192,13 +198,25 @@ class Visitor: self.builder.position_at_end(body_block) self.visit_statements(node.body) - condition = values.operators.bool( - self.visit_expression(node.test), self.builder) - self.builder.cbranch( - condition.get_ssa_value(self.builder), body_block, merge_block) + if not is_terminated(self.builder.basic_block): + condition = values.operators.bool( + self.visit_expression(node.test), self.builder) + self.builder.cbranch( + condition.get_ssa_value(self.builder), body_block, merge_block) self.builder.position_at_end(else_block) self.visit_statements(node.orelse) - self.builder.branch(merge_block) + if not is_terminated(self.builder.basic_block): + self.builder.branch(merge_block) self.builder.position_at_end(merge_block) + + def _visit_stmt_Return(self, node): + if node.value is None: + val = values.VNone() + else: + val = self.visit_expression(node.value) + if isinstance(val, values.VNone): + self.builder.ret_void() + else: + self.builder.ret(val.get_ssa_value(self.builder)) diff --git a/artiq/py2llvm/functions.py b/artiq/py2llvm/functions.py new file mode 100644 index 000000000..cab173105 --- /dev/null +++ b/artiq/py2llvm/functions.py @@ -0,0 +1,31 @@ +from llvm import core as lc + +from artiq.py2llvm import infer_types, ast_body, values, tools + +def compile_function(module, env, funcdef, param_types): + ns = infer_types.infer_function_types(env, funcdef, param_types) + retval = ns["return"] + + function_type = lc.Type.function(retval.get_llvm_type(), + [ns[arg.arg].get_llvm_type() for arg in funcdef.args.args]) + function = module.add_function(function_type, funcdef.name) + bb = function.append_basic_block("entry") + builder = lc.Builder.new(bb) + + for arg_ast, arg_llvm in zip(funcdef.args.args, function.args): + arg_llvm.name = arg_ast.arg + for k, v in ns.items(): + v.alloca(builder, k) + for arg_ast, arg_llvm in zip(funcdef.args.args, function.args): + ns[arg_ast.arg].set_ssa_value(builder, arg_llvm) + + visitor = ast_body.Visitor(env, ns, builder) + visitor.visit_statements(funcdef.body) + + if not tools.is_terminated(builder.basic_block): + if isinstance(retval, values.VNone): + builder.ret_void() + else: + builder.ret(retval.get_ssa_value(builder)) + + return function, retval diff --git a/artiq/py2llvm/infer_types.py b/artiq/py2llvm/infer_types.py index 2057d4eaa..5e542db34 100644 --- a/artiq/py2llvm/infer_types.py +++ b/artiq/py2llvm/infer_types.py @@ -1,61 +1,55 @@ import ast -from operator import itemgetter from copy import deepcopy from artiq.py2llvm.ast_body import Visitor +from artiq.py2llvm import values class _TypeScanner(ast.NodeVisitor): def __init__(self, env, ns): self.exprv = Visitor(env, ns) - def visit_Assign(self, node): - val = self.exprv.visit_expression(node.value) + def _update_target(self, target, val): ns = self.exprv.ns - for target in node.targets: - if isinstance(target, ast.Name): - if target.id in ns: - ns[target.id].merge(val) - else: - ns[target.id] = val - else: - raise NotImplementedError - - def visit_AugAssign(self, node): - val = self.exprv.visit_expression(ast.BinOp( - op=node.op, left=node.target, right=node.value)) - ns = self.exprv.ns - target = node.target if isinstance(target, ast.Name): if target.id in ns: ns[target.id].merge(val) else: - ns[target.id] = val + ns[target.id] = deepcopy(val) else: raise NotImplementedError + def visit_Assign(self, node): + val = self.exprv.visit_expression(node.value) + for target in node.targets: + self._update_target(target, val) -def infer_types(env, node): - ns = dict() + def visit_AugAssign(self, node): + val = self.exprv.visit_expression(ast.BinOp( + op=node.op, left=node.target, right=node.value)) + self._update_target(node.target, val) + + def visit_Return(self, node): + if node.value is None: + val = values.VNone() + else: + val = self.exprv.visit_expression(node.value) + ns = self.exprv.ns + if "return" in ns: + ns["return"].merge(val) + else: + ns["return"] = deepcopy(val) + +def infer_function_types(env, node, param_types): + ns = deepcopy(param_types) + ts = _TypeScanner(env, ns) + ts.visit(node) while True: prev_ns = deepcopy(ns) ts = _TypeScanner(env, ns) ts.visit(node) - if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()): + if all(v.same_type(prev_ns[k]) for k, v in ns.items()): # no more promotions - completed + if "return" not in ns: + ns["return"] = values.VNone() return ns - -if __name__ == "__main__": - testcode = """ -a = 2 # promoted later to int64 -b = a + 1 # initially int32, becomes int64 after a is promoted -c = b//2 # initially int32, becomes int64 after b is promoted -d = 4 # stays int32 -x = int64(7) -a += x # promotes a to int64 -foo = True -bar = None -""" - ns = infer_types(None, ast.parse(testcode)) - for k, v in sorted(ns.items(), key=itemgetter(0)): - print("{:10}--> {}".format(k, str(v))) diff --git a/artiq/py2llvm/tools.py b/artiq/py2llvm/tools.py new file mode 100644 index 000000000..bdc3a0791 --- /dev/null +++ b/artiq/py2llvm/tools.py @@ -0,0 +1,11 @@ +from llvm import passes as lp + +def is_terminated(basic_block): + return basic_block.instructions and basic_block.instructions[-1].is_terminator + +def add_common_passes(pass_manager): + pass_manager.add(lp.PASS_MEM2REG) + pass_manager.add(lp.PASS_INSTCOMBINE) + pass_manager.add(lp.PASS_REASSOCIATE) + pass_manager.add(lp.PASS_GVN) + pass_manager.add(lp.PASS_SIMPLIFYCFG) diff --git a/artiq/py2llvm/values.py b/artiq/py2llvm/values.py index 3ecdf94ec..1760cb8b5 100644 --- a/artiq/py2llvm/values.py +++ b/artiq/py2llvm/values.py @@ -24,7 +24,7 @@ class _Value: def alloca(self, builder, name): if self._llvm_value is not None: - raise RuntimeError("Attempted to alloca existing LLVM value") + raise RuntimeError("Attempted to alloca existing LLVM value "+name) self._llvm_value = builder.alloca(self.get_llvm_type(), name=name) def o_int(self, builder): diff --git a/test/py2llvm.py b/test/py2llvm.py new file mode 100644 index 000000000..7b0074421 --- /dev/null +++ b/test/py2llvm.py @@ -0,0 +1,94 @@ +import unittest +import ast +import inspect + +from llvm import core as lc +from llvm import passes as lp +from llvm import ee as le + +from artiq.py2llvm.infer_types import infer_function_types +from artiq.py2llvm import values +from artiq.py2llvm import compile_function +from artiq.py2llvm.tools import add_common_passes + + +def test_types(choice): + a = 2 # promoted later to int64 + b = a + 1 # initially int32, becomes int64 after a is promoted + c = b//2 # initially int32, becomes int64 after b is promoted + d = 4 # stays int32 + x = int64(7) + a += x # promotes a to int64 + foo = True + bar = None + + if choice: + return 3 + else: + return x + +class FunctionTypesCase(unittest.TestCase): + def setUp(self): + self.ns = infer_function_types( + None, ast.parse(inspect.getsource(test_types)), + dict()) + + def test_base_types(self): + self.assertIsInstance(self.ns["foo"], values.VBool) + self.assertIsInstance(self.ns["bar"], values.VNone) + self.assertIsInstance(self.ns["d"], values.VInt) + self.assertEqual(self.ns["d"].nbits, 32) + self.assertIsInstance(self.ns["x"], values.VInt) + self.assertEqual(self.ns["x"].nbits, 64) + + def test_promotion(self): + for v in "abc": + self.assertIsInstance(self.ns[v], values.VInt) + self.assertEqual(self.ns[v].nbits, 64) + + def test_return(self): + self.assertIsInstance(self.ns["return"], values.VInt) + self.assertEqual(self.ns["return"].nbits, 64) + + +class CompiledFunction: + def __init__(self, function, param_types): + module = lc.Module.new("main") + values.init_module(module) + + funcdef = ast.parse(inspect.getsource(function)).body[0] + self.function, self.retval = compile_function( + module, None, funcdef, param_types) + self.argval = [param_types[arg.arg] for arg in funcdef.args.args] + + self.executor = le.ExecutionEngine.new(module) + pass_manager = lp.PassManager.new() + add_common_passes(pass_manager) + pass_manager.run(module) + + def __call__(self, *args): + args_llvm = [ + le.GenericValue.int(av.get_llvm_type(), a) + for av, a in zip(self.argval, args)] + result = self.executor.run_function(self.function, args_llvm) + if isinstance(self.retval, values.VBool): + return bool(result.as_int()) + elif isinstance(self.retval, values.VInt): + return result.as_int_signed() + else: + raise NotImplementedError + + +def is_prime(x): + d = 2 + while d*d <= x: + if not x % d: + return False + d += 1 + return True + +class CodeGenCase(unittest.TestCase): + def test_is_prime(self): + is_prime_c = CompiledFunction(is_prime, {"x": values.VInt(32)}) + for i in range(200): + self.assertEqual(is_prime_c(i), is_prime(i))