From cab9d90d0102c2b0cecdd8010565362da2d4f07a Mon Sep 17 00:00:00 2001 From: Michael Birtwell Date: Mon, 11 Apr 2022 16:41:38 +0100 Subject: [PATCH] Use sipyco.keepalive Remove the implementation of setting keepalive settings on sockets and use the implementation from sipyco instead. Signed-off-by: Michael Birtwell --- artiq/coredevice/comm.py | 28 ---------------------------- artiq/coredevice/comm_kernel.py | 5 ++--- artiq/coredevice/comm_mgmt.py | 5 ++--- artiq/coredevice/comm_moninj.py | 12 +++++++++--- 4 files changed, 13 insertions(+), 37 deletions(-) delete mode 100644 artiq/coredevice/comm.py diff --git a/artiq/coredevice/comm.py b/artiq/coredevice/comm.py deleted file mode 100644 index fb70a59b5..000000000 --- a/artiq/coredevice/comm.py +++ /dev/null @@ -1,28 +0,0 @@ -import sys -import socket -import logging - -logger = logging.getLogger(__name__) - - -def set_keepalive(sock, after_idle, interval, max_fails): - if sys.platform.startswith("linux"): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, after_idle) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, max_fails) - elif sys.platform.startswith("win") or sys.platform.startswith("cygwin"): - # setting max_fails is not supported, typically ends up being 5 or 10 - # depending on Windows version - sock.ioctl(socket.SIO_KEEPALIVE_VALS, - (1, after_idle * 1000, interval * 1000)) - else: - logger.warning("TCP keepalive not supported on platform '%s', ignored", - sys.platform) - - -def initialize_connection(host, port): - sock = socket.create_connection((host, port)) - set_keepalive(sock, 10, 10, 3) - logger.debug("connected to %s:%d", host, port) - return sock diff --git a/artiq/coredevice/comm_kernel.py b/artiq/coredevice/comm_kernel.py index 0b5dd84b7..2d529b6f9 100644 --- a/artiq/coredevice/comm_kernel.py +++ b/artiq/coredevice/comm_kernel.py @@ -8,9 +8,8 @@ from fractions import Fraction from collections import namedtuple from artiq.coredevice import exceptions -from artiq.coredevice.comm import initialize_connection from artiq import __version__ as software_version - +from sipyco.keepalive import create_connection logger = logging.getLogger(__name__) @@ -195,7 +194,7 @@ class CommKernel: def open(self): if hasattr(self, "socket"): return - self.socket = initialize_connection(self.host, self.port) + self.socket = create_connection(self.host, self.port) self.socket.sendall(b"ARTIQ coredev\n") endian = self._read(1) if endian == b"e": diff --git a/artiq/coredevice/comm_mgmt.py b/artiq/coredevice/comm_mgmt.py index 793b44a75..539643751 100644 --- a/artiq/coredevice/comm_mgmt.py +++ b/artiq/coredevice/comm_mgmt.py @@ -2,8 +2,7 @@ from enum import Enum import logging import struct -from artiq.coredevice.comm import initialize_connection - +from sipyco.keepalive import create_connection logger = logging.getLogger(__name__) @@ -54,7 +53,7 @@ class CommMgmt: def open(self): if hasattr(self, "socket"): return - self.socket = initialize_connection(self.host, self.port) + self.socket = create_connection(self.host, self.port) self.socket.sendall(b"ARTIQ management\n") endian = self._read(1) if endian == b"e": diff --git a/artiq/coredevice/comm_moninj.py b/artiq/coredevice/comm_moninj.py index a6b95983e..b5c2ee40d 100644 --- a/artiq/coredevice/comm_moninj.py +++ b/artiq/coredevice/comm_moninj.py @@ -2,7 +2,8 @@ import asyncio import logging import struct from enum import Enum -from .comm import set_keepalive + +from sipyco.keepalive import async_open_connection __all__ = ["TTLProbe", "TTLOverride", "CommMonInj"] @@ -28,8 +29,13 @@ class CommMonInj: self.disconnect_cb = disconnect_cb async def connect(self, host, port=1383): - self._reader, self._writer = await asyncio.open_connection(host, port) - set_keepalive(self._writer.transport.get_extra_info('socket'), 1, 1, 3) + self._reader, self._writer = await async_open_connection( + host, + port, + after_idle=1, + interval=1, + max_fails=3, + ) try: self._writer.write(b"ARTIQ moninj\n")