mirror of https://github.com/m-labs/artiq.git
compiler/ir: fraction normalization
This commit is contained in:
parent
a579b105b6
commit
a861226409
|
@ -1,7 +1,7 @@
|
||||||
from llvm import core as lc
|
from llvm import core as lc
|
||||||
from llvm import passes as lp
|
from llvm import passes as lp
|
||||||
|
|
||||||
from artiq.compiler import ir_infer_types, ir_ast_body
|
from artiq.compiler import ir_infer_types, ir_ast_body, ir_values
|
||||||
|
|
||||||
def compile_function(module, env, funcdef):
|
def compile_function(module, env, funcdef):
|
||||||
function_type = lc.Type.function(lc.Type.void(), [])
|
function_type = lc.Type.function(lc.Type.void(), [])
|
||||||
|
@ -18,7 +18,8 @@ def compile_function(module, env, funcdef):
|
||||||
|
|
||||||
def get_runtime_binary(env, funcdef):
|
def get_runtime_binary(env, funcdef):
|
||||||
module = lc.Module.new("main")
|
module = lc.Module.new("main")
|
||||||
env.set_module(module)
|
env.init_module(module)
|
||||||
|
ir_values.init_module(module)
|
||||||
|
|
||||||
compile_function(module, env, funcdef)
|
compile_function(module, env, funcdef)
|
||||||
|
|
||||||
|
|
|
@ -202,17 +202,25 @@ class VFraction(_Value):
|
||||||
if not isinstance(other, VFraction):
|
if not isinstance(other, VFraction):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def _simplify(self, builder):
|
def _nd(self, builder):
|
||||||
pass # TODO
|
ssa_value = self.get_ssa_value(builder)
|
||||||
|
numerator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 0))
|
||||||
|
denominator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 1))
|
||||||
|
return numerator, denominator
|
||||||
|
|
||||||
def set_value_nd(self, builder, numerator, denominator):
|
def set_value_nd(self, builder, numerator, denominator):
|
||||||
numerator = numerator.o_int64(builder).get_ssa_value(builder)
|
numerator = numerator.o_int64(builder).get_ssa_value(builder)
|
||||||
denominator = denominator.o_int64(builder).get_ssa_value(builder)
|
denominator = denominator.o_int64(builder).get_ssa_value(builder)
|
||||||
|
|
||||||
|
gcd_f = builder.module.get_function_named("__gcd64")
|
||||||
|
gcd = builder.call(gcd_f, [numerator, denominator])
|
||||||
|
numerator = builder.sdiv(numerator, gcd)
|
||||||
|
denominator = builder.sdiv(denominator, gcd)
|
||||||
|
|
||||||
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
|
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
|
||||||
value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0))
|
value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0))
|
||||||
value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1))
|
value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1))
|
||||||
self.set_ssa_value(builder, value)
|
self.set_ssa_value(builder, value)
|
||||||
self._simplify(builder)
|
|
||||||
|
|
||||||
def set_value(self, builder, n):
|
def set_value(self, builder, n):
|
||||||
if not isinstance(n, VFraction):
|
if not isinstance(n, VFraction):
|
||||||
|
@ -232,8 +240,7 @@ class VFraction(_Value):
|
||||||
return VInt(target_bits)
|
return VInt(target_bits)
|
||||||
else:
|
else:
|
||||||
r = VInt(64)
|
r = VInt(64)
|
||||||
numerator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
|
numerator, denominator = self._nd(builder)
|
||||||
denominator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 1))
|
|
||||||
r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
|
r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
|
||||||
return r.o_intx(target_bits, builder)
|
return r.o_intx(target_bits, builder)
|
||||||
|
|
||||||
|
@ -242,8 +249,7 @@ class VFraction(_Value):
|
||||||
return VInt(target_bits)
|
return VInt(target_bits)
|
||||||
else:
|
else:
|
||||||
r = VInt(64)
|
r = VInt(64)
|
||||||
numerator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
|
numerator, denominator = self._nd(builder)
|
||||||
denominator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 1))
|
|
||||||
h_denominator = builder.ashr(denominator, lc.Constant.int(lc.Type.int(), 1))
|
h_denominator = builder.ashr(denominator, lc.Constant.int(lc.Type.int(), 1))
|
||||||
r_numerator = builder.add(numerator, h_denominator)
|
r_numerator = builder.add(numerator, h_denominator)
|
||||||
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
|
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
|
||||||
|
@ -319,3 +325,8 @@ def _make_operators():
|
||||||
return SimpleNamespace(**d)
|
return SimpleNamespace(**d)
|
||||||
|
|
||||||
operators = _make_operators()
|
operators = _make_operators()
|
||||||
|
|
||||||
|
def init_module(module):
|
||||||
|
func_type = lc.Type.function(lc.Type.int(64),
|
||||||
|
[lc.Type.int(64), lc.Type.int(64)])
|
||||||
|
module.add_function(func_type, "__gcd64")
|
||||||
|
|
|
@ -40,7 +40,7 @@ def _str_to_functype(s):
|
||||||
return var_arg_fixcount, lc.Type.function(type_ret, type_args, var_arg=var_arg_fixcount is not None)
|
return var_arg_fixcount, lc.Type.function(type_ret, type_args, var_arg=var_arg_fixcount is not None)
|
||||||
|
|
||||||
class LinkInterface:
|
class LinkInterface:
|
||||||
def set_module(self, module):
|
def init_module(self, module):
|
||||||
self.module = module
|
self.module = module
|
||||||
self.var_arg_fixcount = dict()
|
self.var_arg_fixcount = dict()
|
||||||
for func_name, func_type_str in _syscalls.items():
|
for func_name, func_type_str in _syscalls.items():
|
||||||
|
|
Loading…
Reference in New Issue