DMA: add API for a much faster replay using handles.

This commit is contained in:
whitequark 2017-04-18 08:11:14 +00:00
parent c6e8d5c901
commit 41c4de4556
4 changed files with 125 additions and 57 deletions

View File

@ -6,7 +6,8 @@ alone could achieve.
""" """
from artiq.language.core import syscall, kernel from artiq.language.core import syscall, kernel
from artiq.language.types import TInt64, TStr, TNone from artiq.language.types import TInt32, TInt64, TStr, TNone, TTuple
from artiq.coredevice.exceptions import DMAError
from numpy import int64 from numpy import int64
@ -24,7 +25,11 @@ def dma_erase(name: TStr) -> TNone:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@syscall @syscall
def dma_playback(timestamp: TInt64, name: TStr) -> TNone: def dma_retrieve(name: TStr) -> TTuple([TInt64, TInt32]):
raise NotImplementedError("syscall not simulated")
@syscall
def dma_playback(timestamp: TInt64, ptr: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@ -66,22 +71,45 @@ class CoreDMA:
def __init__(self, dmgr, core_device="core"): def __init__(self, dmgr, core_device="core"):
self.core = dmgr.get(core_device) self.core = dmgr.get(core_device)
self.recorder = DMARecordContextManager() self.recorder = DMARecordContextManager()
self.epoch = 0
@kernel @kernel
def record(self, name): def record(self, name):
"""Returns a context manager that will record a DMA trace called ``name``. """Returns a context manager that will record a DMA trace called ``name``.
Any previously recorded trace with the same name is overwritten. Any previously recorded trace with the same name is overwritten.
The trace will persist across kernel switches.""" The trace will persist across kernel switches."""
self.epoch += 1
self.recorder.name = name self.recorder.name = name
return self.recorder return self.recorder
@kernel @kernel
def erase(self, name): def erase(self, name):
"""Removes the DMA trace with the given name from storage.""" """Removes the DMA trace with the given name from storage."""
self.epoch += 1
dma_erase(name) dma_erase(name)
@kernel @kernel
def replay(self, name): def playback(self, name):
"""Replays a previously recorded DMA trace. This function blocks until """Replays a previously recorded DMA trace. This function blocks until
the entire trace is submitted to the RTIO FIFOs.""" the entire trace is submitted to the RTIO FIFOs."""
dma_playback(now_mu(), name) (advance_mu, ptr) = dma_retrieve(name)
dma_playback(now_mu(), ptr)
delay_mu(advance_mu)
@kernel
def get_handle(self, name):
"""Returns a handle to a previously recorded DMA trace. The returned handle
is only valid until the next call to :meth:`record` or :meth:`erase`."""
(advance_mu, ptr) = dma_retrieve(name)
return (self.epoch, advance_mu, ptr)
@kernel
def playback_handle(self, handle):
"""Replays a handle obtained with :meth:`get_handle`. Using this function
is much faster than :meth:`playback` for replaying a set of traces repeatedly,
but incurs the overhead of managing the handles onto the programmer."""
(epoch, advance_mu, ptr) = handle
if self.epoch != epoch:
raise DMAError("Invalid handle")
dma_playback(now_mu(), ptr)
delay_mu(advance_mu)

View File

@ -108,6 +108,7 @@ static mut API: &'static [(&'static str, *const ())] = &[
api!(dma_record_start = ::dma_record_start), api!(dma_record_start = ::dma_record_start),
api!(dma_record_stop = ::dma_record_stop), api!(dma_record_stop = ::dma_record_stop),
api!(dma_erase = ::dma_erase), api!(dma_erase = ::dma_erase),
api!(dma_retrieve = ::dma_retrieve),
api!(dma_playback = ::dma_playback), api!(dma_playback = ::dma_playback),
api!(drtio_get_channel_state = ::rtio::drtio_dbg::get_channel_state), api!(drtio_get_channel_state = ::rtio::drtio_dbg::get_channel_state),

View File

@ -360,39 +360,43 @@ extern fn dma_erase(name: CSlice<u8>) {
send(&DmaEraseRequest { name: name }); send(&DmaEraseRequest { name: name });
} }
extern fn dma_playback(timestamp: i64, name: CSlice<u8>) { #[repr(C)]
struct DmaTrace {
duration: i64,
address: i32,
}
extern fn dma_retrieve(name: CSlice<u8>) -> DmaTrace {
let name = str::from_utf8(name.as_ref()).unwrap(); let name = str::from_utf8(name.as_ref()).unwrap();
send(&DmaPlaybackRequest { name: name }); send(&DmaPlaybackRequest { name: name });
let (succeeded, now_advance) = recv!(&DmaPlaybackReply { trace, duration } => {
recv!(&DmaPlaybackReply { trace, duration } => unsafe {
match trace { match trace {
Some(bytes) => { Some(bytes) => Ok(DmaTrace {
let ptr = bytes.as_ptr() as usize; address: bytes.as_ptr() as i32,
assert!(ptr % 64 == 0); duration: duration as i64
}),
csr::rtio_dma::base_address_write(ptr as u64); None => Err(())
csr::rtio_dma::time_offset_write(timestamp as u64);
csr::cri_con::selected_write(1);
csr::rtio_dma::enable_write(1);
while csr::rtio_dma::enable_read() != 0 {}
csr::cri_con::selected_write(0);
(true, duration)
}
None =>
(false, 0)
} }
}); }).unwrap_or_else(|()| {
if !succeeded {
println!("DMA trace called {:?} not found", name); println!("DMA trace called {:?} not found", name);
raise!("DMAError", raise!("DMAError",
"DMA trace not found"); "DMA trace not found");
} })
}
extern fn dma_playback(timestamp: i64, ptr: i32) {
assert!(ptr % 64 == 0);
unsafe { unsafe {
csr::rtio_dma::base_address_write(ptr as u64);
csr::rtio_dma::time_offset_write(timestamp as u64);
csr::cri_con::selected_write(1);
csr::rtio_dma::enable_write(1);
while csr::rtio_dma::enable_read() != 0 {}
csr::cri_con::selected_write(0);
let status = csr::rtio_dma::error_status_read(); let status = csr::rtio_dma::error_status_read();
let timestamp = csr::rtio_dma::error_timestamp_read(); let timestamp = csr::rtio_dma::error_timestamp_read();
let channel = csr::rtio_dma::error_channel_read(); let channel = csr::rtio_dma::error_channel_read();
@ -408,8 +412,6 @@ extern fn dma_playback(timestamp: i64, name: CSlice<u8>) {
"RTIO sequence error at {0} mu, channel {1}", "RTIO sequence error at {0} mu, channel {1}",
timestamp as i64, channel as i64, 0) timestamp as i64, channel as i64, 0)
} }
NOW += now_advance;
} }
} }

