forked from M-Labs/artiq
fir: automatically use transposed topology
This commit is contained in:
parent
a451b675c9
commit
115ea67860
|
@ -1,5 +1,6 @@
|
||||||
from operator import add
|
from operator import add
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
from collections import namedtuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from migen import *
|
from migen import *
|
||||||
|
|
||||||
|
@ -38,7 +39,10 @@ def halfgen4(width, n):
|
||||||
class FIR(Module):
|
class FIR(Module):
|
||||||
"""Full-rate finite impulse response filter.
|
"""Full-rate finite impulse response filter.
|
||||||
|
|
||||||
:param coefficients: integer taps.
|
Tries to use transposed form (adder chain instead of adder tree)
|
||||||
|
as much as possible.
|
||||||
|
|
||||||
|
:param coefficients: integer taps, increasing delay.
|
||||||
:param width: bit width of input and output.
|
:param width: bit width of input and output.
|
||||||
:param shift: scale factor (as power of two).
|
:param shift: scale factor (as power of two).
|
||||||
"""
|
"""
|
||||||
|
@ -47,37 +51,46 @@ class FIR(Module):
|
||||||
self.i = Signal((width, True))
|
self.i = Signal((width, True))
|
||||||
self.o = Signal((width, True))
|
self.o = Signal((width, True))
|
||||||
n = len(coefficients)
|
n = len(coefficients)
|
||||||
self.latency = (n + 1)//2 + 2
|
self.latency = n//2 + 3
|
||||||
|
|
||||||
###
|
###
|
||||||
|
|
||||||
|
if shift is None:
|
||||||
|
shift = bits_for(sum(abs(c) for c in coefficients)) - 1
|
||||||
|
|
||||||
# Delay line: increasing delay
|
# Delay line: increasing delay
|
||||||
x = [Signal((width, True)) for _ in range(n)]
|
x = [Signal((width, True)) for _ in range(n)]
|
||||||
self.sync += [xi.eq(xj) for xi, xj in zip(x, [self.i] + x)]
|
self.sync += [xi.eq(xj) for xi, xj in zip(x, [self.i] + x)]
|
||||||
|
|
||||||
if shift is None:
|
o = Signal((width + shift + 1, True))
|
||||||
shift = width - 1
|
self.comb += self.o.eq(o >> shift)
|
||||||
|
delay = -1
|
||||||
# Make products
|
# Make products
|
||||||
o = []
|
|
||||||
for i, c in enumerate(coefficients):
|
for i, c in enumerate(coefficients):
|
||||||
# simplify for halfband and symmetric filters
|
# simplify for halfband and symmetric filters
|
||||||
if c == 0 or c in coefficients[i + 1:]:
|
if not c or c in coefficients[:i]:
|
||||||
continue
|
continue
|
||||||
m = Signal((width + shift, True))
|
js = [j for j, cj in enumerate(coefficients) if cj == c]
|
||||||
self.sync += m.eq(c*reduce(add, [
|
m = Signal.like(o)
|
||||||
xj for xj, cj in zip(x[::-1], coefficients) if cj == c
|
o0, o = o, Signal.like(o)
|
||||||
]))
|
if delay < js[0]:
|
||||||
o.append(m)
|
self.sync += o0.eq(o + m)
|
||||||
|
delay += 1
|
||||||
# Make sum
|
else:
|
||||||
self.sync += self.o.eq(reduce(add, o) >> shift)
|
self.comb += o0.eq(o + m)
|
||||||
|
assert js[0] - delay >= 0
|
||||||
|
self.sync += m.eq(c*reduce(add, [x[j - delay] for j in js]))
|
||||||
|
# symmetric rounding
|
||||||
|
if shift:
|
||||||
|
self.comb += o.eq((1 << shift - 1) - 1)
|
||||||
|
|
||||||
|
|
||||||
class ParallelFIR(Module):
|
class ParallelFIR(Module):
|
||||||
"""Full-rate parallelized finite impulse response filter.
|
"""Full-rate parallelized finite impulse response filter.
|
||||||
|
|
||||||
:param coefficients: integer taps.
|
Tries to use transposed form as much as possible.
|
||||||
|
|
||||||
|
:param coefficients: integer taps, increasing delay.
|
||||||
:param parallelism: number of samples per cycle.
|
:param parallelism: number of samples per cycle.
|
||||||
:param width: bit width of input and output.
|
:param width: bit width of input and output.
|
||||||
:param shift: scale factor (as power of two).
|
:param shift: scale factor (as power of two).
|
||||||
|
@ -86,34 +99,43 @@ class ParallelFIR(Module):
|
||||||
self.width = width
|
self.width = width
|
||||||
self.parallelism = p = parallelism
|
self.parallelism = p = parallelism
|
||||||
n = len(coefficients)
|
n = len(coefficients)
|
||||||
# input and output: old to young, decreasing delay
|
# input and output: old to new, decreasing delay
|
||||||
self.i = [Signal((width, True)) for i in range(p)]
|
self.i = [Signal((width, True)) for i in range(p)]
|
||||||
self.o = [Signal((width, True)) for i in range(p)]
|
self.o = [Signal((width, True)) for i in range(p)]
|
||||||
self.latency = (n + 1)//2//parallelism + 3 # minus one sample
|
self.latency = (n + 1)//2//p + 2
|
||||||
|
# ... plus one sample
|
||||||
|
|
||||||
###
|
###
|
||||||
|
|
||||||
# Delay line: young to old, increasing delay
|
if shift is None:
|
||||||
|
shift = bits_for(sum(abs(c) for c in coefficients)) - 1
|
||||||
|
|
||||||
|
# Delay line: increasing delay
|
||||||
x = [Signal((width, True)) for _ in range(n + p - 1)]
|
x = [Signal((width, True)) for _ in range(n + p - 1)]
|
||||||
self.sync += [xi.eq(xj) for xi, xj in zip(x, self.i[::-1] + x)]
|
self.sync += [xi.eq(xj) for xi, xj in zip(x, self.i[::-1] + x)]
|
||||||
|
|
||||||
if shift is None:
|
for delay in range(p):
|
||||||
shift = width - 1
|
o = Signal((width + shift + 1, True))
|
||||||
|
self.comb += self.o[delay].eq(o >> shift)
|
||||||
for j in range(p):
|
|
||||||
# Make products
|
# Make products
|
||||||
o = []
|
|
||||||
for i, c in enumerate(coefficients):
|
for i, c in enumerate(coefficients):
|
||||||
# simplify for halfband and symmetric filters
|
# simplify for halfband and symmetric filters
|
||||||
if c == 0 or c in coefficients[i + 1:]:
|
if not c or c in coefficients[:i]:
|
||||||
continue
|
continue
|
||||||
m = Signal((width + shift, True))
|
js = [j + p - 1 for j, cj in enumerate(coefficients)
|
||||||
self.sync += m.eq(c*reduce(add, [
|
if cj == c]
|
||||||
xj for xj, cj in zip(x[-1 - j::-1], coefficients) if cj == c
|
m = Signal.like(o)
|
||||||
]))
|
o0, o = o, Signal.like(o)
|
||||||
o.append(m)
|
if delay + p <= js[0]:
|
||||||
# Make sum
|
self.sync += o0.eq(o + m)
|
||||||
self.sync += self.o[j].eq(reduce(add, o) >> shift)
|
delay += p
|
||||||
|
else:
|
||||||
|
self.comb += o0.eq(o + m)
|
||||||
|
assert js[0] - delay >= 0
|
||||||
|
self.sync += m.eq(c*reduce(add, [x[j - delay] for j in js]))
|
||||||
|
# symmetric rounding
|
||||||
|
if shift:
|
||||||
|
self.comb += o.eq((1 << shift - 1) - 1)
|
||||||
|
|
||||||
|
|
||||||
def halfgen4_cascade(rate, width, order=None):
|
def halfgen4_cascade(rate, width, order=None):
|
||||||
|
|
Loading…
Reference in New Issue