compiler/ir: fraction normalization

This commit is contained in:
Sebastien Bourdeauducq 2014-08-28 17:24:33 +08:00
parent a579b105b6
commit a861226409
3 changed files with 22 additions and 10 deletions

View File

@ -1,7 +1,7 @@
from llvm import core as lc from llvm import core as lc
from llvm import passes as lp from llvm import passes as lp
from artiq.compiler import ir_infer_types, ir_ast_body from artiq.compiler import ir_infer_types, ir_ast_body, ir_values
def compile_function(module, env, funcdef): def compile_function(module, env, funcdef):
function_type = lc.Type.function(lc.Type.void(), []) function_type = lc.Type.function(lc.Type.void(), [])
@ -18,7 +18,8 @@ def compile_function(module, env, funcdef):
def get_runtime_binary(env, funcdef): def get_runtime_binary(env, funcdef):
module = lc.Module.new("main") module = lc.Module.new("main")
env.set_module(module) env.init_module(module)
ir_values.init_module(module)
compile_function(module, env, funcdef) compile_function(module, env, funcdef)

View File

@ -202,17 +202,25 @@ class VFraction(_Value):
if not isinstance(other, VFraction): if not isinstance(other, VFraction):
raise TypeError raise TypeError
def _simplify(self, builder): def _nd(self, builder):
pass # TODO ssa_value = self.get_ssa_value(builder)
numerator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 0))
denominator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 1))
return numerator, denominator
def set_value_nd(self, builder, numerator, denominator): def set_value_nd(self, builder, numerator, denominator):
numerator = numerator.o_int64(builder).get_ssa_value(builder) numerator = numerator.o_int64(builder).get_ssa_value(builder)
denominator = denominator.o_int64(builder).get_ssa_value(builder) denominator = denominator.o_int64(builder).get_ssa_value(builder)
gcd_f = builder.module.get_function_named("__gcd64")
gcd = builder.call(gcd_f, [numerator, denominator])
numerator = builder.sdiv(numerator, gcd)
denominator = builder.sdiv(denominator, gcd)
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2)) value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0)) value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0))
value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1)) value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1))
self.set_ssa_value(builder, value) self.set_ssa_value(builder, value)
self._simplify(builder)
def set_value(self, builder, n): def set_value(self, builder, n):
if not isinstance(n, VFraction): if not isinstance(n, VFraction):
@ -232,8 +240,7 @@ class VFraction(_Value):
return VInt(target_bits) return VInt(target_bits)
else: else:
r = VInt(64) r = VInt(64)
numerator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0)) numerator, denominator = self._nd(builder)
denominator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 1))
r.set_ssa_value(builder, builder.sdiv(numerator, denominator)) r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
return r.o_intx(target_bits, builder) return r.o_intx(target_bits, builder)
@ -242,8 +249,7 @@ class VFraction(_Value):
return VInt(target_bits) return VInt(target_bits)
else: else:
r = VInt(64) r = VInt(64)
numerator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0)) numerator, denominator = self._nd(builder)
denominator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 1))
h_denominator = builder.ashr(denominator, lc.Constant.int(lc.Type.int(), 1)) h_denominator = builder.ashr(denominator, lc.Constant.int(lc.Type.int(), 1))
r_numerator = builder.add(numerator, h_denominator) r_numerator = builder.add(numerator, h_denominator)
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator)) r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
@ -319,3 +325,8 @@ def _make_operators():
return SimpleNamespace(**d) return SimpleNamespace(**d)
operators = _make_operators() operators = _make_operators()
def init_module(module):
func_type = lc.Type.function(lc.Type.int(64),
[lc.Type.int(64), lc.Type.int(64)])
module.add_function(func_type, "__gcd64")

View File

@ -40,7 +40,7 @@ def _str_to_functype(s):
return var_arg_fixcount, lc.Type.function(type_ret, type_args, var_arg=var_arg_fixcount is not None) return var_arg_fixcount, lc.Type.function(type_ret, type_args, var_arg=var_arg_fixcount is not None)
class LinkInterface: class LinkInterface:
def set_module(self, module): def init_module(self, module):
self.module = module self.module = module
self.var_arg_fixcount = dict() self.var_arg_fixcount = dict()
for func_name, func_type_str in _syscalls.items(): for func_name, func_type_str in _syscalls.items():