forked from M-Labs/artiq
compiler/ir: fraction normalization
This commit is contained in:
parent
a579b105b6
commit
a861226409
|
@ -1,7 +1,7 @@
|
|||
from llvm import core as lc
|
||||
from llvm import passes as lp
|
||||
|
||||
from artiq.compiler import ir_infer_types, ir_ast_body
|
||||
from artiq.compiler import ir_infer_types, ir_ast_body, ir_values
|
||||
|
||||
def compile_function(module, env, funcdef):
|
||||
function_type = lc.Type.function(lc.Type.void(), [])
|
||||
|
@ -18,7 +18,8 @@ def compile_function(module, env, funcdef):
|
|||
|
||||
def get_runtime_binary(env, funcdef):
|
||||
module = lc.Module.new("main")
|
||||
env.set_module(module)
|
||||
env.init_module(module)
|
||||
ir_values.init_module(module)
|
||||
|
||||
compile_function(module, env, funcdef)
|
||||
|
||||
|
|
|
@ -202,17 +202,25 @@ class VFraction(_Value):
|
|||
if not isinstance(other, VFraction):
|
||||
raise TypeError
|
||||
|
||||
def _simplify(self, builder):
|
||||
pass # TODO
|
||||
def _nd(self, builder):
|
||||
ssa_value = self.get_ssa_value(builder)
|
||||
numerator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 0))
|
||||
denominator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 1))
|
||||
return numerator, denominator
|
||||
|
||||
def set_value_nd(self, builder, numerator, denominator):
|
||||
numerator = numerator.o_int64(builder).get_ssa_value(builder)
|
||||
denominator = denominator.o_int64(builder).get_ssa_value(builder)
|
||||
|
||||
gcd_f = builder.module.get_function_named("__gcd64")
|
||||
gcd = builder.call(gcd_f, [numerator, denominator])
|
||||
numerator = builder.sdiv(numerator, gcd)
|
||||
denominator = builder.sdiv(denominator, gcd)
|
||||
|
||||
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
|
||||
value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0))
|
||||
value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1))
|
||||
self.set_ssa_value(builder, value)
|
||||
self._simplify(builder)
|
||||
|
||||
def set_value(self, builder, n):
|
||||
if not isinstance(n, VFraction):
|
||||
|
@ -232,8 +240,7 @@ class VFraction(_Value):
|
|||
return VInt(target_bits)
|
||||
else:
|
||||
r = VInt(64)
|
||||
numerator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
|
||||
denominator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 1))
|
||||
numerator, denominator = self._nd(builder)
|
||||
r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
|
||||
return r.o_intx(target_bits, builder)
|
||||
|
||||
|
@ -242,8 +249,7 @@ class VFraction(_Value):
|
|||
return VInt(target_bits)
|
||||
else:
|
||||
r = VInt(64)
|
||||
numerator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
|
||||
denominator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 1))
|
||||
numerator, denominator = self._nd(builder)
|
||||
h_denominator = builder.ashr(denominator, lc.Constant.int(lc.Type.int(), 1))
|
||||
r_numerator = builder.add(numerator, h_denominator)
|
||||
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
|
||||
|
@ -319,3 +325,8 @@ def _make_operators():
|
|||
return SimpleNamespace(**d)
|
||||
|
||||
operators = _make_operators()
|
||||
|
||||
def init_module(module):
|
||||
func_type = lc.Type.function(lc.Type.int(64),
|
||||
[lc.Type.int(64), lc.Type.int(64)])
|
||||
module.add_function(func_type, "__gcd64")
|
||||
|
|
|
@ -40,7 +40,7 @@ def _str_to_functype(s):
|
|||
return var_arg_fixcount, lc.Type.function(type_ret, type_args, var_arg=var_arg_fixcount is not None)
|
||||
|
||||
class LinkInterface:
|
||||
def set_module(self, module):
|
||||
def init_module(self, module):
|
||||
self.module = module
|
||||
self.var_arg_fixcount = dict()
|
||||
for func_name, func_type_str in _syscalls.items():
|
||||
|
|
Loading…
Reference in New Issue