diff --git a/artiq/frontend/artiq_devtool.py b/artiq/frontend/artiq_devtool.py index a602cdd71..8324b8aa6 100755 --- a/artiq/frontend/artiq_devtool.py +++ b/artiq/frontend/artiq_devtool.py @@ -112,11 +112,31 @@ def main(): elif action == "connect": transport = client.get_transport() - def forwarder(port): + + def forwarder(local_stream, remote_stream): + try: + while True: + r, _, _ = select.select([local_stream, remote_stream], [], []) + if local_stream in r: + data = local_stream.recv(65535) + if data == b"": + break + remote_stream.sendall(data) + if remote_stream in r: + data = remote_stream.recv(65535) + if data == b"": + break + local_stream.sendall(data) + except Exception as err: + logger.error("Cannot forward on port %s: %s", port, repr(err)) + local_stream.close() + remote_stream.close() + + def listener(port): listener = socket.socket() listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listener.bind(('localhost', port)) - listener.listen(1) + listener.listen(8) while True: local_stream, peer_addr = listener.accept() logger.info("Accepting %s:%s and opening SSH channel to %s:%s", @@ -128,31 +148,17 @@ def main(): try: remote_stream = \ transport.open_channel('direct-tcpip', (args.device, port), peer_addr) - except Exception as e: + except Exception: logger.exception("Cannot open channel on port %s", port) continue - while True: - try: - r, _, _ = select.select([local_stream, remote_stream], [], []) - if local_stream in r: - data = local_stream.recv(65535) - if data == b"": - break - remote_stream.sendall(data) - if remote_stream in r: - data = remote_stream.recv(65535) - if data == b"": - break - local_stream.sendall(data) - except Exception as e: - logger.exception("Forward error on port %s", port) - break - local_stream.close() - remote_stream.close() + + thread = threading.Thread(target=forwarder, args=(local_stream, remote_stream), + name="forward-{}".format(port), daemon=True) + thread.start() for port in (1380, 1381, 1382, 1383): - thread = threading.Thread(target=forwarder, args=(port,), - name="port-{}".format(port), daemon=True) + thread = threading.Thread(target=listener, args=(port,), + name="listen-{}".format(port), daemon=True) thread.start() logger.info("Connecting to device")