From 115ea678606809e246f1bd1956d061b7ed0ad942 Mon Sep 17 00:00:00 2001 From: Robert Jordens Date: Wed, 14 Dec 2016 19:15:50 +0100 Subject: [PATCH] fir: automatically use transposed topology --- artiq/gateware/dsp/fir.py | 86 ++++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 32 deletions(-) diff --git a/artiq/gateware/dsp/fir.py b/artiq/gateware/dsp/fir.py index 6f1535528..06bf2c6cd 100644 --- a/artiq/gateware/dsp/fir.py +++ b/artiq/gateware/dsp/fir.py @@ -1,5 +1,6 @@ from operator import add from functools import reduce +from collections import namedtuple import numpy as np from migen import * @@ -38,7 +39,10 @@ def halfgen4(width, n): class FIR(Module): """Full-rate finite impulse response filter. - :param coefficients: integer taps. + Tries to use transposed form (adder chain instead of adder tree) + as much as possible. + + :param coefficients: integer taps, increasing delay. :param width: bit width of input and output. :param shift: scale factor (as power of two). """ @@ -47,37 +51,46 @@ class FIR(Module): self.i = Signal((width, True)) self.o = Signal((width, True)) n = len(coefficients) - self.latency = (n + 1)//2 + 2 + self.latency = n//2 + 3 ### + if shift is None: + shift = bits_for(sum(abs(c) for c in coefficients)) - 1 + # Delay line: increasing delay x = [Signal((width, True)) for _ in range(n)] self.sync += [xi.eq(xj) for xi, xj in zip(x, [self.i] + x)] - if shift is None: - shift = width - 1 - + o = Signal((width + shift + 1, True)) + self.comb += self.o.eq(o >> shift) + delay = -1 # Make products - o = [] for i, c in enumerate(coefficients): # simplify for halfband and symmetric filters - if c == 0 or c in coefficients[i + 1:]: + if not c or c in coefficients[:i]: continue - m = Signal((width + shift, True)) - self.sync += m.eq(c*reduce(add, [ - xj for xj, cj in zip(x[::-1], coefficients) if cj == c - ])) - o.append(m) - - # Make sum - self.sync += self.o.eq(reduce(add, o) >> shift) + js = [j for j, cj in enumerate(coefficients) if cj == c] + m = Signal.like(o) + o0, o = o, Signal.like(o) + if delay < js[0]: + self.sync += o0.eq(o + m) + delay += 1 + else: + self.comb += o0.eq(o + m) + assert js[0] - delay >= 0 + self.sync += m.eq(c*reduce(add, [x[j - delay] for j in js])) + # symmetric rounding + if shift: + self.comb += o.eq((1 << shift - 1) - 1) class ParallelFIR(Module): """Full-rate parallelized finite impulse response filter. - :param coefficients: integer taps. + Tries to use transposed form as much as possible. + + :param coefficients: integer taps, increasing delay. :param parallelism: number of samples per cycle. :param width: bit width of input and output. :param shift: scale factor (as power of two). @@ -86,34 +99,43 @@ class ParallelFIR(Module): self.width = width self.parallelism = p = parallelism n = len(coefficients) - # input and output: old to young, decreasing delay + # input and output: old to new, decreasing delay self.i = [Signal((width, True)) for i in range(p)] self.o = [Signal((width, True)) for i in range(p)] - self.latency = (n + 1)//2//parallelism + 3 # minus one sample + self.latency = (n + 1)//2//p + 2 + # ... plus one sample ### - # Delay line: young to old, increasing delay + if shift is None: + shift = bits_for(sum(abs(c) for c in coefficients)) - 1 + + # Delay line: increasing delay x = [Signal((width, True)) for _ in range(n + p - 1)] self.sync += [xi.eq(xj) for xi, xj in zip(x, self.i[::-1] + x)] - if shift is None: - shift = width - 1 - - for j in range(p): + for delay in range(p): + o = Signal((width + shift + 1, True)) + self.comb += self.o[delay].eq(o >> shift) # Make products - o = [] for i, c in enumerate(coefficients): # simplify for halfband and symmetric filters - if c == 0 or c in coefficients[i + 1:]: + if not c or c in coefficients[:i]: continue - m = Signal((width + shift, True)) - self.sync += m.eq(c*reduce(add, [ - xj for xj, cj in zip(x[-1 - j::-1], coefficients) if cj == c - ])) - o.append(m) - # Make sum - self.sync += self.o[j].eq(reduce(add, o) >> shift) + js = [j + p - 1 for j, cj in enumerate(coefficients) + if cj == c] + m = Signal.like(o) + o0, o = o, Signal.like(o) + if delay + p <= js[0]: + self.sync += o0.eq(o + m) + delay += p + else: + self.comb += o0.eq(o + m) + assert js[0] - delay >= 0 + self.sync += m.eq(c*reduce(add, [x[j - delay] for j in js])) + # symmetric rounding + if shift: + self.comb += o.eq((1 << shift - 1) - 1) def halfgen4_cascade(rate, width, order=None):