1
0
forked from M-Labs/artiq

Upgrade smoltcp 0.6.0 -> 0.8.0

Main changes:
Deal with interfaces now being generic over mediums, update interface name
and initialisation.
Interfaces now own their sockets. So we store a reference to the Interface
instead of the SocketSet in Scheduler and IO.
Sockets are no longer reference counted. We never called the function to
increase the socket's reference count, so now we just remove it where it
was previously released. This will result in the socket being dropped at
a different time, but I think that should be fine.

Tested firmware upload to the bootloader and spamming artiq_coremgmt log
calls to download the log from the firmware.

Signed-off-by: Michael Birtwell <michael.birtwell@oxionics.com>
This commit is contained in:
Michael Birtwell 2022-01-14 16:51:43 +00:00 committed by Sebastien Bourdeauducq
parent 06ad76b6ab
commit c60de48a30
7 changed files with 105 additions and 96 deletions

View File

@ -240,6 +240,12 @@ version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c75de51135344a4f8ed3cfe2720dc27736f7711989703a0b43aadf3753c55577" checksum = "c75de51135344a4f8ed3cfe2720dc27736f7711989703a0b43aadf3753c55577"
[[package]]
name = "managed"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ca88d725a0a943b096803bd34e73a4437208b6077654cc4ecb2947a5f91618d"
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.4.1" version = "2.4.1"
@ -321,7 +327,7 @@ dependencies = [
"io", "io",
"log", "log",
"logger_artiq", "logger_artiq",
"managed", "managed 0.7.2",
"proto_artiq", "proto_artiq",
"riscv", "riscv",
"smoltcp", "smoltcp",
@ -365,13 +371,13 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
[[package]] [[package]]
name = "smoltcp" name = "smoltcp"
version = "0.6.0" version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fe46639fd2ec79eadf8fe719f237a7a0bd4dac5d957f1ca5bbdbc1c3c39e53a" checksum = "d2308a1657c8db1f5b4993bab4e620bdbe5623bd81f254cf60326767bb243237"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"byteorder", "byteorder",
"managed", "managed 0.8.0",
] ]
[[package]] [[package]]

View File

@ -16,5 +16,5 @@ build_misoc = { path = "../libbuild_misoc" }
byteorder = { version = "1.0", default-features = false } byteorder = { version = "1.0", default-features = false }
crc = { version = "1.7", default-features = false } crc = { version = "1.7", default-features = false }
board_misoc = { path = "../libboard_misoc", features = ["uart_console", "smoltcp"] } board_misoc = { path = "../libboard_misoc", features = ["uart_console", "smoltcp"] }
smoltcp = { version = "0.6.0", default-features = false, features = ["ethernet", "proto-ipv4", "proto-ipv6", "socket-tcp"] } smoltcp = { version = "0.8.0", default-features = false, features = ["medium-ethernet", "proto-ipv4", "proto-ipv6", "socket-tcp"] }
riscv = { version = "0.6.0", features = ["inline-asm"] } riscv = { version = "0.6.0", features = ["inline-asm"] }

View File

