1
0
forked from M-Labs/artiq

fir: automatically use transposed topology

This commit is contained in:
Robert Jördens 2016-12-14 19:15:50 +01:00
parent a451b675c9
commit 115ea67860

View File

@ -1,5 +1,6 @@
from operator import add from operator import add
from functools import reduce from functools import reduce
from collections import namedtuple
import numpy as np import numpy as np
from migen import * from migen import *
@ -38,7 +39,10 @@ def halfgen4(width, n):
class FIR(Module): class FIR(Module):
"""Full-rate finite impulse response filter. """Full-rate finite impulse response filter.
:param coefficients: integer taps. Tries to use transposed form (adder chain instead of adder tree)
as much as possible.
:param coefficients: integer taps, increasing delay.
:param width: bit width of input and output. :param width: bit width of input and output.
:param shift: scale factor (as power of two). :param shift: scale factor (as power of two).
""" """
@ -47,37 +51,46 @@ class FIR(Module):
self.i = Signal((width, True)) self.i = Signal((width, True))
self.o = Signal((width, True)) self.o = Signal((width, True))
n = len(coefficients) n = len(coefficients)
self.latency = (n + 1)//2 + 2 self.latency = n//2 + 3
### ###
if shift is None:
shift = bits_for(sum(abs(c) for c in coefficients)) - 1
# Delay line: increasing delay # Delay line: increasing delay
x = [Signal((width, True)) for _ in range(n)] x = [Signal((width, True)) for _ in range(n)]
self.sync += [xi.eq(xj) for xi, xj in zip(x, [self.i] + x)] self.sync += [xi.eq(xj) for xi, xj in zip(x, [self.i] + x)]
if shift is None: o = Signal((width + shift + 1, True))
shift = width - 1 self.comb += self.o.eq(o >> shift)
delay = -1
# Make products # Make products
o = []
for i, c in enumerate(coefficients): for i, c in enumerate(coefficients):
# simplify for halfband and symmetric filters # simplify for halfband and symmetric filters
if c == 0 or c in coefficients[i + 1:]: if not c or c in coefficients[:i]:
continue continue
m = Signal((width + shift, True)) js = [j for j, cj in enumerate(coefficients) if cj == c]
self.sync += m.eq(c*reduce(add, [ m = Signal.like(o)
xj for xj, cj in zip(x[::-1], coefficients) if cj == c o0, o = o, Signal.like(o)
])) if delay < js[0]:
o.append(m) self.sync += o0.eq(o + m)
delay += 1
# Make sum else:
self.sync += self.o.eq(reduce(add, o) >> shift) self.comb += o0.eq(o + m)
assert js[0] - delay >= 0
self.sync += m.eq(c*reduce(add, [x[j - delay] for j in js]))
# symmetric rounding
if shift:
self.comb += o.eq((1 << shift - 1) - 1)
class ParallelFIR(Module): class ParallelFIR(Module):
"""Full-rate parallelized finite impulse response filter. """Full-rate parallelized finite impulse response filter.
:param coefficients: integer taps. Tries to use transposed form as much as possible.
:param coefficients: integer taps, increasing delay.
:param parallelism: number of samples per cycle. :param parallelism: number of samples per cycle.
:param width: bit width of input and output. :param width: bit width of input and output.
:param shift: scale factor (as power of two). :param shift: scale factor (as power of two).
@ -86,34 +99,43 @@ class ParallelFIR(Module):
self.width = width self.width = width
self.parallelism = p = parallelism self.parallelism = p = parallelism
n = len(coefficients) n = len(coefficients)
# input and output: old to young, decreasing delay # input and output: old to new, decreasing delay
self.i = [Signal((width, True)) for i in range(p)] self.i = [Signal((width, True)) for i in range(p)]
self.o = [Signal((width, True)) for i in range(p)] self.o = [Signal((width, True)) for i in range(p)]
self.latency = (n + 1)//2//parallelism + 3 # minus one sample self.latency = (n + 1)//2//p + 2
# ... plus one sample
### ###
# Delay line: young to old, increasing delay if shift is None:
shift = bits_for(sum(abs(c) for c in coefficients)) - 1
# Delay line: increasing delay
x = [Signal((width, True)) for _ in range(n + p - 1)] x = [Signal((width, True)) for _ in range(n + p - 1)]
self.sync += [xi.eq(xj) for xi, xj in zip(x, self.i[::-1] + x)] self.sync += [xi.eq(xj) for xi, xj in zip(x, self.i[::-1] + x)]
if shift is None: for delay in range(p):
shift = width - 1 o = Signal((width + shift + 1, True))
self.comb += self.o[delay].eq(o >> shift)
for j in range(p):
# Make products # Make products
o = []
for i, c in enumerate(coefficients): for i, c in enumerate(coefficients):
# simplify for halfband and symmetric filters # simplify for halfband and symmetric filters
if c == 0 or c in coefficients[i + 1:]: if not c or c in coefficients[:i]:
continue continue
m = Signal((width + shift, True)) js = [j + p - 1 for j, cj in enumerate(coefficients)
self.sync += m.eq(c*reduce(add, [ if cj == c]
xj for xj, cj in zip(x[-1 - j::-1], coefficients) if cj == c m = Signal.like(o)
])) o0, o = o, Signal.like(o)
o.append(m) if delay + p <= js[0]:
# Make sum self.sync += o0.eq(o + m)
self.sync += self.o[j].eq(reduce(add, o) >> shift) delay += p
else:
self.comb += o0.eq(o + m)
assert js[0] - delay >= 0
self.sync += m.eq(c*reduce(add, [x[j - delay] for j in js]))
# symmetric rounding
if shift:
self.comb += o.eq((1 << shift - 1) - 1)
def halfgen4_cascade(rate, width, order=None): def halfgen4_cascade(rate, width, order=None):