py2llvm: move GCD function into LLVM IR

This commit is contained in:
Sebastien Bourdeauducq 2014-09-07 14:46:32 +08:00
parent 3c8b541939
commit 15dcf3351b
5 changed files with 66 additions and 68 deletions

View File

@ -9,7 +9,7 @@ class _RuntimeEnvironment(LinkInterface):
self.ref_period = ref_period
def emit_object(self):
return str(self.module)
return str(self.llvm_module)
class CoreCom:

View File

@ -46,13 +46,13 @@ def _str_to_functype(s):
class LinkInterface:
def init_module(self, module):
self.module = module
self.llvm_module = module.llvm_module
self.var_arg_fixcount = dict()
for func_name, func_type_str in _syscalls.items():
var_arg_fixcount, func_type = _str_to_functype(func_type_str)
if var_arg_fixcount is not None:
self.var_arg_fixcount[func_name] = var_arg_fixcount
self.module.add_function(func_type, "__syscall_"+func_name)
self.llvm_module.add_function(func_type, "__syscall_"+func_name)
def syscall(self, syscall_name, args, builder):
r = _chr_to_value[_syscalls[syscall_name][-1]]()
@ -63,7 +63,7 @@ class LinkInterface:
args = args[:fixcount] \
+ [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \
+ args[fixcount:]
llvm_function = self.module.get_function_named(
llvm_function = self.llvm_module.get_function_named(
"__syscall_" + syscall_name)
r.set_ssa_value(builder, builder.call(llvm_function, args))
return r
@ -76,5 +76,5 @@ class Environment(LinkInterface):
def emit_object(self):
tm = lt.TargetMachine.new(triple="or1k", cpu="generic")
obj = tm.emit_object(self.module)
obj = tm.emit_object(self.llvm_module)
return obj

View File

@ -1,21 +1,29 @@
import inspect
import ast
from llvm import core as lc
from artiq.py2llvm.values import VGeneric
from artiq.py2llvm.base_types import VBool, VInt
def _gcd64(builder, a, b):
gcd_f = builder.basic_block.function.module.get_function_named("__gcd64")
return builder.call(gcd_f, [a, b])
def _gcd(a, b):
while a:
c = a
a = b % a
b = c
return b
def init_module(module):
func_type = lc.Type.function(
lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)])
module.add_function(func_type, "__gcd64")
funcdef = ast.parse(inspect.getsource(_gcd)).body[0]
module.compile_function(funcdef, {"a": VInt(64), "b": VInt(64)})
def _call_gcd(builder, a, b):
gcd_f = builder.basic_block.function.module.get_function_named("_gcd")
return builder.call(gcd_f, [a, b])
def _frac_normalize(builder, numerator, denominator):
gcd = _gcd64(builder, numerator, denominator)
gcd = _call_gcd(builder, numerator, denominator)
numerator = builder.sdiv(numerator, gcd)
denominator = builder.sdiv(denominator, gcd)
return numerator, denominator
@ -135,12 +143,12 @@ class VFraction(VGeneric):
numerator, denominator = self._nd(builder, invert)
i = other.get_ssa_value(builder)
if div:
gcd = _gcd64(i, numerator)
gcd = _call_gcd(builder, i, numerator)
i = builder.sdiv(i, gcd)
numerator = builder.sdiv(numerator, gcd)
denominator = builder.mul(denominator, i)
else:
gcd = _gcd64(i, denominator)
gcd = _call_gcd(builder, i, denominator)
i = builder.sdiv(i, gcd)
denominator = builder.sdiv(denominator, gcd)
numerator = builder.mul(numerator, i)

View File

@ -11,8 +11,8 @@ class Module:
self.env = env
if self.env is not None:
self.env.init_module(self.llvm_module)
fractions.init_module(self.llvm_module)
self.env.init_module(self)
fractions.init_module(self)
def finalize(self):
pass_manager = lp.PassManager.new()

View File

