diff --git a/artiq/language/units.py b/artiq/language/units.py index f79107ead..1fed268d8 100644 --- a/artiq/language/units.py +++ b/artiq/language/units.py @@ -5,12 +5,49 @@ _prefixes_str = "pnum_kMG" _smallest_prefix = _Fraction(1, 10**12) -class DimensionError(Exception): - """Exception raised when attempting operations on incompatible units - (e.g. adding seconds and hertz). +def mul_dimension(l, r): + if l is None: + return r + if r is None: + return l + if {l, r} == {"Hz", "s"}: + return None - """ - pass + +def _rmul_dimension(l, r): + return mul_dimension(r, l) + + +def div_dimension(l, r): + if l == r: + return None + if r is None: + return l + if l is None: + if r == "s": + return "Hz" + if r == "Hz": + return "s" + + +def _rdiv_dimension(l, r): + return div_dimension(r, l) + + +def addsub_dimension(x, y): + if x == y: + return x + else: + return None + + +def _format(amount, unit): + if amount is NotImplemented: + return NotImplemented + if unit is None: + return amount + else: + return Quantity(amount, unit) class Quantity: @@ -46,92 +83,77 @@ class Quantity: return str(r_amount) + " " + self.unit # mul/div - def __mul__(self, other): + def _binop(self, other, opf_name, dim_function): + opf = getattr(self.amount, opf_name) if isinstance(other, Quantity): - return NotImplemented - return Quantity(self.amount*other, self.unit) + amount = opf(other.amount) + unit = dim_function(self.unit, other.unit) + else: + amount = opf(other) + unit = dim_function(self.unit, None) + return _format(amount, unit) + + def __mul__(self, other): + return self._binop(other, "__mul__", mul_dimension) def __rmul__(self, other): - if isinstance(other, Quantity): - return NotImplemented - return Quantity(other*self.amount, self.unit) + return self._binop(other, "__rmul__", _rmul_dimension) def __truediv__(self, other): - if isinstance(other, Quantity): - if other.unit == self.unit: - return self.amount/other.amount - else: - return NotImplemented - else: - return Quantity(self.amount/other, self.unit) + return self._binop(other, "__truediv__", div_dimension) + + def __rtruediv__(self, other): + return self._binop(other, "__rtruediv__", _rdiv_dimension) def __floordiv__(self, other): - if isinstance(other, Quantity): - if other.unit == self.unit: - return self.amount//other.amount - else: - return NotImplemented - else: - return Quantity(self.amount//other, self.unit) + return self._binop(other, "__floordiv__", div_dimension) + + def __rfloordiv__(self, other): + return self._binop(other, "__rfloordiv__", _rdiv_dimension) # unary ops def __neg__(self): - return Quantity(-self.amount, self.unit) + return Quantity(self.amount.__neg__(), self.unit) def __pos__(self): - return Quantity(self.amount, self.unit) + return Quantity(self.amount.__pos__(), self.unit) # add/sub def __add__(self, other): - if self.unit != other.unit: - raise DimensionError - return Quantity(self.amount + other.amount, self.unit) + return self._binop(other, "__add__", addsub_dimension) def __radd__(self, other): - if self.unit != other.unit: - raise DimensionError - return Quantity(other.amount + self.amount, self.unit) + return self._binop(other, "__radd__", addsub_dimension) def __sub__(self, other): - if self.unit != other.unit: - raise DimensionError - return Quantity(self.amount - other.amount, self.unit) + return self._binop(other, "__sub__", addsub_dimension) def __rsub__(self, other): - if self.unit != other.unit: - raise DimensionError - return Quantity(other.amount - self.amount, self.unit) + return self._binop(other, "__rsub__", addsub_dimension) # comparisons + def _cmp(self, other, opf_name): + if isinstance(other, Quantity): + other = other.amount + return getattr(self.amount, opf_name)(other) + def __lt__(self, other): - if self.unit != other.unit: - raise DimensionError - return self.amount < other.amount + return self._cmp(other, "__lt__") def __le__(self, other): - if self.unit != other.unit: - raise DimensionError - return self.amount <= other.amount + return self._cmp(other, "__le__") def __eq__(self, other): - if self.unit != other.unit: - raise DimensionError - return self.amount == other.amount + return self._cmp(other, "__eq__") def __ne__(self, other): - if self.unit != other.unit: - raise DimensionError - return self.amount != other.amount + return self._cmp(other, "__ne__") def __gt__(self, other): - if self.unit != other.unit: - raise DimensionError - return self.amount > other.amount + return self._cmp(other, "__gt__") def __ge__(self, other): - if self.unit != other.unit: - raise DimensionError - return self.amount >= other.amount + return self._cmp(other, "__ge__") def _register_unit(unit, prefixes):