@ -18,6 +18,8 @@ use board_misoc::slave_fpga;
use board_misoc::{clock, ethmac, net_settings}; use board_misoc::{clock, ethmac, net_settings};
use board_misoc::uart_console::Console; use board_misoc::uart_console::Console;
use riscv::register::{mcause, mepc, mtval}; use riscv::register::{mcause, mepc, mtval};
use smoltcp::iface::SocketStorage;
use smoltcp::wire::HardwareAddress;
fn check_integrity() -> bool { fn check_integrity() -> bool {
extern { extern {
@ -396,6 +398,9 @@ fn network_boot() {
println!("Initializing network..."); println!("Initializing network...");
// Assuming only one socket is ever needed by the bootloader.
// The smoltcp reuses the listening socket when the connection is established.
let mut sockets = [SocketStorage::EMPTY];
let mut net_device = unsafe { ethmac::EthernetDevice::new() }; let mut net_device = unsafe { ethmac::EthernetDevice::new() };
net_device.reset_phy_if_any(); net_device.reset_phy_if_any();
@ -412,15 +417,15 @@ fn network_boot() {
let mut interface = match net_addresses.ipv6_addr { let mut interface = match net_addresses.ipv6_addr {
Some(addr) => { Some(addr) => {
ip_addrs[2] = IpCidr::new(addr, 0); ip_addrs[2] = IpCidr::new(addr, 0);
smoltcp::iface::EthernetInterfaceBuilder::new(net_device) smoltcp::iface::InterfaceBuilder::new(net_device, &mut sockets[..])
.ethernet_addr(net_addresses.hardware_addr) .hardware_addr(HardwareAddress::Ethernet(net_addresses.hardware_addr))
.ip_addrs(&mut ip_addrs[..]) .ip_addrs(&mut ip_addrs[..])
.neighbor_cache(neighbor_cache) .neighbor_cache(neighbor_cache)
.finalize() .finalize()
} }
None => None =>
smoltcp::iface::EthernetInterfaceBuilder::new(net_device) smoltcp::iface::InterfaceBuilder::new(net_device, &mut sockets[..])
.ethernet_addr(net_addresses.hardware_addr) .hardware_addr(HardwareAddress::Ethernet(net_addresses.hardware_addr))
.ip_addrs(&mut ip_addrs[..2]) .ip_addrs(&mut ip_addrs[..2])
.neighbor_cache(neighbor_cache) .neighbor_cache(neighbor_cache)
.finalize() .finalize()
@ -429,14 +434,10 @@ fn network_boot() {
let mut rx_storage = [0; 4096]; let mut rx_storage = [0; 4096];
let mut tx_storage = [0; 128]; let mut tx_storage = [0; 128];
let mut socket_set_entries: [_; 1] = Default::default();
let mut sockets =
smoltcp::socket::SocketSet::new(&mut socket_set_entries[..]);
let tcp_rx_buffer = smoltcp::socket::TcpSocketBuffer::new(&mut rx_storage[..]); let tcp_rx_buffer = smoltcp::socket::TcpSocketBuffer::new(&mut rx_storage[..]);
let tcp_tx_buffer = smoltcp::socket::TcpSocketBuffer::new(&mut tx_storage[..]); let tcp_tx_buffer = smoltcp::socket::TcpSocketBuffer::new(&mut tx_storage[..]);
let tcp_socket = smoltcp::socket::TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); let tcp_socket = smoltcp::socket::TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
let tcp_handle = sockets.add(tcp_socket); let tcp_handle = interface.add_socket(tcp_socket);
let mut net_conn = NetConn::new(); let mut net_conn = NetConn::new();
let mut boot_time = None; let mut boot_time = None;
@ -446,7 +447,7 @@ fn network_boot() {
loop { loop {
let timestamp = clock::get_ms() as i64; let timestamp = clock::get_ms() as i64;
{ {
let socket = &mut *sockets.get::<smoltcp::socket::TcpSocket>(tcp_handle); let socket = &mut *interface.get_socket::<smoltcp::socket::TcpSocket>(tcp_handle);
match boot_time { match boot_time {
None => { None => {
@ -475,7 +476,7 @@ fn network_boot() {
} }
} }
match interface.poll(&mut sockets, smoltcp::time::Instant::from_millis(timestamp)) { match interface.poll(smoltcp::time::Instant::from_millis(timestamp)) {
Ok(_) => (), Ok(_) => (),
Err(smoltcp::Error::Unrecognized) => (), Err(smoltcp::Error::Unrecognized) => (),
Err(err) => println!("Network error: {}", err) Err(err) => println!("Network error: {}", err)

View File

@ -15,7 +15,7 @@ build_misoc = { path = "../libbuild_misoc" }
[dependencies] [dependencies]
byteorder = { version = "1.0", default-features = false } byteorder = { version = "1.0", default-features = false }
log = { version = "0.4", default-features = false, optional = true } log = { version = "0.4", default-features = false, optional = true }
smoltcp = { version = "0.6.0", default-features = false, optional = true } smoltcp = { version = "0.8.0", default-features = false, optional = true }
riscv = { version = "0.6.0", features = ["inline-asm"] } riscv = { version = "0.6.0", features = ["inline-asm"] }
[features] [features]

View File

@ -27,7 +27,7 @@ board_misoc = { path = "../libboard_misoc", features = ["uart_console", "smoltcp
logger_artiq = { path = "../liblogger_artiq" } logger_artiq = { path = "../liblogger_artiq" }
board_artiq = { path = "../libboard_artiq" } board_artiq = { path = "../libboard_artiq" }
proto_artiq = { path = "../libproto_artiq", features = ["log", "alloc"] } proto_artiq = { path = "../libproto_artiq", features = ["log", "alloc"] }
smoltcp = { version = "0.6.0", default-features = false, features = ["alloc", "ethernet", "proto-ipv4", "proto-ipv6", "socket-tcp"] } smoltcp = { version = "0.8.0", default-features = false, features = ["alloc", "medium-ethernet", "proto-ipv4", "proto-ipv6", "socket-tcp"] }
riscv = { version = "0.6.0", features = ["inline-asm"] } riscv = { version = "0.6.0", features = ["inline-asm"] }
[dependencies.fringe] [dependencies.fringe]

View File

@ -27,7 +27,7 @@ extern crate riscv;
use core::cell::RefCell; use core::cell::RefCell;
use core::convert::TryFrom; use core::convert::TryFrom;
use smoltcp::wire::IpCidr; use smoltcp::wire::{IpCidr, HardwareAddress};
use board_misoc::{csr, ident, clock, spiflash, config, net_settings, pmp, boot}; use board_misoc::{csr, ident, clock, spiflash, config, net_settings, pmp, boot};
#[cfg(has_ethmac)] #[cfg(has_ethmac)]
@ -123,38 +123,33 @@ fn startup() {
net_device.reset_phy_if_any(); net_device.reset_phy_if_any();
let net_device = { let net_device = {
use smoltcp::time::Instant; use smoltcp::phy::Tracer;
use smoltcp::wire::PrettyPrinter;
use smoltcp::wire::EthernetFrame;
fn net_trace_writer(timestamp: Instant, printer: PrettyPrinter<EthernetFrame<&[u8]>>) { // We can't create the function pointer as a separate variable here because the type of
print!("\x1b[37m[{:6}.{:03}s]\n{}\x1b[0m\n", // the packet argument Packet isn't accessible and rust's type inference isn't sufficient
timestamp.secs(), timestamp.millis(), printer) // to propagate in to a local var.
}
fn net_trace_silent(_timestamp: Instant, _printer: PrettyPrinter<EthernetFrame<&[u8]>>) {}
let net_trace_fn: fn(Instant, PrettyPrinter<EthernetFrame<&[u8]>>);
match config::read_str("net_trace", |r| r.map(|s| s == "1")) { match config::read_str("net_trace", |r| r.map(|s| s == "1")) {
Ok(true) => net_trace_fn = net_trace_writer, Ok(true) => Tracer::new(net_device, |timestamp, packet| {
_ => net_trace_fn = net_trace_silent print!("\x1b[37m[{:6}.{:03}s]\n{}\x1b[0m\n",
timestamp.secs(), timestamp.millis(), packet)
}),
_ => Tracer::new(net_device, |_, _| {}),
} }
smoltcp::phy::EthernetTracer::new(net_device, net_trace_fn)
}; };
let neighbor_cache = let neighbor_cache =
smoltcp::iface::NeighborCache::new(alloc::collections::btree_map::BTreeMap::new()); smoltcp::iface::NeighborCache::new(alloc::collections::btree_map::BTreeMap::new());
let net_addresses = net_settings::get_adresses(); let net_addresses = net_settings::get_adresses();
info!("network addresses: {}", net_addresses); info!("network addresses: {}", net_addresses);
let mut interface = match net_addresses.ipv6_addr { let interface = match net_addresses.ipv6_addr {
Some(addr) => { Some(addr) => {
let ip_addrs = [ let ip_addrs = [
IpCidr::new(net_addresses.ipv4_addr, 0), IpCidr::new(net_addresses.ipv4_addr, 0),
IpCidr::new(net_addresses.ipv6_ll_addr, 0), IpCidr::new(net_addresses.ipv6_ll_addr, 0),
IpCidr::new(addr, 0) IpCidr::new(addr, 0)
]; ];
smoltcp::iface::EthernetInterfaceBuilder::new(net_device) smoltcp::iface::InterfaceBuilder::new(net_device, vec![])
.ethernet_addr(net_addresses.hardware_addr) .hardware_addr(HardwareAddress::Ethernet(net_addresses.hardware_addr))
.ip_addrs(ip_addrs) .ip_addrs(ip_addrs)
.neighbor_cache(neighbor_cache) .neighbor_cache(neighbor_cache)
.finalize() .finalize()
@ -164,8 +159,8 @@ fn startup() {
IpCidr::new(net_addresses.ipv4_addr, 0), IpCidr::new(net_addresses.ipv4_addr, 0),
IpCidr::new(net_addresses.ipv6_ll_addr, 0) IpCidr::new(net_addresses.ipv6_ll_addr, 0)
]; ];
smoltcp::iface::EthernetInterfaceBuilder::new(net_device) smoltcp::iface::InterfaceBuilder::new(net_device, vec![])
.ethernet_addr(net_addresses.hardware_addr) .hardware_addr(HardwareAddress::Ethernet(net_addresses.hardware_addr))
.ip_addrs(ip_addrs) .ip_addrs(ip_addrs)
.neighbor_cache(neighbor_cache) .neighbor_cache(neighbor_cache)
.finalize() .finalize()
@ -184,7 +179,7 @@ fn startup() {
drtio_routing::interconnect_disable_all(); drtio_routing::interconnect_disable_all();
let aux_mutex = sched::Mutex::new(); let aux_mutex = sched::Mutex::new();
let mut scheduler = sched::Scheduler::new(); let mut scheduler = sched::Scheduler::new(interface);
let io = scheduler.io(); let io = scheduler.io();
rtio_mgt::startup(&io, &aux_mutex, &drtio_routing_table, &up_destinations); rtio_mgt::startup(&io, &aux_mutex, &drtio_routing_table, &up_destinations);
@ -211,19 +206,7 @@ fn startup() {
let mut net_stats = ethmac::EthernetStatistics::new(); let mut net_stats = ethmac::EthernetStatistics::new();
loop { loop {
scheduler.run(); scheduler.run();
scheduler.run_network();
{
let sockets = &mut *scheduler.sockets().borrow_mut();
loop {
let timestamp = smoltcp::time::Instant::from_millis(clock::get_ms() as i64);
match interface.poll(sockets, timestamp) {
Ok(true) => (),
Ok(false) => break,
Err(smoltcp::Error::Unrecognized) => (),
Err(err) => debug!("network error: {}", err)
}
}
}
if let Some(_net_stats_diff) = net_stats.update() { if let Some(_net_stats_diff) = net_stats.update() {
debug!("ethernet mac:{}", ethmac::EthernetStatistics::new()); debug!("ethernet mac:{}", ethmac::EthernetStatistics::new());

View File

@ -9,11 +9,13 @@ use fringe::generator::{Generator, Yielder, State as GeneratorState};
use smoltcp::time::Duration; use smoltcp::time::Duration;
use smoltcp::Error as NetworkError; use smoltcp::Error as NetworkError;
use smoltcp::wire::IpEndpoint; use smoltcp::wire::IpEndpoint;
use smoltcp::socket::{SocketHandle, SocketRef}; use smoltcp::iface::{Interface, SocketHandle};
use io::{Read, Write}; use io::{Read, Write};
use board_misoc::clock; use board_misoc::clock;
use urc::Urc; use urc::Urc;
use board_misoc::ethmac::EthernetDevice;
use smoltcp::phy::Tracer;
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
pub enum Error { pub enum Error {
@ -31,8 +33,6 @@ impl From<NetworkError> for Error {
} }
} }
type SocketSet = ::smoltcp::socket::SocketSet<'static, 'static, 'static>;
#[derive(Debug)] #[derive(Debug)]
struct WaitRequest { struct WaitRequest {
event: Option<*mut dyn FnMut() -> bool>, event: Option<*mut dyn FnMut() -> bool>,
@ -59,7 +59,7 @@ impl Thread {
unsafe fn new<F>(io: &Io, stack_size: usize, f: F) -> ThreadHandle unsafe fn new<F>(io: &Io, stack_size: usize, f: F) -> ThreadHandle
where F: 'static + FnOnce(Io) + Send { where F: 'static + FnOnce(Io) + Send {
let spawned = io.spawned.clone(); let spawned = io.spawned.clone();
let sockets = io.sockets.clone(); let network = io.network.clone();
// Add a 4k stack guard to the stack of any new threads // Add a 4k stack guard to the stack of any new threads
let stack = OwnedStack::new(stack_size + 4096); let stack = OwnedStack::new(stack_size + 4096);
@ -67,8 +67,8 @@ impl Thread {
generator: Generator::unsafe_new(stack, |yielder, _| { generator: Generator::unsafe_new(stack, |yielder, _| {
f(Io { f(Io {
yielder: Some(yielder), yielder: Some(yielder),
spawned: spawned, spawned,
sockets: sockets network
}) })
}), }),
waiting_for: WaitRequest { waiting_for: WaitRequest {
@ -115,19 +115,21 @@ impl ThreadHandle {
} }
} }
type Network = Interface<'static, Tracer<EthernetDevice>>;
pub struct Scheduler { pub struct Scheduler {
threads: Vec<ThreadHandle>, threads: Vec<ThreadHandle>,
spawned: Urc<RefCell<Vec<ThreadHandle>>>, spawned: Urc<RefCell<Vec<ThreadHandle>>>,
sockets: Urc<RefCell<SocketSet>>, network: Urc<RefCell<Network>>,
run_idx: usize, run_idx: usize,
} }
impl Scheduler { impl Scheduler {
pub fn new() -> Scheduler { pub fn new(network: Network) -> Scheduler {
Scheduler { Scheduler {
threads: Vec::new(), threads: Vec::new(),
spawned: Urc::new(RefCell::new(Vec::new())), spawned: Urc::new(RefCell::new(Vec::new())),
sockets: Urc::new(RefCell::new(SocketSet::new(Vec::new()))), network: Urc::new(RefCell::new(network)),
run_idx: 0, run_idx: 0,
} }
} }
@ -136,13 +138,11 @@ impl Scheduler {
Io { Io {
yielder: None, yielder: None,
spawned: self.spawned.clone(), spawned: self.spawned.clone(),
sockets: self.sockets.clone() network: self.network.clone()
} }
} }
pub fn run(&mut self) { pub fn run(&mut self) {
self.sockets.borrow_mut().prune();
self.threads.append(&mut *self.spawned.borrow_mut()); self.threads.append(&mut *self.spawned.borrow_mut());
if self.threads.len() == 0 { return } if self.threads.len() == 0 { return }
@ -188,8 +188,17 @@ impl Scheduler {
} }
} }
pub fn sockets(&self) -> &RefCell<SocketSet> { pub fn run_network(&mut self) {
&*self.sockets let mut interface = self.network.borrow_mut();
loop {
let timestamp = smoltcp::time::Instant::from_millis(clock::get_ms() as i64);
match interface.poll(timestamp) {
Ok(true) => (),
Ok(false) => break,
Err(smoltcp::Error::Unrecognized) => (),
Err(err) => debug!("network error: {}", err)
}
}
} }
} }
@ -197,7 +206,7 @@ impl Scheduler {
pub struct Io<'a> { pub struct Io<'a> {
yielder: Option<&'a Yielder<WaitResult, WaitRequest>>, yielder: Option<&'a Yielder<WaitResult, WaitRequest>>,
spawned: Urc<RefCell<Vec<ThreadHandle>>>, spawned: Urc<RefCell<Vec<ThreadHandle>>>,
sockets: Urc<RefCell<SocketSet>>, network: Urc<RefCell<Network>>,
} }
impl<'a> Io<'a> { impl<'a> Io<'a> {
@ -291,10 +300,10 @@ impl<'a> Drop for MutexGuard<'a> {
macro_rules! until { macro_rules! until {
($socket:expr, $ty:ty, |$var:ident| $cond:expr) => ({ ($socket:expr, $ty:ty, |$var:ident| $cond:expr) => ({
let (sockets, handle) = ($socket.io.sockets.clone(), $socket.handle); let (network, handle) = ($socket.io.network.clone(), $socket.handle);
$socket.io.until(move || { $socket.io.until(move || {
let mut sockets = sockets.borrow_mut(); let mut network = network.borrow_mut();
let $var = sockets.get::<$ty>(handle); let $var = network.get_socket::<$ty>(handle);
$cond $cond
}) })
}) })
@ -316,9 +325,9 @@ impl<'a> TcpListener<'a> {
fn new_lower(io: &'a Io<'a>, buffer_size: usize) -> SocketHandle { fn new_lower(io: &'a Io<'a>, buffer_size: usize) -> SocketHandle {
let rx_buffer = vec![0; buffer_size]; let rx_buffer = vec![0; buffer_size];
let tx_buffer = vec![0; buffer_size]; let tx_buffer = vec![0; buffer_size];
io.sockets io.network
.borrow_mut() .borrow_mut()
.add(TcpSocketLower::new( .add_socket(TcpSocketLower::new(
TcpSocketBuffer::new(rx_buffer), TcpSocketBuffer::new(rx_buffer),
TcpSocketBuffer::new(tx_buffer))) TcpSocketBuffer::new(tx_buffer)))
} }
@ -333,9 +342,9 @@ impl<'a> TcpListener<'a> {
} }
fn with_lower<F, R>(&self, f: F) -> R fn with_lower<F, R>(&self, f: F) -> R
where F: FnOnce(SocketRef<TcpSocketLower>) -> R { where F: FnOnce(&mut TcpSocketLower) -> R {
let mut sockets = self.io.sockets.borrow_mut(); let mut network = self.io.network.borrow_mut();
let result = f(sockets.get(self.handle.get())); let result = f(network.get_socket(self.handle.get()));
result result
} }
@ -353,7 +362,7 @@ impl<'a> TcpListener<'a> {
pub fn listen<T: Into<IpEndpoint>>(&self, endpoint: T) -> Result<(), Error> { pub fn listen<T: Into<IpEndpoint>>(&self, endpoint: T) -> Result<(), Error> {
let endpoint = endpoint.into(); let endpoint = endpoint.into();
self.with_lower(|mut s| s.listen(endpoint)) self.with_lower(|s| s.listen(endpoint))
.map(|()| { .map(|()| {
self.endpoint.set(endpoint); self.endpoint.set(endpoint);
() ()
@ -365,10 +374,10 @@ impl<'a> TcpListener<'a> {
// We're waiting until at least one half of the connection becomes open. // We're waiting until at least one half of the connection becomes open.
// This handles the case where a remote socket immediately sends a FIN-- // This handles the case where a remote socket immediately sends a FIN--
// that still counts as accepting even though nothing may be sent. // that still counts as accepting even though nothing may be sent.
let (sockets, handle) = (self.io.sockets.clone(), self.handle.get()); let (network, handle) = (self.io.network.clone(), self.handle.get());
self.io.until(move || { self.io.until(move || {
let mut sockets = sockets.borrow_mut(); let mut network = network.borrow_mut();
let socket = sockets.get::<TcpSocketLower>(handle); let socket = network.get_socket::<TcpSocketLower>(handle);
socket.may_send() || socket.may_recv() socket.may_send() || socket.may_recv()
})?; })?;
@ -385,14 +394,14 @@ impl<'a> TcpListener<'a> {
} }
pub fn close(&self) { pub fn close(&self) {
self.with_lower(|mut s| s.close()) self.with_lower(|s| s.close())
} }
} }
impl<'a> Drop for TcpListener<'a> { impl<'a> Drop for TcpListener<'a> {
fn drop(&mut self) { fn drop(&mut self) {
self.with_lower(|mut s| s.close()); self.with_lower(|s| s.close());
self.io.sockets.borrow_mut().release(self.handle.get()) self.io.network.borrow_mut().remove_socket(self.handle.get());
} }
} }
@ -416,9 +425,9 @@ impl<'a> TcpStream<'a> {
} }
fn with_lower<F, R>(&self, f: F) -> R fn with_lower<F, R>(&self, f: F) -> R
where F: FnOnce(SocketRef<TcpSocketLower>) -> R { where F: FnOnce(&mut TcpSocketLower) -> R {
let mut sockets = self.io.sockets.borrow_mut(); let mut network = self.io.network.borrow_mut();
let result = f(sockets.get(self.handle)); let result = f(network.get_socket(self.handle));
result result
} }
@ -455,7 +464,7 @@ impl<'a> TcpStream<'a> {
} }
pub fn set_timeout(&self, value: Option<u64>) { pub fn set_timeout(&self, value: Option<u64>) {
self.with_lower(|mut s| s.set_timeout(value.map(Duration::from_millis))) self.with_lower(|s| s.set_timeout(value.map(Duration::from_millis)))
} }
pub fn keep_alive(&self) -> Option<u64> { pub fn keep_alive(&self) -> Option<u64> {
@ -463,11 +472,11 @@ impl<'a> TcpStream<'a> {
} }
pub fn set_keep_alive(&self, value: Option<u64>) { pub fn set_keep_alive(&self, value: Option<u64>) {
self.with_lower(|mut s| s.set_keep_alive(value.map(Duration::from_millis))) self.with_lower(|s| s.set_keep_alive(value.map(Duration::from_millis)))
} }
pub fn close(&self) -> Result<(), Error> { pub fn close(&self) -> Result<(), Error> {
self.with_lower(|mut s| s.close()); self.with_lower(|s| s.close());
until!(self, TcpSocketLower, |s| !s.is_open())?; until!(self, TcpSocketLower, |s| !s.is_open())?;
// right now the socket may be in TIME-WAIT state. if we don't give it a chance to send // right now the socket may be in TIME-WAIT state. if we don't give it a chance to send
// a packet, and the user code executes a loop { s.listen(); s.read(); s.close(); } // a packet, and the user code executes a loop { s.listen(); s.read(); s.close(); }
@ -481,23 +490,33 @@ impl<'a> Read for TcpStream<'a> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::ReadError> { fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::ReadError> {
// Only borrow the underlying socket for the span of the next statement. // Only borrow the underlying socket for the span of the next statement.
let result = self.with_lower(|mut s| s.recv_slice(buf)); let result = self.with_lower(|s| s.recv_slice(buf));
match result { match result {
// Slow path: we need to block until buffer is non-empty. // Slow path: we need to block until buffer is non-empty.
Ok(0) => { Ok(0) => {
until!(self, TcpSocketLower, |s| s.can_recv() || !s.may_recv())?; until!(self, TcpSocketLower, |s| s.can_recv() || !s.may_recv())?;
match self.with_lower(|mut s| s.recv_slice(buf)) { match self.with_lower(|s| s.recv_slice(buf)) {
Ok(length) => Ok(length), Ok(length) => Ok(length),
Err(NetworkError::Finished) |
Err(NetworkError::Illegal) => Ok(0), Err(NetworkError::Illegal) => Ok(0),
_ => unreachable!() Err(e) => {
panic!("Unexpected error from smoltcp: {}", e);
}
} }
} }
// Fast path: we had data in buffer. // Fast path: we had data in buffer.
Ok(length) => Ok(length), Ok(length) => Ok(length),
// We've received a fin.
Err(NetworkError::Finished) |
// Error path: the receive half of the socket is not open. // Error path: the receive half of the socket is not open.
Err(NetworkError::Illegal) => Ok(0), Err(NetworkError::Illegal) => Ok(0),
// No other error may be returned. // No other error may be returned.
Err(_) => unreachable!() Err(e) => {
// This could return Err(Error::Network(e)) rather than panic,
// but I expect that'll just cause a panic later perhaps with
// less interesting context.
panic!("Unexpected error from smoltcp: {}", e);
}
} }
} }
} }
@ -508,12 +527,12 @@ impl<'a> Write for TcpStream<'a> {
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::WriteError> { fn write(&mut self, buf: &[u8]) -> Result<usize, Self::WriteError> {
// Only borrow the underlying socket for the span of the next statement. // Only borrow the underlying socket for the span of the next statement.
let result = self.with_lower(|mut s| s.send_slice(buf)); let result = self.with_lower(|s| s.send_slice(buf));
match result { match result {
// Slow path: we need to block until buffer is non-full. // Slow path: we need to block until buffer is non-full.
Ok(0) => { Ok(0) => {
until!(self, TcpSocketLower, |s| s.can_send() || !s.may_send())?; until!(self, TcpSocketLower, |s| s.can_send() || !s.may_send())?;
match self.with_lower(|mut s| s.send_slice(buf)) { match self.with_lower(|s| s.send_slice(buf)) {
Ok(length) => Ok(length), Ok(length) => Ok(length),
Err(NetworkError::Illegal) => Ok(0), Err(NetworkError::Illegal) => Ok(0),
_ => unreachable!() _ => unreachable!()
@ -540,7 +559,7 @@ impl<'a> Write for TcpStream<'a> {
impl<'a> Drop for TcpStream<'a> { impl<'a> Drop for TcpStream<'a> {
fn drop(&mut self) { fn drop(&mut self) {
self.with_lower(|mut s| s.close()); self.with_lower(|s| s.close());
self.io.sockets.borrow_mut().release(self.handle) self.io.network.borrow_mut().remove_socket(self.handle);
} }
} }