@ -16,18 +16,6 @@ static const struct symbol syscalls[] = {
{NULL, NULL}
};
static long long int gcd64(long long int a, long long int b)
{
long long int c;
while(a) {
c = a;
a = b % a;
b = c;
}
return b;
}
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wimplicit-int"
extern __divsi3, __modsi3, __ledf2, __gedf2, __unorddf2, __negsf2, __negdf2,
@ -38,49 +26,51 @@ extern __divsi3, __modsi3, __ledf2, __gedf2, __unorddf2, __negsf2, __negdf2,
__udivdi3, __umoddi3, __moddi3;
#pragma GCC diagnostic pop
static const struct symbol arithmetic[] = {
{"__divsi3", &__divsi3},
{"__modsi3", &__modsi3},
{"__ledf2", &__ledf2},
{"__gedf2", &__gedf2},
{"__unorddf2", &__gedf2},
{"__negsf2", &__negsf2},
{"__negdf2", &__negdf2},
{"__addsf3", &__addsf3},
{"__subsf3", &__subsf3},
{"__mulsf3", &__mulsf3},
{"__divsf3", &__divsf3},
{"__lshrdi3", &__lshrdi3},
{"__muldi3", &__muldi3},
{"__divdi3", &__divdi3},
{"__ashldi3", &__ashldi3},
{"__ashrdi3", &__ashrdi3},
{"__udivmoddi4", &__udivmoddi4},
{"__floatsisf", &__floatsisf},
{"__floatunsisf", &__floatunsisf},
{"__fixsfsi", &__fixsfsi},
{"__fixunssfsi", &__fixunssfsi},
{"__adddf3", &__adddf3},
{"__subdf3", &__subdf3},
{"__muldf3", &__muldf3},
{"__divdf3", &__divdf3},
{"__floatsidf", &__floatsidf},
{"__floatunsidf", &__floatunsidf},
{"__floatdidf", &__floatdidf},
{"__fixdfsi", &__fixdfsi},
{"__fixunsdfsi", &__fixunsdfsi},
{"__clzsi2", &__clzsi2},
{"__ctzsi2", &__ctzsi2},
{"__udivdi3", &__udivdi3},
{"__umoddi3", &__umoddi3},
{"__moddi3", &__moddi3},
{"__gcd64", gcd64},
static const struct symbol compiler_rt[] = {
{"divsi3", &__divsi3},
{"modsi3", &__modsi3},
{"ledf2", &__ledf2},
{"gedf2", &__gedf2},
{"unorddf2", &__gedf2},
{"negsf2", &__negsf2},
{"negdf2", &__negdf2},
{"addsf3", &__addsf3},
{"subsf3", &__subsf3},
{"mulsf3", &__mulsf3},
{"divsf3", &__divsf3},
{"lshrdi3", &__lshrdi3},
{"muldi3", &__muldi3},
{"divdi3", &__divdi3},
{"ashldi3", &__ashldi3},
{"ashrdi3", &__ashrdi3},
{"udivmoddi4", &__udivmoddi4},
{"floatsisf", &__floatsisf},
{"floatunsisf", &__floatunsisf},
{"fixsfsi", &__fixsfsi},
{"fixunssfsi", &__fixunssfsi},
{"adddf3", &__adddf3},
{"subdf3", &__subdf3},
{"muldf3", &__muldf3},
{"divdf3", &__divdf3},
{"floatsidf", &__floatsidf},
{"floatunsidf", &__floatunsidf},
{"floatdidf", &__floatdidf},
{"fixdfsi", &__fixdfsi},
{"fixunsdfsi", &__fixunsdfsi},
{"clzsi2", &__clzsi2},
{"ctzsi2", &__ctzsi2},
{"udivdi3", &__udivdi3},
{"umoddi3", &__umoddi3},
{"moddi3", &__moddi3},
{NULL, NULL}
};
void *resolve_symbol(const char *name)
{
if(strncmp(name, "__syscall_", 10) == 0)
return find_symbol(syscalls, name + 10);
return find_symbol(arithmetic, name);
if(strncmp(name, "__", 2) != 0)
return NULL;
name += 2;
if(strncmp(name, "syscall_", 8) == 0)
return find_symbol(syscalls, name + 8);
return find_symbol(compiler_rt, name);
}