diff --git a/artiq/coredevice/dma.py b/artiq/coredevice/dma.py index 9148863d3..eafe98e6c 100644 --- a/artiq/coredevice/dma.py +++ b/artiq/coredevice/dma.py @@ -6,7 +6,8 @@ alone could achieve. """ from artiq.language.core import syscall, kernel -from artiq.language.types import TInt64, TStr, TNone +from artiq.language.types import TInt32, TInt64, TStr, TNone, TTuple +from artiq.coredevice.exceptions import DMAError from numpy import int64 @@ -24,7 +25,11 @@ def dma_erase(name: TStr) -> TNone: raise NotImplementedError("syscall not simulated") @syscall -def dma_playback(timestamp: TInt64, name: TStr) -> TNone: +def dma_retrieve(name: TStr) -> TTuple([TInt64, TInt32]): + raise NotImplementedError("syscall not simulated") + +@syscall +def dma_playback(timestamp: TInt64, ptr: TInt32) -> TNone: raise NotImplementedError("syscall not simulated") @@ -66,22 +71,45 @@ class CoreDMA: def __init__(self, dmgr, core_device="core"): self.core = dmgr.get(core_device) self.recorder = DMARecordContextManager() + self.epoch = 0 @kernel def record(self, name): """Returns a context manager that will record a DMA trace called ``name``. Any previously recorded trace with the same name is overwritten. The trace will persist across kernel switches.""" + self.epoch += 1 self.recorder.name = name return self.recorder @kernel def erase(self, name): """Removes the DMA trace with the given name from storage.""" + self.epoch += 1 dma_erase(name) @kernel - def replay(self, name): + def playback(self, name): """Replays a previously recorded DMA trace. This function blocks until the entire trace is submitted to the RTIO FIFOs.""" - dma_playback(now_mu(), name) + (advance_mu, ptr) = dma_retrieve(name) + dma_playback(now_mu(), ptr) + delay_mu(advance_mu) + + @kernel + def get_handle(self, name): + """Returns a handle to a previously recorded DMA trace. The returned handle + is only valid until the next call to :meth:`record` or :meth:`erase`.""" + (advance_mu, ptr) = dma_retrieve(name) + return (self.epoch, advance_mu, ptr) + + @kernel + def playback_handle(self, handle): + """Replays a handle obtained with :meth:`get_handle`. Using this function + is much faster than :meth:`playback` for replaying a set of traces repeatedly, + but incurs the overhead of managing the handles onto the programmer.""" + (epoch, advance_mu, ptr) = handle + if self.epoch != epoch: + raise DMAError("Invalid handle") + dma_playback(now_mu(), ptr) + delay_mu(advance_mu) diff --git a/artiq/firmware/ksupport/api.rs b/artiq/firmware/ksupport/api.rs index 771c580b5..efb458f68 100644 --- a/artiq/firmware/ksupport/api.rs +++ b/artiq/firmware/ksupport/api.rs @@ -108,6 +108,7 @@ static mut API: &'static [(&'static str, *const ())] = &[ api!(dma_record_start = ::dma_record_start), api!(dma_record_stop = ::dma_record_stop), api!(dma_erase = ::dma_erase), + api!(dma_retrieve = ::dma_retrieve), api!(dma_playback = ::dma_playback), api!(drtio_get_channel_state = ::rtio::drtio_dbg::get_channel_state), diff --git a/artiq/firmware/ksupport/lib.rs b/artiq/firmware/ksupport/lib.rs index 620115121..ea782f21b 100644 --- a/artiq/firmware/ksupport/lib.rs +++ b/artiq/firmware/ksupport/lib.rs @@ -360,39 +360,43 @@ extern fn dma_erase(name: CSlice) { send(&DmaEraseRequest { name: name }); } -extern fn dma_playback(timestamp: i64, name: CSlice) { +#[repr(C)] +struct DmaTrace { + duration: i64, + address: i32, +} + +extern fn dma_retrieve(name: CSlice) -> DmaTrace { let name = str::from_utf8(name.as_ref()).unwrap(); send(&DmaPlaybackRequest { name: name }); - let (succeeded, now_advance) = - recv!(&DmaPlaybackReply { trace, duration } => unsafe { + recv!(&DmaPlaybackReply { trace, duration } => { match trace { - Some(bytes) => { - let ptr = bytes.as_ptr() as usize; - assert!(ptr % 64 == 0); - - csr::rtio_dma::base_address_write(ptr as u64); - csr::rtio_dma::time_offset_write(timestamp as u64); - - csr::cri_con::selected_write(1); - csr::rtio_dma::enable_write(1); - while csr::rtio_dma::enable_read() != 0 {} - csr::cri_con::selected_write(0); - - (true, duration) - } - None => - (false, 0) + Some(bytes) => Ok(DmaTrace { + address: bytes.as_ptr() as i32, + duration: duration as i64 + }), + None => Err(()) } - }); - - if !succeeded { + }).unwrap_or_else(|()| { println!("DMA trace called {:?} not found", name); raise!("DMAError", "DMA trace not found"); - } + }) +} + +extern fn dma_playback(timestamp: i64, ptr: i32) { + assert!(ptr % 64 == 0); unsafe { + csr::rtio_dma::base_address_write(ptr as u64); + csr::rtio_dma::time_offset_write(timestamp as u64); + + csr::cri_con::selected_write(1); + csr::rtio_dma::enable_write(1); + while csr::rtio_dma::enable_read() != 0 {} + csr::cri_con::selected_write(0); + let status = csr::rtio_dma::error_status_read(); let timestamp = csr::rtio_dma::error_timestamp_read(); let channel = csr::rtio_dma::error_channel_read(); @@ -408,8 +412,6 @@ extern fn dma_playback(timestamp: i64, name: CSlice) { "RTIO sequence error at {0} mu, channel {1}", timestamp as i64, channel as i64, 0) } - - NOW += now_advance; } } diff --git a/artiq/test/coredevice/test_rtio.py b/artiq/test/coredevice/test_rtio.py index d0a1bb09b..0aff9b979 100644 --- a/artiq/test/coredevice/test_rtio.py +++ b/artiq/test/coredevice/test_rtio.py @@ -501,19 +501,28 @@ class _DMA(EnvExperiment): self.set_dataset("dma_record_time", self.core.mu_to_seconds(t2 - t1)) @kernel - def replay(self): - self.core.break_realtime() - delay(100*ms) - self.core_dma.replay(self.trace_name) - - @kernel - def replay_delta(self): + def playback(self, use_handle=False): self.core.break_realtime() delay(100*ms) start = now_mu() - self.core_dma.replay(self.trace_name) + if use_handle: + handle = self.core_dma.get_handle(self.trace_name) + self.core_dma.playback_handle(handle) + else: + self.core_dma.playback(self.trace_name) self.delta = now_mu() - start + @kernel + def playback_many(self, n): + self.core.break_realtime() + delay(100*ms) + handle = self.core_dma.get_handle(self.trace_name) + t1 = self.core.get_rtio_counter_mu() + for i in range(n): + self.core_dma.playback_handle(handle) + t2 = self.core.get_rtio_counter_mu() + self.set_dataset("dma_playback_time", self.core.mu_to_seconds(t2 - t1)) + @kernel def erase(self): self.core_dma.erase(self.trace_name) @@ -524,16 +533,26 @@ class _DMA(EnvExperiment): with self.core_dma.record(self.trace_name): pass + @kernel + def invalidate(self, mode): + self.record() + handle = self.core_dma.get_handle(self.trace_name) + if mode == 0: + self.record() + elif mode == 1: + self.erase() + self.core_dma.playback_handle(handle) + class DMATest(ExperimentCase): def test_dma_storage(self): exp = self.create(_DMA) exp.record() exp.record() # overwrite - exp.replay() + exp.playback() exp.erase() with self.assertRaises(exceptions.DMAError): - exp.replay() + exp.playback() def test_dma_nested(self): exp = self.create(_DMA) @@ -545,28 +564,32 @@ class DMATest(ExperimentCase): exp = self.create(_DMA) exp.record() - get_analyzer_dump(core_host) # clear analyzer buffer - exp.replay() - dump = decode_dump(get_analyzer_dump(core_host)) - self.assertEqual(len(dump.messages), 3) - self.assertIsInstance(dump.messages[-1], StoppedMessage) - self.assertIsInstance(dump.messages[0], OutputMessage) - self.assertEqual(dump.messages[0].channel, 1) - self.assertEqual(dump.messages[0].address, 0) - self.assertEqual(dump.messages[0].data, 1) - self.assertIsInstance(dump.messages[1], OutputMessage) - self.assertEqual(dump.messages[1].channel, 1) - self.assertEqual(dump.messages[1].address, 0) - self.assertEqual(dump.messages[1].data, 0) - self.assertEqual(dump.messages[1].timestamp - - dump.messages[0].timestamp, 100) + for use_handle in [False, True]: + get_analyzer_dump(core_host) # clear analyzer buffer + exp.playback(use_handle) + + dump = decode_dump(get_analyzer_dump(core_host)) + self.assertEqual(len(dump.messages), 3) + self.assertIsInstance(dump.messages[-1], StoppedMessage) + self.assertIsInstance(dump.messages[0], OutputMessage) + self.assertEqual(dump.messages[0].channel, 1) + self.assertEqual(dump.messages[0].address, 0) + self.assertEqual(dump.messages[0].data, 1) + self.assertIsInstance(dump.messages[1], OutputMessage) + self.assertEqual(dump.messages[1].channel, 1) + self.assertEqual(dump.messages[1].address, 0) + self.assertEqual(dump.messages[1].data, 0) + self.assertEqual(dump.messages[1].timestamp - + dump.messages[0].timestamp, 100) def test_dma_delta(self): exp = self.create(_DMA) exp.record() - exp.replay_delta() - self.assertEqual(exp.delta, 200) + + for use_handle in [False, True]: + exp.playback(use_handle) + self.assertEqual(exp.delta, 200) def test_dma_record_time(self): exp = self.create(_DMA) @@ -574,4 +597,18 @@ class DMATest(ExperimentCase): exp.record_many(count) dt = self.dataset_mgr.get("dma_record_time") print("dt={}, dt/count={}".format(dt, dt/count)) - self.assertLess(dt/count, 15*us) + self.assertLess(dt/count, 16*us) + + def test_dma_playback_time(self): + exp = self.create(_DMA) + count = 20000 + exp.playback_many(count) + dt = self.dataset_mgr.get("dma_playback_time") + print("dt={}, dt/count={}".format(dt, dt/count)) + self.assertLess(dt/count, 7*us) + + def test_handle_invalidation(self): + exp = self.create(_DMA) + for mode in [0, 1]: + with self.assertRaises(exceptions.DMAError): + exp.invalidate(mode)