2
0
mirror of https://github.com/m-labs/artiq.git synced 2024-12-28 20:53:35 +08:00

DMA: add API for a much faster replay using handles.

This commit is contained in:
whitequark 2017-04-18 08:11:14 +00:00
parent c6e8d5c901
commit 41c4de4556
4 changed files with 125 additions and 57 deletions

View File

@ -6,7 +6,8 @@ alone could achieve.
"""
from artiq.language.core import syscall, kernel
from artiq.language.types import TInt64, TStr, TNone
from artiq.language.types import TInt32, TInt64, TStr, TNone, TTuple
from artiq.coredevice.exceptions import DMAError
from numpy import int64
@ -24,7 +25,11 @@ def dma_erase(name: TStr) -> TNone:
raise NotImplementedError("syscall not simulated")
@syscall
def dma_playback(timestamp: TInt64, name: TStr) -> TNone:
def dma_retrieve(name: TStr) -> TTuple([TInt64, TInt32]):
raise NotImplementedError("syscall not simulated")
@syscall
def dma_playback(timestamp: TInt64, ptr: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated")
@ -66,22 +71,45 @@ class CoreDMA:
def __init__(self, dmgr, core_device="core"):
self.core = dmgr.get(core_device)
self.recorder = DMARecordContextManager()
self.epoch = 0
@kernel
def record(self, name):
"""Returns a context manager that will record a DMA trace called ``name``.
Any previously recorded trace with the same name is overwritten.
The trace will persist across kernel switches."""
self.epoch += 1
self.recorder.name = name
return self.recorder
@kernel
def erase(self, name):
"""Removes the DMA trace with the given name from storage."""
self.epoch += 1
dma_erase(name)
@kernel
def replay(self, name):
def playback(self, name):
"""Replays a previously recorded DMA trace. This function blocks until
the entire trace is submitted to the RTIO FIFOs."""
dma_playback(now_mu(), name)
(advance_mu, ptr) = dma_retrieve(name)
dma_playback(now_mu(), ptr)
delay_mu(advance_mu)
@kernel
def get_handle(self, name):
"""Returns a handle to a previously recorded DMA trace. The returned handle
is only valid until the next call to :meth:`record` or :meth:`erase`."""
(advance_mu, ptr) = dma_retrieve(name)
return (self.epoch, advance_mu, ptr)
@kernel
def playback_handle(self, handle):
"""Replays a handle obtained with :meth:`get_handle`. Using this function
is much faster than :meth:`playback` for replaying a set of traces repeatedly,
but incurs the overhead of managing the handles onto the programmer."""
(epoch, advance_mu, ptr) = handle
if self.epoch != epoch:
raise DMAError("Invalid handle")
dma_playback(now_mu(), ptr)
delay_mu(advance_mu)

View File

@ -108,6 +108,7 @@ static mut API: &'static [(&'static str, *const ())] = &[
api!(dma_record_start = ::dma_record_start),
api!(dma_record_stop = ::dma_record_stop),
api!(dma_erase = ::dma_erase),
api!(dma_retrieve = ::dma_retrieve),
api!(dma_playback = ::dma_playback),
api!(drtio_get_channel_state = ::rtio::drtio_dbg::get_channel_state),

View File

@ -360,39 +360,43 @@ extern fn dma_erase(name: CSlice<u8>) {
send(&DmaEraseRequest { name: name });
}
extern fn dma_playback(timestamp: i64, name: CSlice<u8>) {
#[repr(C)]
struct DmaTrace {
duration: i64,
address: i32,
}
extern fn dma_retrieve(name: CSlice<u8>) -> DmaTrace {
let name = str::from_utf8(name.as_ref()).unwrap();
send(&DmaPlaybackRequest { name: name });
let (succeeded, now_advance) =
recv!(&DmaPlaybackReply { trace, duration } => unsafe {
recv!(&DmaPlaybackReply { trace, duration } => {
match trace {
Some(bytes) => {
let ptr = bytes.as_ptr() as usize;
assert!(ptr % 64 == 0);
csr::rtio_dma::base_address_write(ptr as u64);
csr::rtio_dma::time_offset_write(timestamp as u64);
csr::cri_con::selected_write(1);
csr::rtio_dma::enable_write(1);
while csr::rtio_dma::enable_read() != 0 {}
csr::cri_con::selected_write(0);
(true, duration)
}
None =>
(false, 0)
Some(bytes) => Ok(DmaTrace {
address: bytes.as_ptr() as i32,
duration: duration as i64
}),
None => Err(())
}
});
if !succeeded {
}).unwrap_or_else(|()| {
println!("DMA trace called {:?} not found", name);
raise!("DMAError",
"DMA trace not found");
}
})
}
extern fn dma_playback(timestamp: i64, ptr: i32) {
assert!(ptr % 64 == 0);
unsafe {
csr::rtio_dma::base_address_write(ptr as u64);
csr::rtio_dma::time_offset_write(timestamp as u64);
csr::cri_con::selected_write(1);
csr::rtio_dma::enable_write(1);
while csr::rtio_dma::enable_read() != 0 {}
csr::cri_con::selected_write(0);
let status = csr::rtio_dma::error_status_read();
let timestamp = csr::rtio_dma::error_timestamp_read();
let channel = csr::rtio_dma::error_channel_read();
@ -408,8 +412,6 @@ extern fn dma_playback(timestamp: i64, name: CSlice<u8>) {
"RTIO sequence error at {0} mu, channel {1}",
timestamp as i64, channel as i64, 0)
}
NOW += now_advance;
}
}

View File

@ -501,19 +501,28 @@ class _DMA(EnvExperiment):
self.set_dataset("dma_record_time", self.core.mu_to_seconds(t2 - t1))
@kernel
def replay(self):
self.core.break_realtime()
delay(100*ms)
self.core_dma.replay(self.trace_name)
@kernel
def replay_delta(self):
def playback(self, use_handle=False):
self.core.break_realtime()
delay(100*ms)
start = now_mu()
self.core_dma.replay(self.trace_name)
if use_handle:
handle = self.core_dma.get_handle(self.trace_name)
self.core_dma.playback_handle(handle)
else:
self.core_dma.playback(self.trace_name)
self.delta = now_mu() - start
@kernel
def playback_many(self, n):
self.core.break_realtime()
delay(100*ms)
handle = self.core_dma.get_handle(self.trace_name)
t1 = self.core.get_rtio_counter_mu()
for i in range(n):
self.core_dma.playback_handle(handle)
t2 = self.core.get_rtio_counter_mu()
self.set_dataset("dma_playback_time", self.core.mu_to_seconds(t2 - t1))
@kernel
def erase(self):
self.core_dma.erase(self.trace_name)
@ -524,16 +533,26 @@ class _DMA(EnvExperiment):
with self.core_dma.record(self.trace_name):
pass
@kernel
def invalidate(self, mode):
self.record()
handle = self.core_dma.get_handle(self.trace_name)
if mode == 0:
self.record()
elif mode == 1:
self.erase()
self.core_dma.playback_handle(handle)
class DMATest(ExperimentCase):
def test_dma_storage(self):
exp = self.create(_DMA)
exp.record()
exp.record() # overwrite
exp.replay()
exp.playback()
exp.erase()
with self.assertRaises(exceptions.DMAError):
exp.replay()
exp.playback()
def test_dma_nested(self):
exp = self.create(_DMA)
@ -545,28 +564,32 @@ class DMATest(ExperimentCase):
exp = self.create(_DMA)
exp.record()
get_analyzer_dump(core_host) # clear analyzer buffer
exp.replay()
dump = decode_dump(get_analyzer_dump(core_host))
self.assertEqual(len(dump.messages), 3)
self.assertIsInstance(dump.messages[-1], StoppedMessage)
self.assertIsInstance(dump.messages[0], OutputMessage)
self.assertEqual(dump.messages[0].channel, 1)
self.assertEqual(dump.messages[0].address, 0)
self.assertEqual(dump.messages[0].data, 1)
self.assertIsInstance(dump.messages[1], OutputMessage)
self.assertEqual(dump.messages[1].channel, 1)
self.assertEqual(dump.messages[1].address, 0)
self.assertEqual(dump.messages[1].data, 0)
self.assertEqual(dump.messages[1].timestamp -
dump.messages[0].timestamp, 100)
for use_handle in [False, True]:
get_analyzer_dump(core_host) # clear analyzer buffer
exp.playback(use_handle)
dump = decode_dump(get_analyzer_dump(core_host))
self.assertEqual(len(dump.messages), 3)
self.assertIsInstance(dump.messages[-1], StoppedMessage)
self.assertIsInstance(dump.messages[0], OutputMessage)
self.assertEqual(dump.messages[0].channel, 1)
self.assertEqual(dump.messages[0].address, 0)
self.assertEqual(dump.messages[0].data, 1)
self.assertIsInstance(dump.messages[1], OutputMessage)
self.assertEqual(dump.messages[1].channel, 1)
self.assertEqual(dump.messages[1].address, 0)
self.assertEqual(dump.messages[1].data, 0)
self.assertEqual(dump.messages[1].timestamp -
dump.messages[0].timestamp, 100)
def test_dma_delta(self):
exp = self.create(_DMA)
exp.record()
exp.replay_delta()
self.assertEqual(exp.delta, 200)
for use_handle in [False, True]:
exp.playback(use_handle)
self.assertEqual(exp.delta, 200)
def test_dma_record_time(self):
exp = self.create(_DMA)
@ -574,4 +597,18 @@ class DMATest(ExperimentCase):
exp.record_many(count)
dt = self.dataset_mgr.get("dma_record_time")
print("dt={}, dt/count={}".format(dt, dt/count))
self.assertLess(dt/count, 15*us)
self.assertLess(dt/count, 16*us)
def test_dma_playback_time(self):
exp = self.create(_DMA)
count = 20000
exp.playback_many(count)
dt = self.dataset_mgr.get("dma_playback_time")
print("dt={}, dt/count={}".format(dt, dt/count))
self.assertLess(dt/count, 7*us)
def test_handle_invalidation(self):
exp = self.create(_DMA)
for mode in [0, 1]:
with self.assertRaises(exceptions.DMAError):
exp.invalidate(mode)