forked from M-Labs/artiq
fir: streamline, optimize DSP extraction, left-align inputs
This commit is contained in:
parent
cfb66117af
commit
f5f662200b
@ -1,6 +1,10 @@
|
||||
from math import floor
|
||||
from operator import add
|
||||
from functools import reduce
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from migen import *
|
||||
|
||||
|
||||
@ -40,56 +44,11 @@ def halfgen4(width, n, df=1e-3):
|
||||
return a
|
||||
|
||||
|
||||
class FIR(Module):
|
||||
"""Full-rate finite impulse response filter.
|
||||
_Widths = namedtuple("_Widths", "A B P")
|
||||
|
||||
Tries to use transposed form (adder chain instead of adder tree)
|
||||
as much as possible.
|
||||
|
||||
:param coefficients: integer taps, increasing delay.
|
||||
:param width: bit width of input and output.
|
||||
:param shift: scale factor (as power of two).
|
||||
"""
|
||||
def __init__(self, coefficients, width=16, shift=None):
|
||||
self.width = width
|
||||
self.i = Signal((width, True))
|
||||
self.o = Signal((width, True))
|
||||
n = len(coefficients)
|
||||
self.latency = n//2 + 3
|
||||
|
||||
###
|
||||
|
||||
if shift is None:
|
||||
shift = bits_for(sum(abs(c) for c in coefficients)) - 1
|
||||
|
||||
# Delay line: increasing delay
|
||||
x = [Signal((width, True)) for _ in range(n)]
|
||||
self.sync += [xi.eq(xj) for xi, xj in zip(x, [self.i] + x)]
|
||||
|
||||
o = Signal((width + shift + 1, True))
|
||||
self.comb += self.o.eq(o >> shift)
|
||||
delay = -1
|
||||
# Make products
|
||||
for i, c in enumerate(coefficients):
|
||||
# simplify for halfband and symmetric filters
|
||||
if not c or c in coefficients[:i]:
|
||||
continue
|
||||
js = [j for j, cj in enumerate(coefficients) if cj == c]
|
||||
m = Signal.like(o)
|
||||
o0, o = o, Signal.like(o)
|
||||
if delay < js[0]:
|
||||
self.sync += o0.eq(o + m)
|
||||
delay += 1
|
||||
else:
|
||||
self.comb += o0.eq(o + m)
|
||||
assert js[0] - delay >= 0
|
||||
xs = [x[j - delay] for j in js]
|
||||
s = Signal((bits_for(len(xs)) - 1 + len(xs[0]), True))
|
||||
self.comb += s.eq(sum(xs))
|
||||
self.sync += m.eq(c*s)
|
||||
# symmetric rounding
|
||||
if shift:
|
||||
self.comb += o.eq((1 << shift - 1) - 1)
|
||||
_widths = {
|
||||
"DSP48E1": _Widths(25, 18, 48),
|
||||
}
|
||||
|
||||
|
||||
class ParallelFIR(Module):
|
||||
@ -97,12 +56,14 @@ class ParallelFIR(Module):
|
||||
|
||||
Tries to use transposed form as much as possible.
|
||||
|
||||
:param coefficients: integer taps, increasing delay.
|
||||
:param coefficients: tap coefficients (normalized to 1.),
|
||||
increasing delay.
|
||||
:param parallelism: number of samples per cycle.
|
||||
:param width: bit width of input and output.
|
||||
:param shift: scale factor (as power of two).
|
||||
:param arch: architecture (default: "DSP48E1").
|
||||
"""
|
||||
def __init__(self, coefficients, parallelism, width=16, shift=None):
|
||||
def __init__(self, coefficients, parallelism, width=16,
|
||||
arch="DSP48E1"):
|
||||
self.width = width
|
||||
self.parallelism = p = parallelism
|
||||
n = len(coefficients)
|
||||
@ -111,45 +72,60 @@ class ParallelFIR(Module):
|
||||
self.o = [Signal((width, True)) for i in range(p)]
|
||||
self.latency = (n + 1)//2//p + 2
|
||||
# ... plus one sample
|
||||
w = _widths[arch]
|
||||
|
||||
c_max = max(abs(c) for c in coefficients)
|
||||
c_shift = bits_for(floor((1 << w.B - 2) / c_max))
|
||||
self.coefficients = cs = [int(round(c*(1 << c_shift)))
|
||||
for c in coefficients]
|
||||
|
||||
###
|
||||
|
||||
if shift is None:
|
||||
shift = bits_for(sum(abs(c) for c in coefficients)) - 1
|
||||
|
||||
# Delay line: increasing delay
|
||||
x = [Signal((width, True)) for _ in range(n + p - 1)]
|
||||
self.sync += [xi.eq(xj) for xi, xj in zip(x, self.i[::-1] + x)]
|
||||
x = [Signal((w.A, True)) for _ in range(n + p - 1)]
|
||||
x_shift = w.A - width - bits_for(
|
||||
max(cs.count(c) for c in cs if c) - 1)
|
||||
for xi, xj in zip(x, self.i[::-1]):
|
||||
self.sync += xi.eq(xj << x_shift)
|
||||
for xi, xj in zip(x[len(self.i):], x):
|
||||
self.sync += xi.eq(xj)
|
||||
|
||||
for delay in range(p):
|
||||
o = Signal((width + shift + 1, True))
|
||||
self.comb += self.o[delay].eq(o >> shift)
|
||||
o = Signal((w.P, True))
|
||||
self.comb += self.o[delay].eq(o >> c_shift + x_shift)
|
||||
# Make products
|
||||
for i, c in enumerate(coefficients):
|
||||
for i, c in enumerate(cs):
|
||||
# simplify for halfband and symmetric filters
|
||||
if not c or c in coefficients[:i]:
|
||||
if not c or c in cs[:i]:
|
||||
continue
|
||||
js = [j + p - 1 for j, cj in enumerate(coefficients)
|
||||
if cj == c]
|
||||
js = [j + p - 1 for j, cj in enumerate(cs) if cj == c]
|
||||
m = Signal.like(o)
|
||||
o0, o = o, Signal.like(o)
|
||||
q = Signal.like(x[0])
|
||||
if delay + p <= js[0]:
|
||||
self.sync += o0.eq(o + m)
|
||||
delay += p
|
||||
else:
|
||||
self.comb += o0.eq(o + m)
|
||||
assert js[0] - delay >= 0
|
||||
xs = [x[j - delay] for j in js]
|
||||
s = Signal((bits_for(len(xs)) - 1 + len(xs[0]), True))
|
||||
self.comb += s.eq(sum(xs))
|
||||
self.sync += m.eq(c*s)
|
||||
self.comb += q.eq(reduce(add, [x[j - delay] for j in js]))
|
||||
self.sync += m.eq(c*q)
|
||||
# symmetric rounding
|
||||
if shift:
|
||||
self.comb += o.eq((1 << shift - 1) - 1)
|
||||
if c_shift + x_shift > 1:
|
||||
self.comb += o.eq((1 << c_shift + x_shift - 1) - 1)
|
||||
|
||||
|
||||
class FIR(ParallelFIR):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(self, *args, parallelism=1, **kwargs)
|
||||
self.i = self.i[0]
|
||||
self.o = self.o[0]
|
||||
|
||||
|
||||
def halfgen4_cascade(rate, width, order=None):
|
||||
"""Generate coefficients for cascaded half-band filters.
|
||||
Coefficients are normalized to a gain of two per stage to compensate for
|
||||
the zero stuffing.
|
||||
|
||||
:param rate: upsampling rate. power of two
|
||||
:param width: passband/stopband width in units of input sampling rate.
|
||||
@ -160,7 +136,7 @@ def halfgen4_cascade(rate, width, order=None):
|
||||
p = 1
|
||||
while p < rate:
|
||||
p *= 2
|
||||
coeff.append(halfgen4(width*p/rate/2, order*p//rate))
|
||||
coeff.append(2*halfgen4(width*p/rate/2, order*p//rate))
|
||||
return coeff
|
||||
|
||||
|
||||
@ -170,8 +146,8 @@ class ParallelHBFUpsampler(Module):
|
||||
Coefficients should be normalized to overall gain of 2
|
||||
(highest/center coefficient being 1)."""
|
||||
def __init__(self, coefficients, width=16, **kwargs):
|
||||
self.parallelism = 1
|
||||
self.latency = 0
|
||||
self.parallelism = 1 # accumulate
|
||||
self.latency = 0 # accumulate
|
||||
self.width = width
|
||||
self.i = Signal((width, True))
|
||||
|
||||
@ -180,7 +156,6 @@ class ParallelHBFUpsampler(Module):
|
||||
i = [self.i]
|
||||
for coeff in coefficients:
|
||||
self.parallelism *= 2
|
||||
# assert coeff[len(coeff)//2 + 1] == 1
|
||||
hbf = ParallelFIR(coeff, self.parallelism, width, **kwargs)
|
||||
self.submodules += hbf
|
||||
self.comb += [a.eq(b) for a, b in zip(hbf.i[::2], i)]
|
||||
|
@ -128,10 +128,8 @@ class Channel(Module, SatAddMixin):
|
||||
|
||||
self.submodules.a1 = a1 = SplineParallelDDS(widths, orders)
|
||||
self.submodules.a2 = a2 = SplineParallelDDS(widths, orders)
|
||||
coeff = [[int(round((1 << 18)*ci)) for ci in c]
|
||||
for c in halfgen4_cascade(parallelism, width=.4, order=8)]
|
||||
hbf = [ParallelHBFUpsampler(coeff, width=width, shift=17)
|
||||
for i in range(2)]
|
||||
coeff = halfgen4_cascade(parallelism, width=.4, order=8)
|
||||
hbf = [ParallelHBFUpsampler(coeff, width=width) for i in range(2)]
|
||||
self.submodules.b = b = SplineParallelDUC(
|
||||
widths._replace(a=len(hbf[0].o[0]), f=widths.f - width), orders,
|
||||
parallelism=parallelism)
|
||||
|
@ -11,16 +11,16 @@ class Transfer(Module):
|
||||
self.submodules.dut = dut
|
||||
|
||||
def drive(self, x):
|
||||
for xi in x:
|
||||
yield self.dut.i.eq(int(xi))
|
||||
for xi in x.reshape(-1, self.dut.parallelism):
|
||||
yield [ij.eq(int(xj)) for ij, xj in zip(self.dut.i, xi)]
|
||||
yield
|
||||
|
||||
def record(self, y):
|
||||
for i in range(self.dut.latency):
|
||||
yield
|
||||
for i in range(len(y)):
|
||||
for yi in y.reshape(-1, self.dut.parallelism):
|
||||
yield
|
||||
y[i] = (yield self.dut.o)
|
||||
yi[:] = (yield from [(yield o) for o in self.dut.o])
|
||||
|
||||
def run(self, samples, amplitude=1.):
|
||||
w = 2**(self.dut.width - 1) - 1
|
||||
@ -63,21 +63,7 @@ class Transfer(Module):
|
||||
return fig
|
||||
|
||||
|
||||
class ParallelTransfer(Transfer):
|
||||
def drive(self, x):
|
||||
for xi in x.reshape(-1, self.dut.parallelism):
|
||||
yield [ij.eq(int(xj)) for ij, xj in zip(self.dut.i, xi)]
|
||||
yield
|
||||
|
||||
def record(self, y):
|
||||
for i in range(self.dut.latency):
|
||||
yield
|
||||
for yi in y.reshape(-1, self.dut.parallelism):
|
||||
yield
|
||||
yi[:] = (yield from [(yield o) for o in self.dut.o])
|
||||
|
||||
|
||||
class UpTransfer(ParallelTransfer):
|
||||
class UpTransfer(Transfer):
|
||||
def drive(self, x):
|
||||
x = x.reshape(-1, len(self.dut.o))
|
||||
x[:, 1:] = 0
|
||||
@ -94,21 +80,15 @@ class UpTransfer(ParallelTransfer):
|
||||
|
||||
|
||||
def _main():
|
||||
coeff = fir.halfgen4(.4/2, 8)
|
||||
coeff_int = [int(round(c * (1 << 16 - 1))) for c in coeff]
|
||||
if False:
|
||||
coeff = [[int(round((1 << 19) * ci)) for ci in c]
|
||||
for c in fir.halfgen4_cascade(8, width=.4, order=8)]
|
||||
dut = fir.ParallelHBFUpsampler(coeff, width=16, shift=18)
|
||||
if True:
|
||||
coeff = fir.halfgen4_cascade(8, width=.4, order=8)
|
||||
dut = fir.ParallelHBFUpsampler(coeff, width=16)
|
||||
# print(verilog.convert(dut, ios=set([dut.i] + dut.o)))
|
||||
tb = UpTransfer(dut)
|
||||
elif True:
|
||||
dut = fir.ParallelFIR(coeff_int, parallelism=4, width=16)
|
||||
# print(verilog.convert(dut, ios=set(dut.i + dut.o)))
|
||||
tb = ParallelTransfer(dut)
|
||||
else:
|
||||
dut = fir.FIR(coeff_int, width=16)
|
||||
# print(verilog.convert(dut, ios={dut.i, dut.o}))
|
||||
coeff = fir.halfgen4(.4/2, 8)
|
||||
dut = fir.ParallelFIR(coeff, parallelism=4, width=16)
|
||||
# print(verilog.convert(dut, ios=set(dut.i + dut.o)))
|
||||
tb = Transfer(dut)
|
||||
|
||||
x, y = tb.run(samples=1 << 10, amplitude=.5)
|
||||
|
Loading…
Reference in New Issue
Block a user