From 15dcf3351bd700af7ea2e932fe68e22cac9766b7 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sun, 7 Sep 2014 14:46:32 +0800 Subject: [PATCH] py2llvm: move GCD function into LLVM IR --- artiq/devices/corecom_dummy.py | 2 +- artiq/devices/runtime.py | 8 +-- artiq/py2llvm/fractions.py | 26 ++++++---- artiq/py2llvm/module.py | 4 +- soc/runtime/symbols.c | 94 +++++++++++++++------------------- 5 files changed, 66 insertions(+), 68 deletions(-) diff --git a/artiq/devices/corecom_dummy.py b/artiq/devices/corecom_dummy.py index 4d3cb5aef..61a40a929 100644 --- a/artiq/devices/corecom_dummy.py +++ b/artiq/devices/corecom_dummy.py @@ -9,7 +9,7 @@ class _RuntimeEnvironment(LinkInterface): self.ref_period = ref_period def emit_object(self): - return str(self.module) + return str(self.llvm_module) class CoreCom: diff --git a/artiq/devices/runtime.py b/artiq/devices/runtime.py index 2554fced9..1c8d88275 100644 --- a/artiq/devices/runtime.py +++ b/artiq/devices/runtime.py @@ -46,13 +46,13 @@ def _str_to_functype(s): class LinkInterface: def init_module(self, module): - self.module = module + self.llvm_module = module.llvm_module self.var_arg_fixcount = dict() for func_name, func_type_str in _syscalls.items(): var_arg_fixcount, func_type = _str_to_functype(func_type_str) if var_arg_fixcount is not None: self.var_arg_fixcount[func_name] = var_arg_fixcount - self.module.add_function(func_type, "__syscall_"+func_name) + self.llvm_module.add_function(func_type, "__syscall_"+func_name) def syscall(self, syscall_name, args, builder): r = _chr_to_value[_syscalls[syscall_name][-1]]() @@ -63,7 +63,7 @@ class LinkInterface: args = args[:fixcount] \ + [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \ + args[fixcount:] - llvm_function = self.module.get_function_named( + llvm_function = self.llvm_module.get_function_named( "__syscall_" + syscall_name) r.set_ssa_value(builder, builder.call(llvm_function, args)) return r @@ -76,5 +76,5 @@ class Environment(LinkInterface): def emit_object(self): tm = lt.TargetMachine.new(triple="or1k", cpu="generic") - obj = tm.emit_object(self.module) + obj = tm.emit_object(self.llvm_module) return obj diff --git a/artiq/py2llvm/fractions.py b/artiq/py2llvm/fractions.py index ff46e90ae..3544d7275 100644 --- a/artiq/py2llvm/fractions.py +++ b/artiq/py2llvm/fractions.py @@ -1,21 +1,29 @@ +import inspect +import ast + from llvm import core as lc from artiq.py2llvm.values import VGeneric from artiq.py2llvm.base_types import VBool, VInt -def _gcd64(builder, a, b): - gcd_f = builder.basic_block.function.module.get_function_named("__gcd64") - return builder.call(gcd_f, [a, b]) +def _gcd(a, b): + while a: + c = a + a = b % a + b = c + return b def init_module(module): - func_type = lc.Type.function( - lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)]) - module.add_function(func_type, "__gcd64") + funcdef = ast.parse(inspect.getsource(_gcd)).body[0] + module.compile_function(funcdef, {"a": VInt(64), "b": VInt(64)}) +def _call_gcd(builder, a, b): + gcd_f = builder.basic_block.function.module.get_function_named("_gcd") + return builder.call(gcd_f, [a, b]) def _frac_normalize(builder, numerator, denominator): - gcd = _gcd64(builder, numerator, denominator) + gcd = _call_gcd(builder, numerator, denominator) numerator = builder.sdiv(numerator, gcd) denominator = builder.sdiv(denominator, gcd) return numerator, denominator @@ -135,12 +143,12 @@ class VFraction(VGeneric): numerator, denominator = self._nd(builder, invert) i = other.get_ssa_value(builder) if div: - gcd = _gcd64(i, numerator) + gcd = _call_gcd(builder, i, numerator) i = builder.sdiv(i, gcd) numerator = builder.sdiv(numerator, gcd) denominator = builder.mul(denominator, i) else: - gcd = _gcd64(i, denominator) + gcd = _call_gcd(builder, i, denominator) i = builder.sdiv(i, gcd) denominator = builder.sdiv(denominator, gcd) numerator = builder.mul(numerator, i) diff --git a/artiq/py2llvm/module.py b/artiq/py2llvm/module.py index 4dfd646ba..43c9bd418 100644 --- a/artiq/py2llvm/module.py +++ b/artiq/py2llvm/module.py @@ -11,8 +11,8 @@ class Module: self.env = env if self.env is not None: - self.env.init_module(self.llvm_module) - fractions.init_module(self.llvm_module) + self.env.init_module(self) + fractions.init_module(self) def finalize(self): pass_manager = lp.PassManager.new() diff --git a/soc/runtime/symbols.c b/soc/runtime/symbols.c index 224c98941..5f04ce347 100644 --- a/soc/runtime/symbols.c +++ b/soc/runtime/symbols.c @@ -16,18 +16,6 @@ static const struct symbol syscalls[] = { {NULL, NULL} }; -static long long int gcd64(long long int a, long long int b) -{ - long long int c; - - while(a) { - c = a; - a = b % a; - b = c; - } - return b; -} - #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wimplicit-int" extern __divsi3, __modsi3, __ledf2, __gedf2, __unorddf2, __negsf2, __negdf2, @@ -38,49 +26,51 @@ extern __divsi3, __modsi3, __ledf2, __gedf2, __unorddf2, __negsf2, __negdf2, __udivdi3, __umoddi3, __moddi3; #pragma GCC diagnostic pop -static const struct symbol arithmetic[] = { - {"__divsi3", &__divsi3}, - {"__modsi3", &__modsi3}, - {"__ledf2", &__ledf2}, - {"__gedf2", &__gedf2}, - {"__unorddf2", &__gedf2}, - {"__negsf2", &__negsf2}, - {"__negdf2", &__negdf2}, - {"__addsf3", &__addsf3}, - {"__subsf3", &__subsf3}, - {"__mulsf3", &__mulsf3}, - {"__divsf3", &__divsf3}, - {"__lshrdi3", &__lshrdi3}, - {"__muldi3", &__muldi3}, - {"__divdi3", &__divdi3}, - {"__ashldi3", &__ashldi3}, - {"__ashrdi3", &__ashrdi3}, - {"__udivmoddi4", &__udivmoddi4}, - {"__floatsisf", &__floatsisf}, - {"__floatunsisf", &__floatunsisf}, - {"__fixsfsi", &__fixsfsi}, - {"__fixunssfsi", &__fixunssfsi}, - {"__adddf3", &__adddf3}, - {"__subdf3", &__subdf3}, - {"__muldf3", &__muldf3}, - {"__divdf3", &__divdf3}, - {"__floatsidf", &__floatsidf}, - {"__floatunsidf", &__floatunsidf}, - {"__floatdidf", &__floatdidf}, - {"__fixdfsi", &__fixdfsi}, - {"__fixunsdfsi", &__fixunsdfsi}, - {"__clzsi2", &__clzsi2}, - {"__ctzsi2", &__ctzsi2}, - {"__udivdi3", &__udivdi3}, - {"__umoddi3", &__umoddi3}, - {"__moddi3", &__moddi3}, - {"__gcd64", gcd64}, +static const struct symbol compiler_rt[] = { + {"divsi3", &__divsi3}, + {"modsi3", &__modsi3}, + {"ledf2", &__ledf2}, + {"gedf2", &__gedf2}, + {"unorddf2", &__gedf2}, + {"negsf2", &__negsf2}, + {"negdf2", &__negdf2}, + {"addsf3", &__addsf3}, + {"subsf3", &__subsf3}, + {"mulsf3", &__mulsf3}, + {"divsf3", &__divsf3}, + {"lshrdi3", &__lshrdi3}, + {"muldi3", &__muldi3}, + {"divdi3", &__divdi3}, + {"ashldi3", &__ashldi3}, + {"ashrdi3", &__ashrdi3}, + {"udivmoddi4", &__udivmoddi4}, + {"floatsisf", &__floatsisf}, + {"floatunsisf", &__floatunsisf}, + {"fixsfsi", &__fixsfsi}, + {"fixunssfsi", &__fixunssfsi}, + {"adddf3", &__adddf3}, + {"subdf3", &__subdf3}, + {"muldf3", &__muldf3}, + {"divdf3", &__divdf3}, + {"floatsidf", &__floatsidf}, + {"floatunsidf", &__floatunsidf}, + {"floatdidf", &__floatdidf}, + {"fixdfsi", &__fixdfsi}, + {"fixunsdfsi", &__fixunsdfsi}, + {"clzsi2", &__clzsi2}, + {"ctzsi2", &__ctzsi2}, + {"udivdi3", &__udivdi3}, + {"umoddi3", &__umoddi3}, + {"moddi3", &__moddi3}, {NULL, NULL} }; void *resolve_symbol(const char *name) { - if(strncmp(name, "__syscall_", 10) == 0) - return find_symbol(syscalls, name + 10); - return find_symbol(arithmetic, name); + if(strncmp(name, "__", 2) != 0) + return NULL; + name += 2; + if(strncmp(name, "syscall_", 8) == 0) + return find_symbol(syscalls, name + 8); + return find_symbol(compiler_rt, name); }