From 09945ecc4d8b06c9761e46985e7d0ee0a0492cde Mon Sep 17 00:00:00 2001 From: occheung Date: Mon, 8 Nov 2021 13:05:55 +0800 Subject: [PATCH] gateware: fix drtio/dma tests --- .../test/drtio/test_aux_controller.py | 58 +++++++------ artiq/gateware/test/rtio/test_dma.py | 87 ++++++++++++------- 2 files changed, 85 insertions(+), 60 deletions(-) diff --git a/artiq/gateware/test/drtio/test_aux_controller.py b/artiq/gateware/test/drtio/test_aux_controller.py index 64e2e15d7..68c9f3bbd 100644 --- a/artiq/gateware/test/drtio/test_aux_controller.py +++ b/artiq/gateware/test/drtio/test_aux_controller.py @@ -33,63 +33,67 @@ class Loopback(Module): class TB(Module): - def __init__(self, nwords): + def __init__(self, nwords, dw): self.submodules.link_layer = Loopback(nwords) self.submodules.aux_controller = ClockDomainsRenamer( - {"rtio": "sys", "rtio_rx": "sys"})(DRTIOAuxController(self.link_layer)) + {"rtio": "sys", "rtio_rx": "sys"})(DRTIOAuxController(self.link_layer, dw)) class TestAuxController(unittest.TestCase): def test_aux_controller(self): - dut = TB(4) + dut = { + 32: TB(4, 32), + 64: TB(4, 64) + } - def link_init(): + def link_init(dw): for i in range(8): yield - yield dut.link_layer.ready.eq(1) + yield dut[dw].link_layer.ready.eq(1) - def send_packet(packet): + def send_packet(packet, dw): for i, d in enumerate(packet): - yield from dut.aux_controller.bus.write(i, d) - yield from dut.aux_controller.transmitter.aux_tx_length.write(len(packet)*4) - yield from dut.aux_controller.transmitter.aux_tx.write(1) + yield from dut[dw].aux_controller.bus.write(i, d) + yield from dut[dw].aux_controller.transmitter.aux_tx_length.write(len(packet)*dw//8) + yield from dut[dw].aux_controller.transmitter.aux_tx.write(1) yield - while (yield from dut.aux_controller.transmitter.aux_tx.read()): + while (yield from dut[dw].aux_controller.transmitter.aux_tx.read()): yield - def receive_packet(): - while not (yield from dut.aux_controller.receiver.aux_rx_present.read()): + def receive_packet(dw): + while not (yield from dut[dw].aux_controller.receiver.aux_rx_present.read()): yield - length = yield from dut.aux_controller.receiver.aux_rx_length.read() + length = yield from dut[dw].aux_controller.receiver.aux_rx_length.read() r = [] - for i in range(length//4): - r.append((yield from dut.aux_controller.bus.read(256+i))) - yield from dut.aux_controller.receiver.aux_rx_present.write(1) + for i in range(length//(dw//8)): + r.append((yield from dut[dw].aux_controller.bus.read(256+i))) + yield from dut[dw].aux_controller.receiver.aux_rx_present.write(1) return r prng = random.Random(0) - def send_and_check_packet(): - data = [prng.randrange(2**32-1) for _ in range(prng.randrange(1, 16))] - yield from send_packet(data) - received = yield from receive_packet() + def send_and_check_packet(dw): + data = [prng.randrange(2**dw-1) for _ in range(prng.randrange(1, 16))] + yield from send_packet(data, dw) + received = yield from receive_packet(dw) self.assertEqual(data, received) - def sim(): - yield from link_init() + def sim(dw): + yield from link_init(dw) for i in range(8): - yield from send_and_check_packet() + yield from send_and_check_packet(dw) @passive - def rt_traffic(): + def rt_traffic(dw): while True: while prng.randrange(4): yield - yield dut.link_layer.tx_rt_frame.eq(1) + yield dut[dw].link_layer.tx_rt_frame.eq(1) yield while prng.randrange(4): yield - yield dut.link_layer.tx_rt_frame.eq(0) + yield dut[dw].link_layer.tx_rt_frame.eq(0) yield - run_simulation(dut, [sim(), rt_traffic()]) + run_simulation(dut[32], [sim(32), rt_traffic(32)]) + run_simulation(dut[64], [sim(64), rt_traffic(64)]) diff --git a/artiq/gateware/test/rtio/test_dma.py b/artiq/gateware/test/rtio/test_dma.py index 84bc4a3ff..c5d220a9f 100644 --- a/artiq/gateware/test/rtio/test_dma.py +++ b/artiq/gateware/test/rtio/test_dma.py @@ -31,24 +31,25 @@ def encode_record(channel, timestamp, address, data): return encode_n(len(r)+1, 1, 1) + r -def pack(x, size): +def pack(x, size, dw): r = [] for i in range((len(x)+size-1)//size): n = 0 - for j in range(i*size, (i+1)*size): - n <<= 8 + for j in range(i*size//(dw//8), (i+1)*size//(dw//8)): + n <<= dw try: - n |= x[j] + encoded = int.from_bytes(x[j*(dw//8): (j+1)*(dw//8)], "little") + n |= encoded except IndexError: pass r.append(n) return r -def encode_sequence(writes, ws): +def encode_sequence(writes, ws, dw): sequence = [b for write in writes for b in encode_record(*write)] sequence.append(0) - return pack(sequence, ws) + return pack(sequence, ws, dw) def do_dma(dut, address): @@ -84,9 +85,9 @@ prng = random.Random(0) class TB(Module): - def __init__(self, ws): - sequence1 = encode_sequence(test_writes1, ws) - sequence2 = encode_sequence(test_writes2, ws) + def __init__(self, ws, dw): + sequence1 = encode_sequence(test_writes1, ws, dw) + sequence2 = encode_sequence(test_writes2, ws, dw) offset = 512//ws assert len(sequence1) < offset sequence = ( @@ -97,7 +98,7 @@ class TB(Module): bus = wishbone.Interface(ws*8) self.submodules.memory = wishbone.SRAM( 1024, init=sequence, bus=bus) - self.submodules.dut = dma.DMA(bus) + self.submodules.dut = dma.DMA(bus, dw) test_writes_full_stack = [ @@ -109,7 +110,7 @@ test_writes_full_stack = [ class FullStackTB(Module): - def __init__(self, ws): + def __init__(self, ws, dw): self.ttl0 = Signal() self.ttl1 = Signal() @@ -121,12 +122,12 @@ class FullStackTB(Module): rtio.Channel.from_phy(self.phy1) ] - sequence = encode_sequence(test_writes_full_stack, ws) + sequence = encode_sequence(test_writes_full_stack, ws, dw) - bus = wishbone.Interface(ws*8) + bus = wishbone.Interface(ws*8, 32-log2_int(dw//8)) self.submodules.memory = wishbone.SRAM( 256, init=sequence, bus=bus) - self.submodules.dut = dma.DMA(bus) + self.submodules.dut = dma.DMA(bus, dw) self.submodules.tsc = rtio.TSC("async") self.submodules.rtio = rtio.Core(self.tsc, rtio_channels) self.comb += self.dut.cri.connect(self.rtio.cri) @@ -134,16 +135,22 @@ class FullStackTB(Module): class TestDMA(unittest.TestCase): def test_dma_noerror(self): - tb = TB(64) + tb = { + 32: TB(64, 32), + 64: TB(64, 64) + } - def do_writes(): - yield from do_dma(tb.dut, 0) - yield from do_dma(tb.dut, 512) + def do_writes(dw): + yield from do_dma(tb[dw].dut, 0) + yield from do_dma(tb[dw].dut, 512) - received = [] + received = { + 32: [], + 64: [] + } @passive - def rtio_sim(): - dut_cri = tb.dut.cri + def rtio_sim(dw): + dut_cri = tb[dw].dut.cri while True: cmd = yield dut_cri.cmd if cmd == cri.commands["nop"]: @@ -153,7 +160,7 @@ class TestDMA(unittest.TestCase): timestamp = yield dut_cri.o_timestamp address = yield dut_cri.o_address data = yield dut_cri.o_data - received.append((channel, timestamp, address, data)) + received[dw].append((channel, timestamp, address, data)) yield dut_cri.o_status.eq(1) for i in range(prng.randrange(10)): @@ -163,32 +170,46 @@ class TestDMA(unittest.TestCase): self.fail("unexpected RTIO command") yield - run_simulation(tb, [do_writes(), rtio_sim()]) - self.assertEqual(received, test_writes1 + test_writes2) + run_simulation(tb[32], [do_writes(32), rtio_sim(32)]) + self.assertEqual(received[32], test_writes1 + test_writes2) + + run_simulation(tb[64], [do_writes(64), rtio_sim(64)]) + self.assertEqual(received[64], test_writes1 + test_writes2) def test_full_stack(self): - tb = FullStackTB(64) + tb = { + 32: FullStackTB(64, 32), + 64: FullStackTB(64, 64) + } - ttl_changes = [] + ttl_changes = { + 32: [], + 64: [] + } @passive - def monitor(): + def monitor(dw): old_ttl_states = [0, 0] for time in itertools.count(): ttl_states = [ - (yield tb.ttl0), - (yield tb.ttl1) + (yield tb[dw].ttl0), + (yield tb[dw].ttl1) ] for i, (old, new) in enumerate(zip(old_ttl_states, ttl_states)): if new != old: - ttl_changes.append((time, i)) + ttl_changes[dw].append((time, i)) old_ttl_states = ttl_states yield - run_simulation(tb, {"sys": [ - do_dma(tb.dut, 0), monitor(), + run_simulation(tb[32], {"sys": [ + do_dma(tb[32].dut, 0), monitor(32), + (None for _ in range(70)), + ]}, {"sys": 8, "rsys": 8, "rtio": 8, "rio": 8, "rio_phy": 8}) + run_simulation(tb[64], {"sys": [ + do_dma(tb[64].dut, 0), monitor(64), (None for _ in range(70)), ]}, {"sys": 8, "rsys": 8, "rtio": 8, "rio": 8, "rio_phy": 8}) correct_changes = [(timestamp + 11, channel) for channel, timestamp, _, _ in test_writes_full_stack] - self.assertEqual(ttl_changes, correct_changes) + self.assertEqual(ttl_changes[32], correct_changes) + self.assertEqual(ttl_changes[64], correct_changes)