forked from M-Labs/artiq
DMA: add API for a much faster replay using handles.
This commit is contained in:
parent
c6e8d5c901
commit
41c4de4556
|
@ -6,7 +6,8 @@ alone could achieve.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from artiq.language.core import syscall, kernel
|
from artiq.language.core import syscall, kernel
|
||||||
from artiq.language.types import TInt64, TStr, TNone
|
from artiq.language.types import TInt32, TInt64, TStr, TNone, TTuple
|
||||||
|
from artiq.coredevice.exceptions import DMAError
|
||||||
|
|
||||||
from numpy import int64
|
from numpy import int64
|
||||||
|
|
||||||
|
@ -24,7 +25,11 @@ def dma_erase(name: TStr) -> TNone:
|
||||||
raise NotImplementedError("syscall not simulated")
|
raise NotImplementedError("syscall not simulated")
|
||||||
|
|
||||||
@syscall
|
@syscall
|
||||||
def dma_playback(timestamp: TInt64, name: TStr) -> TNone:
|
def dma_retrieve(name: TStr) -> TTuple([TInt64, TInt32]):
|
||||||
|
raise NotImplementedError("syscall not simulated")
|
||||||
|
|
||||||
|
@syscall
|
||||||
|
def dma_playback(timestamp: TInt64, ptr: TInt32) -> TNone:
|
||||||
raise NotImplementedError("syscall not simulated")
|
raise NotImplementedError("syscall not simulated")
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,22 +71,45 @@ class CoreDMA:
|
||||||
def __init__(self, dmgr, core_device="core"):
|
def __init__(self, dmgr, core_device="core"):
|
||||||
self.core = dmgr.get(core_device)
|
self.core = dmgr.get(core_device)
|
||||||
self.recorder = DMARecordContextManager()
|
self.recorder = DMARecordContextManager()
|
||||||
|
self.epoch = 0
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def record(self, name):
|
def record(self, name):
|
||||||
"""Returns a context manager that will record a DMA trace called ``name``.
|
"""Returns a context manager that will record a DMA trace called ``name``.
|
||||||
Any previously recorded trace with the same name is overwritten.
|
Any previously recorded trace with the same name is overwritten.
|
||||||
The trace will persist across kernel switches."""
|
The trace will persist across kernel switches."""
|
||||||
|
self.epoch += 1
|
||||||
self.recorder.name = name
|
self.recorder.name = name
|
||||||
return self.recorder
|
return self.recorder
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def erase(self, name):
|
def erase(self, name):
|
||||||
"""Removes the DMA trace with the given name from storage."""
|
"""Removes the DMA trace with the given name from storage."""
|
||||||
|
self.epoch += 1
|
||||||
dma_erase(name)
|
dma_erase(name)
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def replay(self, name):
|
def playback(self, name):
|
||||||
"""Replays a previously recorded DMA trace. This function blocks until
|
"""Replays a previously recorded DMA trace. This function blocks until
|
||||||
the entire trace is submitted to the RTIO FIFOs."""
|
the entire trace is submitted to the RTIO FIFOs."""
|
||||||
dma_playback(now_mu(), name)
|
(advance_mu, ptr) = dma_retrieve(name)
|
||||||
|
dma_playback(now_mu(), ptr)
|
||||||
|
delay_mu(advance_mu)
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def get_handle(self, name):
|
||||||
|
"""Returns a handle to a previously recorded DMA trace. The returned handle
|
||||||
|
is only valid until the next call to :meth:`record` or :meth:`erase`."""
|
||||||
|
(advance_mu, ptr) = dma_retrieve(name)
|
||||||
|
return (self.epoch, advance_mu, ptr)
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def playback_handle(self, handle):
|
||||||
|
"""Replays a handle obtained with :meth:`get_handle`. Using this function
|
||||||
|
is much faster than :meth:`playback` for replaying a set of traces repeatedly,
|
||||||
|
but incurs the overhead of managing the handles onto the programmer."""
|
||||||
|
(epoch, advance_mu, ptr) = handle
|
||||||
|
if self.epoch != epoch:
|
||||||
|
raise DMAError("Invalid handle")
|
||||||
|
dma_playback(now_mu(), ptr)
|
||||||
|
delay_mu(advance_mu)
|
||||||
|
|
|
@ -108,6 +108,7 @@ static mut API: &'static [(&'static str, *const ())] = &[
|
||||||
api!(dma_record_start = ::dma_record_start),
|
api!(dma_record_start = ::dma_record_start),
|
||||||
api!(dma_record_stop = ::dma_record_stop),
|
api!(dma_record_stop = ::dma_record_stop),
|
||||||
api!(dma_erase = ::dma_erase),
|
api!(dma_erase = ::dma_erase),
|
||||||
|
api!(dma_retrieve = ::dma_retrieve),
|
||||||
api!(dma_playback = ::dma_playback),
|
api!(dma_playback = ::dma_playback),
|
||||||
|
|
||||||
api!(drtio_get_channel_state = ::rtio::drtio_dbg::get_channel_state),
|
api!(drtio_get_channel_state = ::rtio::drtio_dbg::get_channel_state),
|
||||||
|
|
|
@ -360,39 +360,43 @@ extern fn dma_erase(name: CSlice<u8>) {
|
||||||
send(&DmaEraseRequest { name: name });
|
send(&DmaEraseRequest { name: name });
|
||||||
}
|
}
|
||||||
|
|
||||||
extern fn dma_playback(timestamp: i64, name: CSlice<u8>) {
|
#[repr(C)]
|
||||||
|
struct DmaTrace {
|
||||||
|
duration: i64,
|
||||||
|
address: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
extern fn dma_retrieve(name: CSlice<u8>) -> DmaTrace {
|
||||||
let name = str::from_utf8(name.as_ref()).unwrap();
|
let name = str::from_utf8(name.as_ref()).unwrap();
|
||||||
|
|
||||||
send(&DmaPlaybackRequest { name: name });
|
send(&DmaPlaybackRequest { name: name });
|
||||||
let (succeeded, now_advance) =
|
recv!(&DmaPlaybackReply { trace, duration } => {
|
||||||
recv!(&DmaPlaybackReply { trace, duration } => unsafe {
|
|
||||||
match trace {
|
match trace {
|
||||||
Some(bytes) => {
|
Some(bytes) => Ok(DmaTrace {
|
||||||
let ptr = bytes.as_ptr() as usize;
|
address: bytes.as_ptr() as i32,
|
||||||
assert!(ptr % 64 == 0);
|
duration: duration as i64
|
||||||
|
}),
|
||||||
csr::rtio_dma::base_address_write(ptr as u64);
|
None => Err(())
|
||||||
csr::rtio_dma::time_offset_write(timestamp as u64);
|
|
||||||
|
|
||||||
csr::cri_con::selected_write(1);
|
|
||||||
csr::rtio_dma::enable_write(1);
|
|
||||||
while csr::rtio_dma::enable_read() != 0 {}
|
|
||||||
csr::cri_con::selected_write(0);
|
|
||||||
|
|
||||||
(true, duration)
|
|
||||||
}
|
|
||||||
None =>
|
|
||||||
(false, 0)
|
|
||||||
}
|
}
|
||||||
});
|
}).unwrap_or_else(|()| {
|
||||||
|
|
||||||
if !succeeded {
|
|
||||||
println!("DMA trace called {:?} not found", name);
|
println!("DMA trace called {:?} not found", name);
|
||||||
raise!("DMAError",
|
raise!("DMAError",
|
||||||
"DMA trace not found");
|
"DMA trace not found");
|
||||||
}
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
extern fn dma_playback(timestamp: i64, ptr: i32) {
|
||||||
|
assert!(ptr % 64 == 0);
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
|
csr::rtio_dma::base_address_write(ptr as u64);
|
||||||
|
csr::rtio_dma::time_offset_write(timestamp as u64);
|
||||||
|
|
||||||
|
csr::cri_con::selected_write(1);
|
||||||
|
csr::rtio_dma::enable_write(1);
|
||||||
|
while csr::rtio_dma::enable_read() != 0 {}
|
||||||
|
csr::cri_con::selected_write(0);
|
||||||
|
|
||||||
let status = csr::rtio_dma::error_status_read();
|
let status = csr::rtio_dma::error_status_read();
|
||||||
let timestamp = csr::rtio_dma::error_timestamp_read();
|
let timestamp = csr::rtio_dma::error_timestamp_read();
|
||||||
let channel = csr::rtio_dma::error_channel_read();
|
let channel = csr::rtio_dma::error_channel_read();
|
||||||
|
@ -408,8 +412,6 @@ extern fn dma_playback(timestamp: i64, name: CSlice<u8>) {
|
||||||
"RTIO sequence error at {0} mu, channel {1}",
|
"RTIO sequence error at {0} mu, channel {1}",
|
||||||
timestamp as i64, channel as i64, 0)
|
timestamp as i64, channel as i64, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
NOW += now_advance;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -501,19 +501,28 @@ class _DMA(EnvExperiment):
|
||||||
self.set_dataset("dma_record_time", self.core.mu_to_seconds(t2 - t1))
|
self.set_dataset("dma_record_time", self.core.mu_to_seconds(t2 - t1))
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def replay(self):
|
def playback(self, use_handle=False):
|
||||||
self.core.break_realtime()
|
|
||||||
delay(100*ms)
|
|
||||||
self.core_dma.replay(self.trace_name)
|
|
||||||
|
|
||||||
@kernel
|
|
||||||
def replay_delta(self):
|
|
||||||
self.core.break_realtime()
|
self.core.break_realtime()
|
||||||
delay(100*ms)
|
delay(100*ms)
|
||||||
start = now_mu()
|
start = now_mu()
|
||||||
self.core_dma.replay(self.trace_name)
|
if use_handle:
|
||||||
|
handle = self.core_dma.get_handle(self.trace_name)
|
||||||
|
self.core_dma.playback_handle(handle)
|
||||||
|
else:
|
||||||
|
self.core_dma.playback(self.trace_name)
|
||||||
self.delta = now_mu() - start
|
self.delta = now_mu() - start
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def playback_many(self, n):
|
||||||
|
self.core.break_realtime()
|
||||||
|
delay(100*ms)
|
||||||
|
handle = self.core_dma.get_handle(self.trace_name)
|
||||||
|
t1 = self.core.get_rtio_counter_mu()
|
||||||
|
for i in range(n):
|
||||||
|
self.core_dma.playback_handle(handle)
|
||||||
|
t2 = self.core.get_rtio_counter_mu()
|
||||||
|
self.set_dataset("dma_playback_time", self.core.mu_to_seconds(t2 - t1))
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def erase(self):
|
def erase(self):
|
||||||
self.core_dma.erase(self.trace_name)
|
self.core_dma.erase(self.trace_name)
|
||||||
|
@ -524,16 +533,26 @@ class _DMA(EnvExperiment):
|
||||||
with self.core_dma.record(self.trace_name):
|
with self.core_dma.record(self.trace_name):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def invalidate(self, mode):
|
||||||
|
self.record()
|
||||||
|
handle = self.core_dma.get_handle(self.trace_name)
|
||||||
|
if mode == 0:
|
||||||
|
self.record()
|
||||||
|
elif mode == 1:
|
||||||
|
self.erase()
|
||||||
|
self.core_dma.playback_handle(handle)
|
||||||
|
|
||||||
|
|
||||||
class DMATest(ExperimentCase):
|
class DMATest(ExperimentCase):
|
||||||
def test_dma_storage(self):
|
def test_dma_storage(self):
|
||||||
exp = self.create(_DMA)
|
exp = self.create(_DMA)
|
||||||
exp.record()
|
exp.record()
|
||||||
exp.record() # overwrite
|
exp.record() # overwrite
|
||||||
exp.replay()
|
exp.playback()
|
||||||
exp.erase()
|
exp.erase()
|
||||||
with self.assertRaises(exceptions.DMAError):
|
with self.assertRaises(exceptions.DMAError):
|
||||||
exp.replay()
|
exp.playback()
|
||||||
|
|
||||||
def test_dma_nested(self):
|
def test_dma_nested(self):
|
||||||
exp = self.create(_DMA)
|
exp = self.create(_DMA)
|
||||||
|
@ -545,28 +564,32 @@ class DMATest(ExperimentCase):
|
||||||
|
|
||||||
exp = self.create(_DMA)
|
exp = self.create(_DMA)
|
||||||
exp.record()
|
exp.record()
|
||||||
get_analyzer_dump(core_host) # clear analyzer buffer
|
|
||||||
exp.replay()
|
|
||||||
|
|
||||||
dump = decode_dump(get_analyzer_dump(core_host))
|
for use_handle in [False, True]:
|
||||||
self.assertEqual(len(dump.messages), 3)
|
get_analyzer_dump(core_host) # clear analyzer buffer
|
||||||
self.assertIsInstance(dump.messages[-1], StoppedMessage)
|
exp.playback(use_handle)
|
||||||
self.assertIsInstance(dump.messages[0], OutputMessage)
|
|
||||||
self.assertEqual(dump.messages[0].channel, 1)
|
dump = decode_dump(get_analyzer_dump(core_host))
|
||||||
self.assertEqual(dump.messages[0].address, 0)
|
self.assertEqual(len(dump.messages), 3)
|
||||||
self.assertEqual(dump.messages[0].data, 1)
|
self.assertIsInstance(dump.messages[-1], StoppedMessage)
|
||||||
self.assertIsInstance(dump.messages[1], OutputMessage)
|
self.assertIsInstance(dump.messages[0], OutputMessage)
|
||||||
self.assertEqual(dump.messages[1].channel, 1)
|
self.assertEqual(dump.messages[0].channel, 1)
|
||||||
self.assertEqual(dump.messages[1].address, 0)
|
self.assertEqual(dump.messages[0].address, 0)
|
||||||
self.assertEqual(dump.messages[1].data, 0)
|
self.assertEqual(dump.messages[0].data, 1)
|
||||||
self.assertEqual(dump.messages[1].timestamp -
|
self.assertIsInstance(dump.messages[1], OutputMessage)
|
||||||
dump.messages[0].timestamp, 100)
|
self.assertEqual(dump.messages[1].channel, 1)
|
||||||
|
self.assertEqual(dump.messages[1].address, 0)
|
||||||
|
self.assertEqual(dump.messages[1].data, 0)
|
||||||
|
self.assertEqual(dump.messages[1].timestamp -
|
||||||
|
dump.messages[0].timestamp, 100)
|
||||||
|
|
||||||
def test_dma_delta(self):
|
def test_dma_delta(self):
|
||||||
exp = self.create(_DMA)
|
exp = self.create(_DMA)
|
||||||
exp.record()
|
exp.record()
|
||||||
exp.replay_delta()
|
|
||||||
self.assertEqual(exp.delta, 200)
|
for use_handle in [False, True]:
|
||||||
|
exp.playback(use_handle)
|
||||||
|
self.assertEqual(exp.delta, 200)
|
||||||
|
|
||||||
def test_dma_record_time(self):
|
def test_dma_record_time(self):
|
||||||
exp = self.create(_DMA)
|
exp = self.create(_DMA)
|
||||||
|
@ -574,4 +597,18 @@ class DMATest(ExperimentCase):
|
||||||
exp.record_many(count)
|
exp.record_many(count)
|
||||||
dt = self.dataset_mgr.get("dma_record_time")
|
dt = self.dataset_mgr.get("dma_record_time")
|
||||||
print("dt={}, dt/count={}".format(dt, dt/count))
|
print("dt={}, dt/count={}".format(dt, dt/count))
|
||||||
self.assertLess(dt/count, 15*us)
|
self.assertLess(dt/count, 16*us)
|
||||||
|
|
||||||
|
def test_dma_playback_time(self):
|
||||||
|
exp = self.create(_DMA)
|
||||||
|
count = 20000
|
||||||
|
exp.playback_many(count)
|
||||||
|
dt = self.dataset_mgr.get("dma_playback_time")
|
||||||
|
print("dt={}, dt/count={}".format(dt, dt/count))
|
||||||
|
self.assertLess(dt/count, 7*us)
|
||||||
|
|
||||||
|
def test_handle_invalidation(self):
|
||||||
|
exp = self.create(_DMA)
|
||||||
|
for mode in [0, 1]:
|
||||||
|
with self.assertRaises(exceptions.DMAError):
|
||||||
|
exp.invalidate(mode)
|
||||||
|
|
Loading…
Reference in New Issue