language/units: better support for ops on different dimensions

This commit is contained in:
Sebastien Bourdeauducq 2014-10-05 23:15:15 +08:00
parent 6c44fe0a87
commit 70fc0f6ce7

View File

@ -5,12 +5,49 @@ _prefixes_str = "pnum_kMG"
_smallest_prefix = _Fraction(1, 10**12)
class DimensionError(Exception):
"""Exception raised when attempting operations on incompatible units
(e.g. adding seconds and hertz).
def mul_dimension(l, r):
if l is None:
return r
if r is None:
return l
if {l, r} == {"Hz", "s"}:
return None
"""
pass
def _rmul_dimension(l, r):
return mul_dimension(r, l)
def div_dimension(l, r):
if l == r:
return None
if r is None:
return l
if l is None:
if r == "s":
return "Hz"
if r == "Hz":
return "s"
def _rdiv_dimension(l, r):
return div_dimension(r, l)
def addsub_dimension(x, y):
if x == y:
return x
else:
return None
def _format(amount, unit):
if amount is NotImplemented:
return NotImplemented
if unit is None:
return amount
else:
return Quantity(amount, unit)
class Quantity:
@ -46,92 +83,77 @@ class Quantity:
return str(r_amount) + " " + self.unit
# mul/div
def __mul__(self, other):
def _binop(self, other, opf_name, dim_function):
opf = getattr(self.amount, opf_name)
if isinstance(other, Quantity):
return NotImplemented
return Quantity(self.amount*other, self.unit)
amount = opf(other.amount)
unit = dim_function(self.unit, other.unit)
else:
amount = opf(other)
unit = dim_function(self.unit, None)
return _format(amount, unit)
def __mul__(self, other):
return self._binop(other, "__mul__", mul_dimension)
def __rmul__(self, other):
if isinstance(other, Quantity):
return NotImplemented
return Quantity(other*self.amount, self.unit)
return self._binop(other, "__rmul__", _rmul_dimension)
def __truediv__(self, other):
if isinstance(other, Quantity):
if other.unit == self.unit:
return self.amount/other.amount
else:
return NotImplemented
else:
return Quantity(self.amount/other, self.unit)
return self._binop(other, "__truediv__", div_dimension)
def __rtruediv__(self, other):
return self._binop(other, "__rtruediv__", _rdiv_dimension)
def __floordiv__(self, other):
if isinstance(other, Quantity):
if other.unit == self.unit:
return self.amount//other.amount
else:
return NotImplemented
else:
return Quantity(self.amount//other, self.unit)
return self._binop(other, "__floordiv__", div_dimension)
def __rfloordiv__(self, other):
return self._binop(other, "__rfloordiv__", _rdiv_dimension)
# unary ops
def __neg__(self):
return Quantity(-self.amount, self.unit)
return Quantity(self.amount.__neg__(), self.unit)
def __pos__(self):
return Quantity(self.amount, self.unit)
return Quantity(self.amount.__pos__(), self.unit)
# add/sub
def __add__(self, other):
if self.unit != other.unit:
raise DimensionError
return Quantity(self.amount + other.amount, self.unit)
return self._binop(other, "__add__", addsub_dimension)
def __radd__(self, other):
if self.unit != other.unit:
raise DimensionError
return Quantity(other.amount + self.amount, self.unit)
return self._binop(other, "__radd__", addsub_dimension)
def __sub__(self, other):
if self.unit != other.unit:
raise DimensionError
return Quantity(self.amount - other.amount, self.unit)
return self._binop(other, "__sub__", addsub_dimension)
def __rsub__(self, other):
if self.unit != other.unit:
raise DimensionError
return Quantity(other.amount - self.amount, self.unit)
return self._binop(other, "__rsub__", addsub_dimension)
# comparisons
def _cmp(self, other, opf_name):
if isinstance(other, Quantity):
other = other.amount
return getattr(self.amount, opf_name)(other)
def __lt__(self, other):
if self.unit != other.unit:
raise DimensionError
return self.amount < other.amount
return self._cmp(other, "__lt__")
def __le__(self, other):
if self.unit != other.unit:
raise DimensionError
return self.amount <= other.amount
return self._cmp(other, "__le__")
def __eq__(self, other):
if self.unit != other.unit:
raise DimensionError
return self.amount == other.amount
return self._cmp(other, "__eq__")
def __ne__(self, other):
if self.unit != other.unit:
raise DimensionError
return self.amount != other.amount
return self._cmp(other, "__ne__")
def __gt__(self, other):
if self.unit != other.unit:
raise DimensionError
return self.amount > other.amount
return self._cmp(other, "__gt__")
def __ge__(self, other):
if self.unit != other.unit:
raise DimensionError
return self.amount >= other.amount
return self._cmp(other, "__ge__")
def _register_unit(unit, prefixes):