py2llvm: move GCD function into LLVM IR

This commit is contained in:
Sebastien Bourdeauducq 2014-09-07 14:46:32 +08:00
parent 3c8b541939
commit 15dcf3351b
5 changed files with 66 additions and 68 deletions

View File

@ -9,7 +9,7 @@ class _RuntimeEnvironment(LinkInterface):
self.ref_period = ref_period self.ref_period = ref_period
def emit_object(self): def emit_object(self):
return str(self.module) return str(self.llvm_module)
class CoreCom: class CoreCom:

View File

@ -46,13 +46,13 @@ def _str_to_functype(s):
class LinkInterface: class LinkInterface:
def init_module(self, module): def init_module(self, module):
self.module = module self.llvm_module = module.llvm_module
self.var_arg_fixcount = dict() self.var_arg_fixcount = dict()
for func_name, func_type_str in _syscalls.items(): for func_name, func_type_str in _syscalls.items():
var_arg_fixcount, func_type = _str_to_functype(func_type_str) var_arg_fixcount, func_type = _str_to_functype(func_type_str)
if var_arg_fixcount is not None: if var_arg_fixcount is not None:
self.var_arg_fixcount[func_name] = var_arg_fixcount self.var_arg_fixcount[func_name] = var_arg_fixcount
self.module.add_function(func_type, "__syscall_"+func_name) self.llvm_module.add_function(func_type, "__syscall_"+func_name)
def syscall(self, syscall_name, args, builder): def syscall(self, syscall_name, args, builder):
r = _chr_to_value[_syscalls[syscall_name][-1]]() r = _chr_to_value[_syscalls[syscall_name][-1]]()
@ -63,7 +63,7 @@ class LinkInterface:
args = args[:fixcount] \ args = args[:fixcount] \
+ [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \ + [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \
+ args[fixcount:] + args[fixcount:]
llvm_function = self.module.get_function_named( llvm_function = self.llvm_module.get_function_named(
"__syscall_" + syscall_name) "__syscall_" + syscall_name)
r.set_ssa_value(builder, builder.call(llvm_function, args)) r.set_ssa_value(builder, builder.call(llvm_function, args))
return r return r
@ -76,5 +76,5 @@ class Environment(LinkInterface):
def emit_object(self): def emit_object(self):
tm = lt.TargetMachine.new(triple="or1k", cpu="generic") tm = lt.TargetMachine.new(triple="or1k", cpu="generic")
obj = tm.emit_object(self.module) obj = tm.emit_object(self.llvm_module)
return obj return obj

View File

@ -1,21 +1,29 @@
import inspect
import ast
from llvm import core as lc from llvm import core as lc
from artiq.py2llvm.values import VGeneric from artiq.py2llvm.values import VGeneric
from artiq.py2llvm.base_types import VBool, VInt from artiq.py2llvm.base_types import VBool, VInt
def _gcd64(builder, a, b): def _gcd(a, b):
gcd_f = builder.basic_block.function.module.get_function_named("__gcd64") while a:
return builder.call(gcd_f, [a, b]) c = a
a = b % a
b = c
return b
def init_module(module): def init_module(module):
func_type = lc.Type.function( funcdef = ast.parse(inspect.getsource(_gcd)).body[0]
lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)]) module.compile_function(funcdef, {"a": VInt(64), "b": VInt(64)})
module.add_function(func_type, "__gcd64")
def _call_gcd(builder, a, b):
gcd_f = builder.basic_block.function.module.get_function_named("_gcd")
return builder.call(gcd_f, [a, b])
def _frac_normalize(builder, numerator, denominator): def _frac_normalize(builder, numerator, denominator):
gcd = _gcd64(builder, numerator, denominator) gcd = _call_gcd(builder, numerator, denominator)
numerator = builder.sdiv(numerator, gcd) numerator = builder.sdiv(numerator, gcd)
denominator = builder.sdiv(denominator, gcd) denominator = builder.sdiv(denominator, gcd)
return numerator, denominator return numerator, denominator
@ -135,12 +143,12 @@ class VFraction(VGeneric):
numerator, denominator = self._nd(builder, invert) numerator, denominator = self._nd(builder, invert)
i = other.get_ssa_value(builder) i = other.get_ssa_value(builder)
if div: if div:
gcd = _gcd64(i, numerator) gcd = _call_gcd(builder, i, numerator)
i = builder.sdiv(i, gcd) i = builder.sdiv(i, gcd)
numerator = builder.sdiv(numerator, gcd) numerator = builder.sdiv(numerator, gcd)
denominator = builder.mul(denominator, i) denominator = builder.mul(denominator, i)
else: else:
gcd = _gcd64(i, denominator) gcd = _call_gcd(builder, i, denominator)
i = builder.sdiv(i, gcd) i = builder.sdiv(i, gcd)
denominator = builder.sdiv(denominator, gcd) denominator = builder.sdiv(denominator, gcd)
numerator = builder.mul(numerator, i) numerator = builder.mul(numerator, i)

