1
0
forked from M-Labs/artiq

gateware: fix drtio/dma tests

This commit is contained in:
occheung 2021-11-08 13:05:55 +08:00 committed by Sébastien Bourdeauducq
parent 02119282b8
commit 09945ecc4d
2 changed files with 85 additions and 60 deletions

View File

@ -33,63 +33,67 @@ class Loopback(Module):
class TB(Module): class TB(Module):
def __init__(self, nwords): def __init__(self, nwords, dw):
self.submodules.link_layer = Loopback(nwords) self.submodules.link_layer = Loopback(nwords)
self.submodules.aux_controller = ClockDomainsRenamer( self.submodules.aux_controller = ClockDomainsRenamer(
{"rtio": "sys", "rtio_rx": "sys"})(DRTIOAuxController(self.link_layer)) {"rtio": "sys", "rtio_rx": "sys"})(DRTIOAuxController(self.link_layer, dw))
class TestAuxController(unittest.TestCase): class TestAuxController(unittest.TestCase):
def test_aux_controller(self): def test_aux_controller(self):
dut = TB(4) dut = {
32: TB(4, 32),
64: TB(4, 64)
}
def link_init(): def link_init(dw):
for i in range(8): for i in range(8):
yield yield
yield dut.link_layer.ready.eq(1) yield dut[dw].link_layer.ready.eq(1)
def send_packet(packet): def send_packet(packet, dw):
for i, d in enumerate(packet): for i, d in enumerate(packet):
yield from dut.aux_controller.bus.write(i, d) yield from dut[dw].aux_controller.bus.write(i, d)
yield from dut.aux_controller.transmitter.aux_tx_length.write(len(packet)*4) yield from dut[dw].aux_controller.transmitter.aux_tx_length.write(len(packet)*dw//8)
yield from dut.aux_controller.transmitter.aux_tx.write(1) yield from dut[dw].aux_controller.transmitter.aux_tx.write(1)
yield yield
while (yield from dut.aux_controller.transmitter.aux_tx.read()): while (yield from dut[dw].aux_controller.transmitter.aux_tx.read()):
yield yield
def receive_packet(): def receive_packet(dw):
while not (yield from dut.aux_controller.receiver.aux_rx_present.read()): while not (yield from dut[dw].aux_controller.receiver.aux_rx_present.read()):
yield yield
length = yield from dut.aux_controller.receiver.aux_rx_length.read() length = yield from dut[dw].aux_controller.receiver.aux_rx_length.read()
r = [] r = []
for i in range(length//4): for i in range(length//(dw//8)):
r.append((yield from dut.aux_controller.bus.read(256+i))) r.append((yield from dut[dw].aux_controller.bus.read(256+i)))
yield from dut.aux_controller.receiver.aux_rx_present.write(1) yield from dut[dw].aux_controller.receiver.aux_rx_present.write(1)
return r return r
prng = random.Random(0) prng = random.Random(0)
def send_and_check_packet(): def send_and_check_packet(dw):
data = [prng.randrange(2**32-1) for _ in range(prng.randrange(1, 16))] data = [prng.randrange(2**dw-1) for _ in range(prng.randrange(1, 16))]
yield from send_packet(data) yield from send_packet(data, dw)
received = yield from receive_packet() received = yield from receive_packet(dw)
self.assertEqual(data, received) self.assertEqual(data, received)
def sim(): def sim(dw):
yield from link_init() yield from link_init(dw)
for i in range(8): for i in range(8):
yield from send_and_check_packet() yield from send_and_check_packet(dw)
@passive @passive
def rt_traffic(): def rt_traffic(dw):
while True: while True:
while prng.randrange(4): while prng.randrange(4):
yield yield
yield dut.link_layer.tx_rt_frame.eq(1) yield dut[dw].link_layer.tx_rt_frame.eq(1)
yield yield
while prng.randrange(4): while prng.randrange(4):
yield yield
yield dut.link_layer.tx_rt_frame.eq(0) yield dut[dw].link_layer.tx_rt_frame.eq(0)
yield yield
run_simulation(dut, [sim(), rt_traffic()]) run_simulation(dut[32], [sim(32), rt_traffic(32)])
run_simulation(dut[64], [sim(64), rt_traffic(64)])

View File

@ -31,24 +31,25 @@ def encode_record(channel, timestamp, address, data):
return encode_n(len(r)+1, 1, 1) + r return encode_n(len(r)+1, 1, 1) + r
def pack(x, size): def pack(x, size, dw):
r = [] r = []
for i in range((len(x)+size-1)//size): for i in range((len(x)+size-1)//size):
n = 0 n = 0
for j in range(i*size, (i+1)*size): for j in range(i*size//(dw//8), (i+1)*size//(dw//8)):
n <<= 8 n <<= dw
try: try:
n |= x[j] encoded = int.from_bytes(x[j*(dw//8): (j+1)*(dw//8)], "little")
n |= encoded
except IndexError: except IndexError:
pass pass
r.append(n) r.append(n)
return r return r
def encode_sequence(writes, ws): def encode_sequence(writes, ws, dw):
sequence = [b for write in writes for b in encode_record(*write)] sequence = [b for write in writes for b in encode_record(*write)]
sequence.append(0) sequence.append(0)
return pack(sequence, ws) return pack(sequence, ws, dw)
def do_dma(dut, address): def do_dma(dut, address):
@ -84,9 +85,9 @@ prng = random.Random(0)
class TB(Module): class TB(Module):
def __init__(self, ws): def __init__(self, ws, dw):
sequence1 = encode_sequence(test_writes1, ws) sequence1 = encode_sequence(test_writes1, ws, dw)
sequence2 = encode_sequence(test_writes2, ws) sequence2 = encode_sequence(test_writes2, ws, dw)
offset = 512//ws offset = 512//ws
assert len(sequence1) < offset assert len(sequence1) < offset
sequence = ( sequence = (
@ -97,7 +98,7 @@ class TB(Module):
bus = wishbone.Interface(ws*8) bus = wishbone.Interface(ws*8)
self.submodules.memory = wishbone.SRAM( self.submodules.memory = wishbone.SRAM(
1024, init=sequence, bus=bus) 1024, init=sequence, bus=bus)
self.submodules.dut = dma.DMA(bus) self.submodules.dut = dma.DMA(bus, dw)
test_writes_full_stack = [ test_writes_full_stack = [
@ -109,7 +110,7 @@ test_writes_full_stack = [
class FullStackTB(Module): class FullStackTB(Module):
def __init__(self, ws): def __init__(self, ws, dw):
self.ttl0 = Signal() self.ttl0 = Signal()
self.ttl1 = Signal() self.ttl1 = Signal()
@ -121,12 +122,12 @@ class FullStackTB(Module):
rtio.Channel.from_phy(self.phy1) rtio.Channel.from_phy(self.phy1)
] ]
sequence = encode_sequence(test_writes_full_stack, ws) sequence = encode_sequence(test_writes_full_stack, ws, dw)
bus = wishbone.Interface(ws*8) bus = wishbone.Interface(ws*8, 32-log2_int(dw//8))
self.submodules.memory = wishbone.SRAM( self.submodules.memory = wishbone.SRAM(
256, init=sequence, bus=bus) 256, init=sequence, bus=bus)
self.submodules.dut = dma.DMA(bus) self.submodules.dut = dma.DMA(bus, dw)
self.submodules.tsc = rtio.TSC("async") self.submodules.tsc = rtio.TSC("async")
self.submodules.rtio = rtio.Core(self.tsc, rtio_channels) self.submodules.rtio = rtio.Core(self.tsc, rtio_channels)
self.comb += self.dut.cri.connect(self.rtio.cri) self.comb += self.dut.cri.connect(self.rtio.cri)
@ -134,16 +135,22 @@ class FullStackTB(Module):
class TestDMA(unittest.TestCase): class TestDMA(unittest.TestCase):
def test_dma_noerror(self): def test_dma_noerror(self):
tb = TB(64) tb = {
32: TB(64, 32),
64: TB(64, 64)
}
def do_writes(): def do_writes(dw):
yield from do_dma(tb.dut, 0) yield from do_dma(tb[dw].dut, 0)
yield from do_dma(tb.dut, 512) yield from do_dma(tb[dw].dut, 512)
received = [] received = {
32: [],
64: []
}
@passive @passive
def rtio_sim(): def rtio_sim(dw):
dut_cri = tb.dut.cri dut_cri = tb[dw].dut.cri
while True: while True:
cmd = yield dut_cri.cmd cmd = yield dut_cri.cmd
if cmd == cri.commands["nop"]: if cmd == cri.commands["nop"]:
@ -153,7 +160,7 @@ class TestDMA(unittest.TestCase):
timestamp = yield dut_cri.o_timestamp timestamp = yield dut_cri.o_timestamp
address = yield dut_cri.o_address address = yield dut_cri.o_address
data = yield dut_cri.o_data data = yield dut_cri.o_data
received.append((channel, timestamp, address, data)) received[dw].append((channel, timestamp, address, data))
yield dut_cri.o_status.eq(1) yield dut_cri.o_status.eq(1)
for i in range(prng.randrange(10)): for i in range(prng.randrange(10)):
@ -163,32 +170,46 @@ class TestDMA(unittest.TestCase):
self.fail("unexpected RTIO command") self.fail("unexpected RTIO command")
yield yield
run_simulation(tb, [do_writes(), rtio_sim()]) run_simulation(tb[32], [do_writes(32), rtio_sim(32)])
self.assertEqual(received, test_writes1 + test_writes2) self.assertEqual(received[32], test_writes1 + test_writes2)
run_simulation(tb[64], [do_writes(64), rtio_sim(64)])
self.assertEqual(received[64], test_writes1 + test_writes2)
def test_full_stack(self): def test_full_stack(self):
tb = FullStackTB(64) tb = {
32: FullStackTB(64, 32),
64: FullStackTB(64, 64)
}
ttl_changes = [] ttl_changes = {
32: [],
64: []
}
@passive @passive
def monitor(): def monitor(dw):
old_ttl_states = [0, 0] old_ttl_states = [0, 0]
for time in itertools.count(): for time in itertools.count():
ttl_states = [ ttl_states = [
(yield tb.ttl0), (yield tb[dw].ttl0),
(yield tb.ttl1) (yield tb[dw].ttl1)
] ]
for i, (old, new) in enumerate(zip(old_ttl_states, ttl_states)): for i, (old, new) in enumerate(zip(old_ttl_states, ttl_states)):
if new != old: if new != old:
ttl_changes.append((time, i)) ttl_changes[dw].append((time, i))
old_ttl_states = ttl_states old_ttl_states = ttl_states
yield yield
run_simulation(tb, {"sys": [ run_simulation(tb[32], {"sys": [
do_dma(tb.dut, 0), monitor(), do_dma(tb[32].dut, 0), monitor(32),
(None for _ in range(70)),
]}, {"sys": 8, "rsys": 8, "rtio": 8, "rio": 8, "rio_phy": 8})
run_simulation(tb[64], {"sys": [
do_dma(tb[64].dut, 0), monitor(64),
(None for _ in range(70)), (None for _ in range(70)),
]}, {"sys": 8, "rsys": 8, "rtio": 8, "rio": 8, "rio_phy": 8}) ]}, {"sys": 8, "rsys": 8, "rtio": 8, "rio": 8, "rio_phy": 8})
correct_changes = [(timestamp + 11, channel) correct_changes = [(timestamp + 11, channel)
for channel, timestamp, _, _ in test_writes_full_stack] for channel, timestamp, _, _ in test_writes_full_stack]
self.assertEqual(ttl_changes, correct_changes) self.assertEqual(ttl_changes[32], correct_changes)
self.assertEqual(ttl_changes[64], correct_changes)