From 688f3d9225cdebe6ea6707c076c4712b4a1eb51b Mon Sep 17 00:00:00 2001 From: Florian Agbuya Date: Wed, 3 Jul 2024 09:12:36 +0200 Subject: [PATCH] client: integrate asyncio with happy eyeballs support --- artiq/frontend/afws_client.py | 119 ++++++++++++++++++---------------- 1 file changed, 64 insertions(+), 55 deletions(-) diff --git a/artiq/frontend/afws_client.py b/artiq/frontend/afws_client.py index 2faaf560b..6351f761d 100755 --- a/artiq/frontend/afws_client.py +++ b/artiq/frontend/afws_client.py @@ -4,6 +4,7 @@ import sys import argparse import os import socket +import asyncio import ssl import io import zipfile @@ -41,22 +42,27 @@ def zip_unarchive(data, directory): class Client: def __init__(self, server, port, cafile): + self.server = server + self.port = port self.ssl_context = ssl.create_default_context(cafile=cafile) - self.raw_socket = socket.create_connection((server, port)) - self.init_websocket(server) - try: - self.socket = self.ssl_context.wrap_socket(self.raw_socket, server_hostname=server) - except: - self.raw_socket.close() - raise - self.fsocket = self.socket.makefile("rwb") + self.reader = None + self.writer = None - def init_websocket(self, server): - self.raw_socket.sendall("GET / HTTP/1.1\r\nHost: {}\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n" - .format(server).encode()) + async def connect(self): + self.reader, self.writer = await asyncio.open_connection( + host=self.server, + port=self.port, + happy_eyeballs_delay=0.25 + ) + await self.init_websocket() + await self.writer.start_tls(self.ssl_context) + + async def init_websocket(self): + self.writer.write("GET / HTTP/1.1\r\nHost: {}\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n" + .format(self.server).encode()) crlf_count = 0 while crlf_count < 4: - char = self.raw_socket.recv(1) + char = await self.reader.read(1) if not char: return ValueError("Connection closed during WebSocket initialization") if char == b"\r" or char == b"\n": @@ -64,30 +70,30 @@ class Client: else: crlf_count = 0 - def close(self): - self.socket.close() - self.raw_socket.close() + async def close(self): + if self.writer: + await self.writer.close() + await self.writer.wait_closed() - def send_command(self, *command): - self.fsocket.write((" ".join(command) + "\n").encode()) - self.fsocket.flush() + async def send_command(self, *command): + await self.writer.write((" ".join(command) + "\n").encode()) - def read_line(self): - return self.fsocket.readline().decode("ascii") + async def read_line(self): + return (await self.reader.readline()).decode("ascii") - def read_reply(self): - return self.fsocket.readline().decode("ascii").split() + async def read_reply(self): + return (await self.reader.readline()).decode("ascii").split() - def read_json(self): - return json.loads(self.fsocket.readline().decode("ascii")) + async def read_json(self): + return json.loads(await self.reader.readline().decode("ascii")) - def login(self, username, password): - self.send_command("LOGIN", username, password) - return self.read_reply() == ["HELLO"] + async def login(self, username, password): + await self.send_command("LOGIN", username, password) + return await self.read_reply() == ["HELLO"] - def build(self, major_ver, rev, variant, log, experimental_features): + async def build(self, major_ver, rev, variant, log, experimental_features): if not variant: - variant = self.get_single_variant(error_msg="User can build more than 1 variant - need to specify") + variant = await self.get_single_variant(error_msg="User can build more than 1 variant - need to specify") print("Building variant: {}".format(variant)) build_args = ( rev, @@ -96,19 +102,19 @@ class Client: major_ver, *experimental_features, ) - self.send_command("BUILD", *build_args) - reply = self.read_reply()[0] + await self.send_command("BUILD", *build_args) + reply = (await self.read_reply())[0] if reply != "BUILDING": return reply, None print("Build in progress. This may take 10-15 minutes.") if log: - line = self.read_line() + line = await self.read_line() while line != "" and line.startswith("LOG"): print(line[4:], end="") - line = self.read_line() + line = await self.read_line() reply, status = line.split() else: - reply, status = self.read_reply() + reply, status = await self.read_reply() if reply != "DONE": raise ValueError("Unexpected server reply: expected 'DONE', got '{}'".format(reply)) if status != "done": @@ -129,19 +135,19 @@ class Client: print("Download completed.") return "OK", contents - def passwd(self, password): - self.send_command("PASSWD", password) - return self.read_reply() == ["OK"] + async def passwd(self, password): + await self.send_command("PASSWD", password) + return (await self.read_reply()) == ["OK"] - def get_variants(self): - self.send_command("GET_VARIANTS") - reply = self.read_reply()[0] + async def get_variants(self): + await self.send_command("GET_VARIANTS") + reply = (await self.read_reply())[0] if reply != "OK": raise ValueError("Unexpected server reply: expected 'OK', got '{}'".format(reply)) - return self.read_json() + return await self.read_json() - def get_single_variant(self, error_msg): - variants = self.get_variants() + async def get_single_variant(self, error_msg): + variants = await self.get_variants() if len(variants) != 1: print(error_msg) table = PrettyTable() @@ -152,17 +158,17 @@ class Client: sys.exit(1) return variants[0][0] - def get_json(self, variant): - self.send_command("GET_JSON", variant) - reply = self.read_reply() + async def get_json(self, variant): + await self.send_command("GET_JSON", variant) + reply = await self.read_reply() if reply[0] != "OK": return reply[0], None length = int(reply[1]) - json_str = self.fsocket.read(length).decode("ascii") + json_str = (await self.reader.read(length)).decode("ascii") return "OK", json_str -def main(): +async def main_async(): parser = argparse.ArgumentParser() parser.add_argument("--server", default="afws.m-labs.hk", help="server to connect to (default: %(default)s)") parser.add_argument("--port", default=80, type=int, help="port to connect to (default: %(default)d)") @@ -186,6 +192,7 @@ def main(): args = parser.parse_args() client = Client(args.server, args.port, args.cert) + await client.connect() try: if args.action == "build": # do this before user enters password so errors are reported without unnecessary user action @@ -216,7 +223,7 @@ def main(): password = getpass("Current password: ") else: password = getpass() - if not client.login(args.username, password): + if not await client.login(args.username, password): print("Login failed") sys.exit(1) @@ -229,12 +236,12 @@ def main(): print("Passwords do not match") password = getpass("New password: ") password_confirm = getpass("New password (again): ") - if not client.passwd(password): + if not await client.passwd(password): print("Failed to change password") sys.exit(1) elif args.action == "build": # build dir and version variables set up above - result, contents = client.build(major_ver, rev, args.variant, args.log, args.experimental) + result, contents = await client.build(major_ver, rev, args.variant, args.log, args.experimental) if result != "OK": if result == "UNAUTHORIZED": print("You are not authorized to build this variant. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.") @@ -245,7 +252,7 @@ def main(): sys.exit(1) zip_unarchive(contents, args.directory) elif args.action == "get_variants": - variants = client.get_variants() + variants = await client.get_variants() table = PrettyTable() table.field_names = ["Variant", "Expiry date"] for variant in variants: @@ -255,8 +262,8 @@ def main(): if args.variant: variant = args.variant else: - variant = client.get_single_variant(error_msg="User can get JSON of more than 1 variant - need to specify") - result, json_str = client.get_json(variant) + variant = await client.get_single_variant(error_msg="User can get JSON of more than 1 variant - need to specify") + result, json_str = await client.get_json(variant) if result != "OK": if result == "UNAUTHORIZED": print(f"You are not authorized to get JSON of variant {variant}. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.") @@ -272,8 +279,10 @@ def main(): else: raise ValueError finally: - client.close() + await client.close() +def main(): + asyncio.run(main_async()) if __name__ == "__main__": main()