View File

@ -11,8 +11,8 @@ class Module:
self.env = env self.env = env
if self.env is not None: if self.env is not None:
self.env.init_module(self.llvm_module) self.env.init_module(self)
fractions.init_module(self.llvm_module) fractions.init_module(self)
def finalize(self): def finalize(self):
pass_manager = lp.PassManager.new() pass_manager = lp.PassManager.new()

View File

@ -16,18 +16,6 @@ static const struct symbol syscalls[] = {
{NULL, NULL} {NULL, NULL}
}; };
static long long int gcd64(long long int a, long long int b)
{
long long int c;
while(a) {
c = a;
a = b % a;
b = c;
}
return b;
}
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wimplicit-int" #pragma GCC diagnostic ignored "-Wimplicit-int"
extern __divsi3, __modsi3, __ledf2, __gedf2, __unorddf2, __negsf2, __negdf2, extern __divsi3, __modsi3, __ledf2, __gedf2, __unorddf2, __negsf2, __negdf2,
@ -38,49 +26,51 @@ extern __divsi3, __modsi3, __ledf2, __gedf2, __unorddf2, __negsf2, __negdf2,
__udivdi3, __umoddi3, __moddi3; __udivdi3, __umoddi3, __moddi3;
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
static const struct symbol arithmetic[] = { static const struct symbol compiler_rt[] = {
{"__divsi3", &__divsi3}, {"divsi3", &__divsi3},
{"__modsi3", &__modsi3}, {"modsi3", &__modsi3},
{"__ledf2", &__ledf2}, {"ledf2", &__ledf2},
{"__gedf2", &__gedf2}, {"gedf2", &__gedf2},
{"__unorddf2", &__gedf2}, {"unorddf2", &__gedf2},
{"__negsf2", &__negsf2}, {"negsf2", &__negsf2},
{"__negdf2", &__negdf2}, {"negdf2", &__negdf2},
{"__addsf3", &__addsf3}, {"addsf3", &__addsf3},
{"__subsf3", &__subsf3}, {"subsf3", &__subsf3},
{"__mulsf3", &__mulsf3}, {"mulsf3", &__mulsf3},
{"__divsf3", &__divsf3}, {"divsf3", &__divsf3},
{"__lshrdi3", &__lshrdi3}, {"lshrdi3", &__lshrdi3},
{"__muldi3", &__muldi3}, {"muldi3", &__muldi3},
{"__divdi3", &__divdi3}, {"divdi3", &__divdi3},
{"__ashldi3", &__ashldi3}, {"ashldi3", &__ashldi3},
{"__ashrdi3", &__ashrdi3}, {"ashrdi3", &__ashrdi3},
{"__udivmoddi4", &__udivmoddi4}, {"udivmoddi4", &__udivmoddi4},
{"__floatsisf", &__floatsisf}, {"floatsisf", &__floatsisf},
{"__floatunsisf", &__floatunsisf}, {"floatunsisf", &__floatunsisf},
{"__fixsfsi", &__fixsfsi}, {"fixsfsi", &__fixsfsi},
{"__fixunssfsi", &__fixunssfsi}, {"fixunssfsi", &__fixunssfsi},
{"__adddf3", &__adddf3}, {"adddf3", &__adddf3},
{"__subdf3", &__subdf3}, {"subdf3", &__subdf3},
{"__muldf3", &__muldf3}, {"muldf3", &__muldf3},
{"__divdf3", &__divdf3}, {"divdf3", &__divdf3},
{"__floatsidf", &__floatsidf}, {"floatsidf", &__floatsidf},
{"__floatunsidf", &__floatunsidf}, {"floatunsidf", &__floatunsidf},
{"__floatdidf", &__floatdidf}, {"floatdidf", &__floatdidf},
{"__fixdfsi", &__fixdfsi}, {"fixdfsi", &__fixdfsi},
{"__fixunsdfsi", &__fixunsdfsi}, {"fixunsdfsi", &__fixunsdfsi},
{"__clzsi2", &__clzsi2}, {"clzsi2", &__clzsi2},
{"__ctzsi2", &__ctzsi2}, {"ctzsi2", &__ctzsi2},
{"__udivdi3", &__udivdi3}, {"udivdi3", &__udivdi3},
{"__umoddi3", &__umoddi3}, {"umoddi3", &__umoddi3},
{"__moddi3", &__moddi3}, {"moddi3", &__moddi3},
{"__gcd64", gcd64},
{NULL, NULL} {NULL, NULL}
}; };
void *resolve_symbol(const char *name) void *resolve_symbol(const char *name)
{ {
if(strncmp(name, "__syscall_", 10) == 0) if(strncmp(name, "__", 2) != 0)
return find_symbol(syscalls, name + 10); return NULL;
return find_symbol(arithmetic, name); name += 2;
if(strncmp(name, "syscall_", 8) == 0)
return find_symbol(syscalls, name + 8);
return find_symbol(compiler_rt, name);
} }