forked from M-Labs/artiq
py2llvm: move GCD function into LLVM IR
This commit is contained in:
parent
3c8b541939
commit
15dcf3351b
|
@ -9,7 +9,7 @@ class _RuntimeEnvironment(LinkInterface):
|
|||
self.ref_period = ref_period
|
||||
|
||||
def emit_object(self):
|
||||
return str(self.module)
|
||||
return str(self.llvm_module)
|
||||
|
||||
|
||||
class CoreCom:
|
||||
|
|
|
@ -46,13 +46,13 @@ def _str_to_functype(s):
|
|||
|
||||
class LinkInterface:
|
||||
def init_module(self, module):
|
||||
self.module = module
|
||||
self.llvm_module = module.llvm_module
|
||||
self.var_arg_fixcount = dict()
|
||||
for func_name, func_type_str in _syscalls.items():
|
||||
var_arg_fixcount, func_type = _str_to_functype(func_type_str)
|
||||
if var_arg_fixcount is not None:
|
||||
self.var_arg_fixcount[func_name] = var_arg_fixcount
|
||||
self.module.add_function(func_type, "__syscall_"+func_name)
|
||||
self.llvm_module.add_function(func_type, "__syscall_"+func_name)
|
||||
|
||||
def syscall(self, syscall_name, args, builder):
|
||||
r = _chr_to_value[_syscalls[syscall_name][-1]]()
|
||||
|
@ -63,7 +63,7 @@ class LinkInterface:
|
|||
args = args[:fixcount] \
|
||||
+ [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \
|
||||
+ args[fixcount:]
|
||||
llvm_function = self.module.get_function_named(
|
||||
llvm_function = self.llvm_module.get_function_named(
|
||||
"__syscall_" + syscall_name)
|
||||
r.set_ssa_value(builder, builder.call(llvm_function, args))
|
||||
return r
|
||||
|
@ -76,5 +76,5 @@ class Environment(LinkInterface):
|
|||
|
||||
def emit_object(self):
|
||||
tm = lt.TargetMachine.new(triple="or1k", cpu="generic")
|
||||
obj = tm.emit_object(self.module)
|
||||
obj = tm.emit_object(self.llvm_module)
|
||||
return obj
|
||||
|
|
|
@ -1,21 +1,29 @@
|
|||
import inspect
|
||||
import ast
|
||||
|
||||
from llvm import core as lc
|
||||
|
||||
from artiq.py2llvm.values import VGeneric
|
||||
from artiq.py2llvm.base_types import VBool, VInt
|
||||
|
||||
|
||||
def _gcd64(builder, a, b):
|
||||
gcd_f = builder.basic_block.function.module.get_function_named("__gcd64")
|
||||
return builder.call(gcd_f, [a, b])
|
||||
def _gcd(a, b):
|
||||
while a:
|
||||
c = a
|
||||
a = b % a
|
||||
b = c
|
||||
return b
|
||||
|
||||
def init_module(module):
|
||||
func_type = lc.Type.function(
|
||||
lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)])
|
||||
module.add_function(func_type, "__gcd64")
|
||||
funcdef = ast.parse(inspect.getsource(_gcd)).body[0]
|
||||
module.compile_function(funcdef, {"a": VInt(64), "b": VInt(64)})
|
||||
|
||||
def _call_gcd(builder, a, b):
|
||||
gcd_f = builder.basic_block.function.module.get_function_named("_gcd")
|
||||
return builder.call(gcd_f, [a, b])
|
||||
|
||||
def _frac_normalize(builder, numerator, denominator):
|
||||
gcd = _gcd64(builder, numerator, denominator)
|
||||
gcd = _call_gcd(builder, numerator, denominator)
|
||||
numerator = builder.sdiv(numerator, gcd)
|
||||
denominator = builder.sdiv(denominator, gcd)
|
||||
return numerator, denominator
|
||||
|
@ -135,12 +143,12 @@ class VFraction(VGeneric):
|
|||
numerator, denominator = self._nd(builder, invert)
|
||||
i = other.get_ssa_value(builder)
|
||||
if div:
|
||||
gcd = _gcd64(i, numerator)
|
||||
gcd = _call_gcd(builder, i, numerator)
|
||||
i = builder.sdiv(i, gcd)
|
||||
numerator = builder.sdiv(numerator, gcd)
|
||||
denominator = builder.mul(denominator, i)
|
||||
else:
|
||||
gcd = _gcd64(i, denominator)
|
||||
gcd = _call_gcd(builder, i, denominator)
|
||||
i = builder.sdiv(i, gcd)
|
||||
denominator = builder.sdiv(denominator, gcd)
|
||||
numerator = builder.mul(numerator, i)
|
||||
|
|
|
@ -11,8 +11,8 @@ class Module:
|
|||
self.env = env
|
||||
|
||||
if self.env is not None:
|
||||
self.env.init_module(self.llvm_module)
|
||||
fractions.init_module(self.llvm_module)
|
||||
self.env.init_module(self)
|
||||
fractions.init_module(self)
|
||||
|
||||
def finalize(self):
|
||||
pass_manager = lp.PassManager.new()
|
||||
|
|
|
@ -16,18 +16,6 @@ static const struct symbol syscalls[] = {
|
|||
{NULL, NULL}
|
||||
};
|
||||
|
||||
static long long int gcd64(long long int a, long long int b)
|
||||
{
|
||||
long long int c;
|
||||
|
||||
while(a) {
|
||||
c = a;
|
||||
a = b % a;
|
||||
b = c;
|
||||
}
|
||||
return b;
|
||||
}
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wimplicit-int"
|
||||
extern __divsi3, __modsi3, __ledf2, __gedf2, __unorddf2, __negsf2, __negdf2,
|
||||
|
@ -38,49 +26,51 @@ extern __divsi3, __modsi3, __ledf2, __gedf2, __unorddf2, __negsf2, __negdf2,
|
|||
__udivdi3, __umoddi3, __moddi3;
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
static const struct symbol arithmetic[] = {
|
||||
{"__divsi3", &__divsi3},
|
||||
{"__modsi3", &__modsi3},
|
||||
{"__ledf2", &__ledf2},
|
||||
{"__gedf2", &__gedf2},
|
||||
{"__unorddf2", &__gedf2},
|
||||
{"__negsf2", &__negsf2},
|
||||
{"__negdf2", &__negdf2},
|
||||
{"__addsf3", &__addsf3},
|
||||
{"__subsf3", &__subsf3},
|
||||
{"__mulsf3", &__mulsf3},
|
||||
{"__divsf3", &__divsf3},
|
||||
{"__lshrdi3", &__lshrdi3},
|
||||
{"__muldi3", &__muldi3},
|
||||
{"__divdi3", &__divdi3},
|
||||
{"__ashldi3", &__ashldi3},
|
||||
{"__ashrdi3", &__ashrdi3},
|
||||
{"__udivmoddi4", &__udivmoddi4},
|
||||
{"__floatsisf", &__floatsisf},
|
||||
{"__floatunsisf", &__floatunsisf},
|
||||
{"__fixsfsi", &__fixsfsi},
|
||||
{"__fixunssfsi", &__fixunssfsi},
|
||||
{"__adddf3", &__adddf3},
|
||||
{"__subdf3", &__subdf3},
|
||||
{"__muldf3", &__muldf3},
|
||||
{"__divdf3", &__divdf3},
|
||||
{"__floatsidf", &__floatsidf},
|
||||
{"__floatunsidf", &__floatunsidf},
|
||||
{"__floatdidf", &__floatdidf},
|
||||
{"__fixdfsi", &__fixdfsi},
|
||||
{"__fixunsdfsi", &__fixunsdfsi},
|
||||
{"__clzsi2", &__clzsi2},
|
||||
{"__ctzsi2", &__ctzsi2},
|
||||
{"__udivdi3", &__udivdi3},
|
||||
{"__umoddi3", &__umoddi3},
|
||||
{"__moddi3", &__moddi3},
|
||||
{"__gcd64", gcd64},
|
||||
static const struct symbol compiler_rt[] = {
|
||||
{"divsi3", &__divsi3},
|
||||
{"modsi3", &__modsi3},
|
||||
{"ledf2", &__ledf2},
|
||||
{"gedf2", &__gedf2},
|
||||
{"unorddf2", &__gedf2},
|
||||
{"negsf2", &__negsf2},
|
||||
{"negdf2", &__negdf2},
|
||||
{"addsf3", &__addsf3},
|
||||
{"subsf3", &__subsf3},
|
||||
{"mulsf3", &__mulsf3},
|
||||
{"divsf3", &__divsf3},
|
||||
{"lshrdi3", &__lshrdi3},
|
||||
{"muldi3", &__muldi3},
|
||||
{"divdi3", &__divdi3},
|
||||
{"ashldi3", &__ashldi3},
|
||||
{"ashrdi3", &__ashrdi3},
|
||||
{"udivmoddi4", &__udivmoddi4},
|
||||
{"floatsisf", &__floatsisf},
|
||||
{"floatunsisf", &__floatunsisf},
|
||||
{"fixsfsi", &__fixsfsi},
|
||||
{"fixunssfsi", &__fixunssfsi},
|
||||
{"adddf3", &__adddf3},
|
||||
{"subdf3", &__subdf3},
|
||||
{"muldf3", &__muldf3},
|
||||
{"divdf3", &__divdf3},
|
||||
{"floatsidf", &__floatsidf},
|
||||
{"floatunsidf", &__floatunsidf},
|
||||
{"floatdidf", &__floatdidf},
|
||||
{"fixdfsi", &__fixdfsi},
|
||||
{"fixunsdfsi", &__fixunsdfsi},
|
||||
{"clzsi2", &__clzsi2},
|
||||
{"ctzsi2", &__ctzsi2},
|
||||
{"udivdi3", &__udivdi3},
|
||||
{"umoddi3", &__umoddi3},
|
||||
{"moddi3", &__moddi3},
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
void *resolve_symbol(const char *name)
|
||||
{
|
||||
if(strncmp(name, "__syscall_", 10) == 0)
|
||||
return find_symbol(syscalls, name + 10);
|
||||
return find_symbol(arithmetic, name);
|
||||
if(strncmp(name, "__", 2) != 0)
|
||||
return NULL;
|
||||
name += 2;
|
||||
if(strncmp(name, "syscall_", 8) == 0)
|
||||
return find_symbol(syscalls, name + 8);
|
||||
return find_symbol(compiler_rt, name);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue