forked from M-Labs/artiq
1
0
Fork 0

compiler/ir_values: implement rational mul/div

This commit is contained in:
Sebastien Bourdeauducq 2014-08-28 18:58:24 +08:00
parent 841e7cce35
commit 7e9df82e37
1 changed files with 87 additions and 14 deletions

View File

@ -188,6 +188,22 @@ class VBool(VInt):
# Fraction type # Fraction type
def _gcd64(builder, a, b):
gcd_f = builder.module.get_function_named("__gcd64")
return builder.call(gcd_f, [a, b])
def _frac_normalize(builder, numerator, denominator):
gcd = _gcd64(numerator, denominator)
numerator = builder.sdiv(numerator, gcd)
denominator = builder.sdiv(denominator, gcd)
return numerator, denominator
def _frac_make_ssa(builder, numerator, denominator):
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0))
value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1))
return value
class VFraction(_Value): class VFraction(_Value):
def get_llvm_type(self): def get_llvm_type(self):
return lc.Type.vector(lc.Type.int(64), 2) return lc.Type.vector(lc.Type.int(64), 2)
@ -202,25 +218,20 @@ class VFraction(_Value):
if not isinstance(other, VFraction): if not isinstance(other, VFraction):
raise TypeError raise TypeError
def _nd(self, builder): def _nd(self, builder, invert=False):
ssa_value = self.get_ssa_value(builder) ssa_value = self.get_ssa_value(builder)
numerator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 0)) numerator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 0))
denominator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 1)) denominator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 1))
if invert:
return denominator, numerator
else:
return numerator, denominator return numerator, denominator
def set_value_nd(self, builder, numerator, denominator): def set_value_nd(self, builder, numerator, denominator):
numerator = numerator.o_int64(builder).get_ssa_value(builder) numerator = numerator.o_int64(builder).get_ssa_value(builder)
denominator = denominator.o_int64(builder).get_ssa_value(builder) denominator = denominator.o_int64(builder).get_ssa_value(builder)
numerator, denominator = _frac_normalize(builder, numerator, denominator)
gcd_f = builder.module.get_function_named("__gcd64") self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator))
gcd = builder.call(gcd_f, [numerator, denominator])
numerator = builder.sdiv(numerator, gcd)
denominator = builder.sdiv(denominator, gcd)
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0))
value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1))
self.set_ssa_value(builder, value)
def set_value(self, builder, n): def set_value(self, builder, n):
if not isinstance(n, VFraction): if not isinstance(n, VFraction):
@ -255,7 +266,7 @@ class VFraction(_Value):
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator)) r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
return r.o_intx(target_bits, builder) return r.o_intx(target_bits, builder)
def _o_eq_inv(self, other, builder, invert): def _o_eq_inv(self, other, builder, ne):
if isinstance(other, VFraction): if isinstance(other, VFraction):
r = VBool() r = VBool()
if builder is not None: if builder is not None:
@ -265,7 +276,7 @@ class VFraction(_Value):
eo = builder.extract_element(other.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), i)) eo = builder.extract_element(other.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), i))
ee.append(builder.icmp(lc.ICMP_EQ, es, eo)) ee.append(builder.icmp(lc.ICMP_EQ, es, eo))
ssa_r = builder.and_(ee[0], ee[1]) ssa_r = builder.and_(ee[0], ee[1])
if invert: if ne:
ssa_r = builder.xor(ssa_r, lc.Constant.int(lc.Type.int(1), 1)) ssa_r = builder.xor(ssa_r, lc.Constant.int(lc.Type.int(1), 1))
r.set_ssa_value(builder, ssa_r) r.set_ssa_value(builder, ssa_r)
return r return r
@ -278,6 +289,68 @@ class VFraction(_Value):
def o_ne(self, other, builder): def o_ne(self, other, builder):
return self._o_eq_inv(other, builder, True) return self._o_eq_inv(other, builder, True)
def _o_muldiv(self, other, builder, div, invert=False):
r = VFraction()
if isinstance(other, VInt):
if builder is None:
return r
else:
numerator, denominator = self._nd(builder, invert)
i = other.get_ssa_value(builder)
if div:
gcd = _gcd64(i, numerator)
i = builder.sdiv(i, gcd)
numerator = builder.sdiv(numerator, gcd)
denominator = builder.mul(denominator, i)
else:
gcd = _gcd64(i, denominator)
i = builder.sdiv(i, gcd)
denominator = builder.sdiv(denominator, gcd)
numerator = builder.mul(numerator, i)
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator))
elif isinstance(other, VFraction):
if builder is None:
return r
else:
numerator, denominator = self._nd(builder, invert)
onumerator, odenominator = other._nd(builder)
if div:
numerator = builder.mul(numerator, odenominator)
denominator = builder.mul(denominator, onumerator)
else:
numerator = builder.mul(numerator, onumerator)
denominator = builder.mul(denominator, odenominator)
numerator, denominator = _frac_normalize(builder, numerator, denominator)
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator))
else:
return NotImplemented
def o_mul(self, other, builder):
return self._o_muldiv(other, builder, False)
def o_truediv(self, other, builder):
return self._o_muldiv(other, builder, True)
def or_mul(self, other, builder):
return self._o_muldiv(other, builder, False)
def or_truediv(self, other, builder):
return self._o_muldiv(other, builder, False, True)
def o_floordiv(self, other, builder):
r = self.o_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)
def or_floordiv(self, other, builder):
r = self.or_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)
# Operators # Operators
def _make_unary_operator(op_name): def _make_unary_operator(op_name):