View File

@ -501,19 +501,28 @@ class _DMA(EnvExperiment):
self.set_dataset("dma_record_time", self.core.mu_to_seconds(t2 - t1)) self.set_dataset("dma_record_time", self.core.mu_to_seconds(t2 - t1))
@kernel @kernel
def replay(self): def playback(self, use_handle=False):
self.core.break_realtime()
delay(100*ms)
self.core_dma.replay(self.trace_name)
@kernel
def replay_delta(self):
self.core.break_realtime() self.core.break_realtime()
delay(100*ms) delay(100*ms)
start = now_mu() start = now_mu()
self.core_dma.replay(self.trace_name) if use_handle:
handle = self.core_dma.get_handle(self.trace_name)
self.core_dma.playback_handle(handle)
else:
self.core_dma.playback(self.trace_name)
self.delta = now_mu() - start self.delta = now_mu() - start
@kernel
def playback_many(self, n):
self.core.break_realtime()
delay(100*ms)
handle = self.core_dma.get_handle(self.trace_name)
t1 = self.core.get_rtio_counter_mu()
for i in range(n):
self.core_dma.playback_handle(handle)
t2 = self.core.get_rtio_counter_mu()
self.set_dataset("dma_playback_time", self.core.mu_to_seconds(t2 - t1))
@kernel @kernel
def erase(self): def erase(self):
self.core_dma.erase(self.trace_name) self.core_dma.erase(self.trace_name)
@ -524,16 +533,26 @@ class _DMA(EnvExperiment):
with self.core_dma.record(self.trace_name): with self.core_dma.record(self.trace_name):
pass pass
@kernel
def invalidate(self, mode):
self.record()
handle = self.core_dma.get_handle(self.trace_name)
if mode == 0:
self.record()
elif mode == 1:
self.erase()
self.core_dma.playback_handle(handle)
class DMATest(ExperimentCase): class DMATest(ExperimentCase):
def test_dma_storage(self): def test_dma_storage(self):
exp = self.create(_DMA) exp = self.create(_DMA)
exp.record() exp.record()
exp.record() # overwrite exp.record() # overwrite
exp.replay() exp.playback()
exp.erase() exp.erase()
with self.assertRaises(exceptions.DMAError): with self.assertRaises(exceptions.DMAError):
exp.replay() exp.playback()
def test_dma_nested(self): def test_dma_nested(self):
exp = self.create(_DMA) exp = self.create(_DMA)
@ -545,28 +564,32 @@ class DMATest(ExperimentCase):
exp = self.create(_DMA) exp = self.create(_DMA)
exp.record() exp.record()
get_analyzer_dump(core_host) # clear analyzer buffer
exp.replay()
dump = decode_dump(get_analyzer_dump(core_host)) for use_handle in [False, True]:
self.assertEqual(len(dump.messages), 3) get_analyzer_dump(core_host) # clear analyzer buffer
self.assertIsInstance(dump.messages[-1], StoppedMessage) exp.playback(use_handle)
self.assertIsInstance(dump.messages[0], OutputMessage)
self.assertEqual(dump.messages[0].channel, 1) dump = decode_dump(get_analyzer_dump(core_host))
self.assertEqual(dump.messages[0].address, 0) self.assertEqual(len(dump.messages), 3)
self.assertEqual(dump.messages[0].data, 1) self.assertIsInstance(dump.messages[-1], StoppedMessage)
self.assertIsInstance(dump.messages[1], OutputMessage) self.assertIsInstance(dump.messages[0], OutputMessage)
self.assertEqual(dump.messages[1].channel, 1) self.assertEqual(dump.messages[0].channel, 1)
self.assertEqual(dump.messages[1].address, 0) self.assertEqual(dump.messages[0].address, 0)
self.assertEqual(dump.messages[1].data, 0) self.assertEqual(dump.messages[0].data, 1)
self.assertEqual(dump.messages[1].timestamp - self.assertIsInstance(dump.messages[1], OutputMessage)
dump.messages[0].timestamp, 100) self.assertEqual(dump.messages[1].channel, 1)
self.assertEqual(dump.messages[1].address, 0)
self.assertEqual(dump.messages[1].data, 0)
self.assertEqual(dump.messages[1].timestamp -
dump.messages[0].timestamp, 100)
def test_dma_delta(self): def test_dma_delta(self):
exp = self.create(_DMA) exp = self.create(_DMA)
exp.record() exp.record()
exp.replay_delta()
self.assertEqual(exp.delta, 200) for use_handle in [False, True]:
exp.playback(use_handle)
self.assertEqual(exp.delta, 200)
def test_dma_record_time(self): def test_dma_record_time(self):
exp = self.create(_DMA) exp = self.create(_DMA)
@ -574,4 +597,18 @@ class DMATest(ExperimentCase):
exp.record_many(count) exp.record_many(count)
dt = self.dataset_mgr.get("dma_record_time") dt = self.dataset_mgr.get("dma_record_time")
print("dt={}, dt/count={}".format(dt, dt/count)) print("dt={}, dt/count={}".format(dt, dt/count))
self.assertLess(dt/count, 15*us) self.assertLess(dt/count, 16*us)
def test_dma_playback_time(self):
exp = self.create(_DMA)
count = 20000
exp.playback_many(count)
dt = self.dataset_mgr.get("dma_playback_time")
print("dt={}, dt/count={}".format(dt, dt/count))
self.assertLess(dt/count, 7*us)
def test_handle_invalidation(self):
exp = self.create(_DMA)
for mode in [0, 1]:
with self.assertRaises(exceptions.DMAError):
exp.invalidate(mode)