diff --git a/artiq/compiler/ir_values.py b/artiq/compiler/ir_values.py index 72b209a37..ac34d3c3f 100644 --- a/artiq/compiler/ir_values.py +++ b/artiq/compiler/ir_values.py @@ -188,6 +188,22 @@ class VBool(VInt): # Fraction type +def _gcd64(builder, a, b): + gcd_f = builder.module.get_function_named("__gcd64") + return builder.call(gcd_f, [a, b]) + +def _frac_normalize(builder, numerator, denominator): + gcd = _gcd64(numerator, denominator) + numerator = builder.sdiv(numerator, gcd) + denominator = builder.sdiv(denominator, gcd) + return numerator, denominator + +def _frac_make_ssa(builder, numerator, denominator): + value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2)) + value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0)) + value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1)) + return value + class VFraction(_Value): def get_llvm_type(self): return lc.Type.vector(lc.Type.int(64), 2) @@ -202,25 +218,20 @@ class VFraction(_Value): if not isinstance(other, VFraction): raise TypeError - def _nd(self, builder): + def _nd(self, builder, invert=False): ssa_value = self.get_ssa_value(builder) numerator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 0)) denominator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 1)) - return numerator, denominator + if invert: + return denominator, numerator + else: + return numerator, denominator def set_value_nd(self, builder, numerator, denominator): numerator = numerator.o_int64(builder).get_ssa_value(builder) denominator = denominator.o_int64(builder).get_ssa_value(builder) - - gcd_f = builder.module.get_function_named("__gcd64") - gcd = builder.call(gcd_f, [numerator, denominator]) - numerator = builder.sdiv(numerator, gcd) - denominator = builder.sdiv(denominator, gcd) - - value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2)) - value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0)) - value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1)) - self.set_ssa_value(builder, value) + numerator, denominator = _frac_normalize(builder, numerator, denominator) + self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator)) def set_value(self, builder, n): if not isinstance(n, VFraction): @@ -255,7 +266,7 @@ class VFraction(_Value): r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator)) return r.o_intx(target_bits, builder) - def _o_eq_inv(self, other, builder, invert): + def _o_eq_inv(self, other, builder, ne): if isinstance(other, VFraction): r = VBool() if builder is not None: @@ -265,7 +276,7 @@ class VFraction(_Value): eo = builder.extract_element(other.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), i)) ee.append(builder.icmp(lc.ICMP_EQ, es, eo)) ssa_r = builder.and_(ee[0], ee[1]) - if invert: + if ne: ssa_r = builder.xor(ssa_r, lc.Constant.int(lc.Type.int(1), 1)) r.set_ssa_value(builder, ssa_r) return r @@ -278,6 +289,68 @@ class VFraction(_Value): def o_ne(self, other, builder): return self._o_eq_inv(other, builder, True) + def _o_muldiv(self, other, builder, div, invert=False): + r = VFraction() + if isinstance(other, VInt): + if builder is None: + return r + else: + numerator, denominator = self._nd(builder, invert) + i = other.get_ssa_value(builder) + if div: + gcd = _gcd64(i, numerator) + i = builder.sdiv(i, gcd) + numerator = builder.sdiv(numerator, gcd) + denominator = builder.mul(denominator, i) + else: + gcd = _gcd64(i, denominator) + i = builder.sdiv(i, gcd) + denominator = builder.sdiv(denominator, gcd) + numerator = builder.mul(numerator, i) + self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator)) + elif isinstance(other, VFraction): + if builder is None: + return r + else: + numerator, denominator = self._nd(builder, invert) + onumerator, odenominator = other._nd(builder) + if div: + numerator = builder.mul(numerator, odenominator) + denominator = builder.mul(denominator, onumerator) + else: + numerator = builder.mul(numerator, onumerator) + denominator = builder.mul(denominator, odenominator) + numerator, denominator = _frac_normalize(builder, numerator, denominator) + self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator)) + else: + return NotImplemented + + def o_mul(self, other, builder): + return self._o_muldiv(other, builder, False) + + def o_truediv(self, other, builder): + return self._o_muldiv(other, builder, True) + + def or_mul(self, other, builder): + return self._o_muldiv(other, builder, False) + + def or_truediv(self, other, builder): + return self._o_muldiv(other, builder, False, True) + + def o_floordiv(self, other, builder): + r = self.o_truediv(other, builder) + if r is NotImplemented: + return r + else: + return r.o_int(builder) + + def or_floordiv(self, other, builder): + r = self.or_truediv(other, builder) + if r is NotImplemented: + return r + else: + return r.o_int(builder) + # Operators def _make_unary_operator(op_name):