diff --git a/artiq/gateware/dsp/sawg.py b/artiq/gateware/dsp/sawg.py index 9fc064e74..5b432cd54 100644 --- a/artiq/gateware/dsp/sawg.py +++ b/artiq/gateware/dsp/sawg.py @@ -183,9 +183,9 @@ class Channel(Module, SatAddMixin): Cat(b.yi).eq(Cat(hbf[1].o)), ] self.sync += [ - hbf[0].i.eq(self.sat_add(a1.xo[0], a2.xo[0], + hbf[0].i.eq(self.sat_add((a1.xo[0], a2.xo[0]), limits=cfg.limits[1], clipped=cfg.clipped[1])), - hbf[1].i.eq(self.sat_add(a1.yo[0], a2.yo[0], + hbf[1].i.eq(self.sat_add((a1.yo[0], a2.yo[0]), limits=cfg.limits[1], clipped=cfg.clipped[1])), ] # wire up outputs and q_{i,o} exchange @@ -199,9 +199,8 @@ class Channel(Module, SatAddMixin): o_y.eq(Mux(cfg.iq_en[1], y, 0)), ] self.sync += [ - o.eq(self.sat_add(o_offset, o_x, o_y, - limits=cfg.limits[0], - clipped=cfg.clipped[0])), + o.eq(self.sat_add((o_offset, o_x, o_y), + limits=cfg.limits[0], clipped=cfg.clipped[0])), ] def connect_y(self, buddy): diff --git a/artiq/gateware/dsp/tools.py b/artiq/gateware/dsp/tools.py index 504c245c5..a9142fa76 100644 --- a/artiq/gateware/dsp/tools.py +++ b/artiq/gateware/dsp/tools.py @@ -30,34 +30,39 @@ def eqh(a, b): class SatAddMixin: """Signed saturating addition mixin""" - def sat_add(self, *a, limits=None, clipped=None): + def sat_add(self, a, *, width=None, limits=None, clipped=None): a = list(a) # assert all(value_bits_sign(ai)[1] for ai in a) - length = max(len(ai) for ai in a) + if width is None: + width = max(value_bits_sign(ai)[0] for ai in a) carry = log2_int(len(a), need_pow2=False) - full = Signal((length + carry, True)) - limited = Signal((length, True)) + full = Signal((width + carry, True)) + limited = Signal((width, True)) clip = Signal(2) + sign = Signal() if clipped is not None: self.comb += clipped.eq(clip) self.comb += [ full.eq(reduce(add, a)), + sign.eq(full[-1]), + limited.eq(full) ] if limits is None: - sign = Signal() self.comb += [ - sign.eq(full[-1]), - If(full[-1-carry:] == Replicate(sign, carry + 1), - clip.eq(0), - limited.eq(full), - ).Else( + If(full[-1-carry:] != Replicate(sign, carry + 1), clip.eq(Cat(sign, ~sign)), - limited.eq(Cat(Replicate(~sign, length - 1), sign)), + limited.eq(Cat(Replicate(~sign, width - 1), sign)), ) ] else: self.comb += [ - clip.eq(Cat(full < limits[0], full > limits[1])), - limited.eq(Array([full, limits[0], limits[1], 0])[clip]), + If(full < limits[0], + clip.eq(0b01), + limited.eq(limits[0]) + ), + If(full > limits[1], + clip.eq(0b10), + limited.eq(limits[1]), + ) ] return limited