From a8612264090387b9150f4fe714c59967caced395 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Thu, 28 Aug 2014 17:24:33 +0800 Subject: [PATCH] compiler/ir: fraction normalization --- artiq/compiler/ir.py | 5 +++-- artiq/compiler/ir_values.py | 25 ++++++++++++++++++------- artiq/devices/runtime.py | 2 +- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index c2d69740c..23a7b2cba 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -1,7 +1,7 @@ from llvm import core as lc from llvm import passes as lp -from artiq.compiler import ir_infer_types, ir_ast_body +from artiq.compiler import ir_infer_types, ir_ast_body, ir_values def compile_function(module, env, funcdef): function_type = lc.Type.function(lc.Type.void(), []) @@ -18,7 +18,8 @@ def compile_function(module, env, funcdef): def get_runtime_binary(env, funcdef): module = lc.Module.new("main") - env.set_module(module) + env.init_module(module) + ir_values.init_module(module) compile_function(module, env, funcdef) diff --git a/artiq/compiler/ir_values.py b/artiq/compiler/ir_values.py index b5e9ed85d..aa7737422 100644 --- a/artiq/compiler/ir_values.py +++ b/artiq/compiler/ir_values.py @@ -202,17 +202,25 @@ class VFraction(_Value): if not isinstance(other, VFraction): raise TypeError - def _simplify(self, builder): - pass # TODO + def _nd(self, builder): + ssa_value = self.get_ssa_value(builder) + numerator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 0)) + denominator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 1)) + return numerator, denominator def set_value_nd(self, builder, numerator, denominator): numerator = numerator.o_int64(builder).get_ssa_value(builder) denominator = denominator.o_int64(builder).get_ssa_value(builder) + + gcd_f = builder.module.get_function_named("__gcd64") + gcd = builder.call(gcd_f, [numerator, denominator]) + numerator = builder.sdiv(numerator, gcd) + denominator = builder.sdiv(denominator, gcd) + value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2)) value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0)) value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1)) self.set_ssa_value(builder, value) - self._simplify(builder) def set_value(self, builder, n): if not isinstance(n, VFraction): @@ -232,8 +240,7 @@ class VFraction(_Value): return VInt(target_bits) else: r = VInt(64) - numerator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0)) - denominator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 1)) + numerator, denominator = self._nd(builder) r.set_ssa_value(builder, builder.sdiv(numerator, denominator)) return r.o_intx(target_bits, builder) @@ -242,8 +249,7 @@ class VFraction(_Value): return VInt(target_bits) else: r = VInt(64) - numerator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0)) - denominator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 1)) + numerator, denominator = self._nd(builder) h_denominator = builder.ashr(denominator, lc.Constant.int(lc.Type.int(), 1)) r_numerator = builder.add(numerator, h_denominator) r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator)) @@ -319,3 +325,8 @@ def _make_operators(): return SimpleNamespace(**d) operators = _make_operators() + +def init_module(module): + func_type = lc.Type.function(lc.Type.int(64), + [lc.Type.int(64), lc.Type.int(64)]) + module.add_function(func_type, "__gcd64") diff --git a/artiq/devices/runtime.py b/artiq/devices/runtime.py index da690e3ff..a134aadd7 100644 --- a/artiq/devices/runtime.py +++ b/artiq/devices/runtime.py @@ -40,7 +40,7 @@ def _str_to_functype(s): return var_arg_fixcount, lc.Type.function(type_ret, type_args, var_arg=var_arg_fixcount is not None) class LinkInterface: - def set_module(self, module): + def init_module(self, module): self.module = module self.var_arg_fixcount = dict() for func_name, func_type_str in _syscalls.items():