mirror of https://github.com/m-labs/artiq.git
DMA: add API for a much faster replay using handles.
This commit is contained in:
parent
c6e8d5c901
commit
41c4de4556
|
@ -6,7 +6,8 @@ alone could achieve.
|
|||
"""
|
||||
|
||||
from artiq.language.core import syscall, kernel
|
||||
from artiq.language.types import TInt64, TStr, TNone
|
||||
from artiq.language.types import TInt32, TInt64, TStr, TNone, TTuple
|
||||
from artiq.coredevice.exceptions import DMAError
|
||||
|
||||
from numpy import int64
|
||||
|
||||
|
@ -24,7 +25,11 @@ def dma_erase(name: TStr) -> TNone:
|
|||
raise NotImplementedError("syscall not simulated")
|
||||
|
||||
@syscall
|
||||
def dma_playback(timestamp: TInt64, name: TStr) -> TNone:
|
||||
def dma_retrieve(name: TStr) -> TTuple([TInt64, TInt32]):
|
||||
raise NotImplementedError("syscall not simulated")
|
||||
|
||||
@syscall
|
||||
def dma_playback(timestamp: TInt64, ptr: TInt32) -> TNone:
|
||||
raise NotImplementedError("syscall not simulated")
|
||||
|
||||
|
||||
|
@ -66,22 +71,45 @@ class CoreDMA:
|
|||
def __init__(self, dmgr, core_device="core"):
|
||||
self.core = dmgr.get(core_device)
|
||||
self.recorder = DMARecordContextManager()
|
||||
self.epoch = 0
|
||||
|
||||
@kernel
|
||||
def record(self, name):
|
||||
"""Returns a context manager that will record a DMA trace called ``name``.
|
||||
Any previously recorded trace with the same name is overwritten.
|
||||
The trace will persist across kernel switches."""
|
||||
self.epoch += 1
|
||||
self.recorder.name = name
|
||||
return self.recorder
|
||||
|
||||
@kernel
|
||||
def erase(self, name):
|
||||
"""Removes the DMA trace with the given name from storage."""
|
||||
self.epoch += 1
|
||||
dma_erase(name)
|
||||
|
||||
@kernel
|
||||
def replay(self, name):
|
||||
def playback(self, name):
|
||||
"""Replays a previously recorded DMA trace. This function blocks until
|
||||
the entire trace is submitted to the RTIO FIFOs."""
|
||||
dma_playback(now_mu(), name)
|
||||
(advance_mu, ptr) = dma_retrieve(name)
|
||||
dma_playback(now_mu(), ptr)
|
||||
delay_mu(advance_mu)
|
||||
|
||||
@kernel
|
||||
def get_handle(self, name):
|
||||
"""Returns a handle to a previously recorded DMA trace. The returned handle
|
||||
is only valid until the next call to :meth:`record` or :meth:`erase`."""
|
||||
(advance_mu, ptr) = dma_retrieve(name)
|
||||
return (self.epoch, advance_mu, ptr)
|
||||
|
||||
@kernel
|
||||
def playback_handle(self, handle):
|
||||
"""Replays a handle obtained with :meth:`get_handle`. Using this function
|
||||
is much faster than :meth:`playback` for replaying a set of traces repeatedly,
|
||||
but incurs the overhead of managing the handles onto the programmer."""
|
||||
(epoch, advance_mu, ptr) = handle
|
||||
if self.epoch != epoch:
|
||||
raise DMAError("Invalid handle")
|
||||
dma_playback(now_mu(), ptr)
|
||||
delay_mu(advance_mu)
|
||||
|
|
|
@ -108,6 +108,7 @@ static mut API: &'static [(&'static str, *const ())] = &[
|
|||
api!(dma_record_start = ::dma_record_start),
|
||||
api!(dma_record_stop = ::dma_record_stop),
|
||||
api!(dma_erase = ::dma_erase),
|
||||
api!(dma_retrieve = ::dma_retrieve),
|
||||
api!(dma_playback = ::dma_playback),
|
||||
|
||||
api!(drtio_get_channel_state = ::rtio::drtio_dbg::get_channel_state),
|
||||
|
|
|
@ -360,39 +360,43 @@ extern fn dma_erase(name: CSlice<u8>) {
|
|||
send(&DmaEraseRequest { name: name });
|
||||
}
|
||||
|
||||
extern fn dma_playback(timestamp: i64, name: CSlice<u8>) {
|
||||
#[repr(C)]
|
||||
struct DmaTrace {
|
||||
duration: i64,
|
||||
address: i32,
|
||||
}
|
||||
|
||||
extern fn dma_retrieve(name: CSlice<u8>) -> DmaTrace {
|
||||
let name = str::from_utf8(name.as_ref()).unwrap();
|
||||
|
||||
send(&DmaPlaybackRequest { name: name });
|
||||
let (succeeded, now_advance) =
|
||||
recv!(&DmaPlaybackReply { trace, duration } => unsafe {
|
||||
recv!(&DmaPlaybackReply { trace, duration } => {
|
||||
match trace {
|
||||
Some(bytes) => {
|
||||
let ptr = bytes.as_ptr() as usize;
|
||||
assert!(ptr % 64 == 0);
|
||||
|
||||
csr::rtio_dma::base_address_write(ptr as u64);
|
||||
csr::rtio_dma::time_offset_write(timestamp as u64);
|
||||
|
||||
csr::cri_con::selected_write(1);
|
||||
csr::rtio_dma::enable_write(1);
|
||||
while csr::rtio_dma::enable_read() != 0 {}
|
||||
csr::cri_con::selected_write(0);
|
||||
|
||||
(true, duration)
|
||||
}
|
||||
None =>
|
||||
(false, 0)
|
||||
Some(bytes) => Ok(DmaTrace {
|
||||
address: bytes.as_ptr() as i32,
|
||||
duration: duration as i64
|
||||
}),
|
||||
None => Err(())
|
||||
}
|
||||
});
|
||||
|
||||
if !succeeded {
|
||||
}).unwrap_or_else(|()| {
|
||||
println!("DMA trace called {:?} not found", name);
|
||||
raise!("DMAError",
|
||||
"DMA trace not found");
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
extern fn dma_playback(timestamp: i64, ptr: i32) {
|
||||
assert!(ptr % 64 == 0);
|
||||
|
||||
unsafe {
|
||||
csr::rtio_dma::base_address_write(ptr as u64);
|
||||
csr::rtio_dma::time_offset_write(timestamp as u64);
|
||||
|
||||
csr::cri_con::selected_write(1);
|
||||
csr::rtio_dma::enable_write(1);
|
||||
while csr::rtio_dma::enable_read() != 0 {}
|
||||
csr::cri_con::selected_write(0);
|
||||
|
||||
let status = csr::rtio_dma::error_status_read();
|
||||
let timestamp = csr::rtio_dma::error_timestamp_read();
|
||||
let channel = csr::rtio_dma::error_channel_read();
|
||||
|
@ -408,8 +412,6 @@ extern fn dma_playback(timestamp: i64, name: CSlice<u8>) {
|
|||
"RTIO sequence error at {0} mu, channel {1}",
|
||||
timestamp as i64, channel as i64, 0)
|
||||
}
|
||||
|
||||
NOW += now_advance;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -501,19 +501,28 @@ class _DMA(EnvExperiment):
|
|||
self.set_dataset("dma_record_time", self.core.mu_to_seconds(t2 - t1))
|
||||
|
||||
@kernel
|
||||
def replay(self):
|
||||
self.core.break_realtime()
|
||||
delay(100*ms)
|
||||
self.core_dma.replay(self.trace_name)
|
||||
|
||||
@kernel
|
||||
def replay_delta(self):
|
||||
def playback(self, use_handle=False):
|
||||
self.core.break_realtime()
|
||||
delay(100*ms)
|
||||
start = now_mu()
|
||||
self.core_dma.replay(self.trace_name)
|
||||
if use_handle:
|
||||
handle = self.core_dma.get_handle(self.trace_name)
|
||||
self.core_dma.playback_handle(handle)
|
||||
else:
|
||||
self.core_dma.playback(self.trace_name)
|
||||
self.delta = now_mu() - start
|
||||
|
||||
@kernel
|
||||
def playback_many(self, n):
|
||||
self.core.break_realtime()
|
||||
delay(100*ms)
|
||||
handle = self.core_dma.get_handle(self.trace_name)
|
||||
t1 = self.core.get_rtio_counter_mu()
|
||||
for i in range(n):
|
||||
self.core_dma.playback_handle(handle)
|
||||
t2 = self.core.get_rtio_counter_mu()
|
||||
self.set_dataset("dma_playback_time", self.core.mu_to_seconds(t2 - t1))
|
||||
|
||||
@kernel
|
||||
def erase(self):
|
||||
self.core_dma.erase(self.trace_name)
|
||||
|
@ -524,16 +533,26 @@ class _DMA(EnvExperiment):
|
|||
with self.core_dma.record(self.trace_name):
|
||||
pass
|
||||
|
||||
@kernel
|
||||
def invalidate(self, mode):
|
||||
self.record()
|
||||
handle = self.core_dma.get_handle(self.trace_name)
|
||||
if mode == 0:
|
||||
self.record()
|
||||
elif mode == 1:
|
||||
self.erase()
|
||||
self.core_dma.playback_handle(handle)
|
||||
|
||||
|
||||
class DMATest(ExperimentCase):
|
||||
def test_dma_storage(self):
|
||||
exp = self.create(_DMA)
|
||||
exp.record()
|
||||
exp.record() # overwrite
|
||||
exp.replay()
|
||||
exp.playback()
|
||||
exp.erase()
|
||||
with self.assertRaises(exceptions.DMAError):
|
||||
exp.replay()
|
||||
exp.playback()
|
||||
|
||||
def test_dma_nested(self):
|
||||
exp = self.create(_DMA)
|
||||
|
@ -545,28 +564,32 @@ class DMATest(ExperimentCase):
|
|||
|
||||
exp = self.create(_DMA)
|
||||
exp.record()
|
||||
get_analyzer_dump(core_host) # clear analyzer buffer
|
||||
exp.replay()
|
||||
|
||||
dump = decode_dump(get_analyzer_dump(core_host))
|
||||
self.assertEqual(len(dump.messages), 3)
|
||||
self.assertIsInstance(dump.messages[-1], StoppedMessage)
|
||||
self.assertIsInstance(dump.messages[0], OutputMessage)
|
||||
self.assertEqual(dump.messages[0].channel, 1)
|
||||
self.assertEqual(dump.messages[0].address, 0)
|
||||
self.assertEqual(dump.messages[0].data, 1)
|
||||
self.assertIsInstance(dump.messages[1], OutputMessage)
|
||||
self.assertEqual(dump.messages[1].channel, 1)
|
||||
self.assertEqual(dump.messages[1].address, 0)
|
||||
self.assertEqual(dump.messages[1].data, 0)
|
||||
self.assertEqual(dump.messages[1].timestamp -
|
||||
dump.messages[0].timestamp, 100)
|
||||
for use_handle in [False, True]:
|
||||
get_analyzer_dump(core_host) # clear analyzer buffer
|
||||
exp.playback(use_handle)
|
||||
|
||||
dump = decode_dump(get_analyzer_dump(core_host))
|
||||
self.assertEqual(len(dump.messages), 3)
|
||||
self.assertIsInstance(dump.messages[-1], StoppedMessage)
|
||||
self.assertIsInstance(dump.messages[0], OutputMessage)
|
||||
self.assertEqual(dump.messages[0].channel, 1)
|
||||
self.assertEqual(dump.messages[0].address, 0)
|
||||
self.assertEqual(dump.messages[0].data, 1)
|
||||
self.assertIsInstance(dump.messages[1], OutputMessage)
|
||||
self.assertEqual(dump.messages[1].channel, 1)
|
||||
self.assertEqual(dump.messages[1].address, 0)
|
||||
self.assertEqual(dump.messages[1].data, 0)
|
||||
self.assertEqual(dump.messages[1].timestamp -
|
||||
dump.messages[0].timestamp, 100)
|
||||
|
||||
def test_dma_delta(self):
|
||||
exp = self.create(_DMA)
|
||||
exp.record()
|
||||
exp.replay_delta()
|
||||
self.assertEqual(exp.delta, 200)
|
||||
|
||||
for use_handle in [False, True]:
|
||||
exp.playback(use_handle)
|
||||
self.assertEqual(exp.delta, 200)
|
||||
|
||||
def test_dma_record_time(self):
|
||||
exp = self.create(_DMA)
|
||||
|
@ -574,4 +597,18 @@ class DMATest(ExperimentCase):
|
|||
exp.record_many(count)
|
||||
dt = self.dataset_mgr.get("dma_record_time")
|
||||
print("dt={}, dt/count={}".format(dt, dt/count))
|
||||
self.assertLess(dt/count, 15*us)
|
||||
self.assertLess(dt/count, 16*us)
|
||||
|
||||
def test_dma_playback_time(self):
|
||||
exp = self.create(_DMA)
|
||||
count = 20000
|
||||
exp.playback_many(count)
|
||||
dt = self.dataset_mgr.get("dma_playback_time")
|
||||
print("dt={}, dt/count={}".format(dt, dt/count))
|
||||
self.assertLess(dt/count, 7*us)
|
||||
|
||||
def test_handle_invalidation(self):
|
||||
exp = self.create(_DMA)
|
||||
for mode in [0, 1]:
|
||||
with self.assertRaises(exceptions.DMAError):
|
||||
exp.invalidate(mode)
|
||||
|
|
Loading…
Reference in New Issue