2
0
mirror of https://github.com/m-labs/artiq.git synced 2025-01-14 21:08:55 +08:00

client: integrate asyncio with happy eyeballs support

This commit is contained in:
Florian Agbuya 2024-07-03 09:12:36 +02:00 committed by Sebastien Bourdeauducq
parent 6f3322ea35
commit 688f3d9225

View File

@ -4,6 +4,7 @@ import sys
import argparse import argparse
import os import os
import socket import socket
import asyncio
import ssl import ssl
import io import io
import zipfile import zipfile
@ -41,22 +42,27 @@ def zip_unarchive(data, directory):
class Client: class Client:
def __init__(self, server, port, cafile): def __init__(self, server, port, cafile):
self.server = server
self.port = port
self.ssl_context = ssl.create_default_context(cafile=cafile) self.ssl_context = ssl.create_default_context(cafile=cafile)
self.raw_socket = socket.create_connection((server, port)) self.reader = None
self.init_websocket(server) self.writer = None
try:
self.socket = self.ssl_context.wrap_socket(self.raw_socket, server_hostname=server)
except:
self.raw_socket.close()
raise
self.fsocket = self.socket.makefile("rwb")
def init_websocket(self, server): async def connect(self):
self.raw_socket.sendall("GET / HTTP/1.1\r\nHost: {}\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n" self.reader, self.writer = await asyncio.open_connection(
.format(server).encode()) host=self.server,
port=self.port,
happy_eyeballs_delay=0.25
)
await self.init_websocket()
await self.writer.start_tls(self.ssl_context)
async def init_websocket(self):
self.writer.write("GET / HTTP/1.1\r\nHost: {}\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n"
.format(self.server).encode())
crlf_count = 0 crlf_count = 0
while crlf_count < 4: while crlf_count < 4:
char = self.raw_socket.recv(1) char = await self.reader.read(1)
if not char: if not char:
return ValueError("Connection closed during WebSocket initialization") return ValueError("Connection closed during WebSocket initialization")
if char == b"\r" or char == b"\n": if char == b"\r" or char == b"\n":
@ -64,30 +70,30 @@ class Client:
else: else:
crlf_count = 0 crlf_count = 0
def close(self): async def close(self):
self.socket.close() if self.writer:
self.raw_socket.close() await self.writer.close()
await self.writer.wait_closed()
def send_command(self, *command): async def send_command(self, *command):
self.fsocket.write((" ".join(command) + "\n").encode()) await self.writer.write((" ".join(command) + "\n").encode())
self.fsocket.flush()
def read_line(self): async def read_line(self):
return self.fsocket.readline().decode("ascii") return (await self.reader.readline()).decode("ascii")
def read_reply(self): async def read_reply(self):
return self.fsocket.readline().decode("ascii").split() return (await self.reader.readline()).decode("ascii").split()
def read_json(self): async def read_json(self):
return json.loads(self.fsocket.readline().decode("ascii")) return json.loads(await self.reader.readline().decode("ascii"))
def login(self, username, password): async def login(self, username, password):
self.send_command("LOGIN", username, password) await self.send_command("LOGIN", username, password)
return self.read_reply() == ["HELLO"] return await self.read_reply() == ["HELLO"]
def build(self, major_ver, rev, variant, log, experimental_features): async def build(self, major_ver, rev, variant, log, experimental_features):
if not variant: if not variant:
variant = self.get_single_variant(error_msg="User can build more than 1 variant - need to specify") variant = await self.get_single_variant(error_msg="User can build more than 1 variant - need to specify")
print("Building variant: {}".format(variant)) print("Building variant: {}".format(variant))
build_args = ( build_args = (
rev, rev,
@ -96,19 +102,19 @@ class Client:
major_ver, major_ver,
*experimental_features, *experimental_features,
) )
self.send_command("BUILD", *build_args) await self.send_command("BUILD", *build_args)
reply = self.read_reply()[0] reply = (await self.read_reply())[0]
if reply != "BUILDING": if reply != "BUILDING":
return reply, None return reply, None
print("Build in progress. This may take 10-15 minutes.") print("Build in progress. This may take 10-15 minutes.")
if log: if log:
line = self.read_line() line = await self.read_line()
while line != "" and line.startswith("LOG"): while line != "" and line.startswith("LOG"):
print(line[4:], end="") print(line[4:], end="")
line = self.read_line() line = await self.read_line()
reply, status = line.split() reply, status = line.split()
else: else:
reply, status = self.read_reply() reply, status = await self.read_reply()
if reply != "DONE": if reply != "DONE":
raise ValueError("Unexpected server reply: expected 'DONE', got '{}'".format(reply)) raise ValueError("Unexpected server reply: expected 'DONE', got '{}'".format(reply))
if status != "done": if status != "done":
@ -129,19 +135,19 @@ class Client:
print("Download completed.") print("Download completed.")
return "OK", contents return "OK", contents
def passwd(self, password): async def passwd(self, password):
self.send_command("PASSWD", password) await self.send_command("PASSWD", password)
return self.read_reply() == ["OK"] return (await self.read_reply()) == ["OK"]
def get_variants(self): async def get_variants(self):
self.send_command("GET_VARIANTS") await self.send_command("GET_VARIANTS")
reply = self.read_reply()[0] reply = (await self.read_reply())[0]
if reply != "OK": if reply != "OK":
raise ValueError("Unexpected server reply: expected 'OK', got '{}'".format(reply)) raise ValueError("Unexpected server reply: expected 'OK', got '{}'".format(reply))
return self.read_json() return await self.read_json()
def get_single_variant(self, error_msg): async def get_single_variant(self, error_msg):
variants = self.get_variants() variants = await self.get_variants()
if len(variants) != 1: if len(variants) != 1:
print(error_msg) print(error_msg)
table = PrettyTable() table = PrettyTable()
@ -152,17 +158,17 @@ class Client:
sys.exit(1) sys.exit(1)
return variants[0][0] return variants[0][0]
def get_json(self, variant): async def get_json(self, variant):
self.send_command("GET_JSON", variant) await self.send_command("GET_JSON", variant)
reply = self.read_reply() reply = await self.read_reply()
if reply[0] != "OK": if reply[0] != "OK":
return reply[0], None return reply[0], None
length = int(reply[1]) length = int(reply[1])
json_str = self.fsocket.read(length).decode("ascii") json_str = (await self.reader.read(length)).decode("ascii")
return "OK", json_str return "OK", json_str
def main(): async def main_async():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--server", default="afws.m-labs.hk", help="server to connect to (default: %(default)s)") parser.add_argument("--server", default="afws.m-labs.hk", help="server to connect to (default: %(default)s)")
parser.add_argument("--port", default=80, type=int, help="port to connect to (default: %(default)d)") parser.add_argument("--port", default=80, type=int, help="port to connect to (default: %(default)d)")
@ -186,6 +192,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
client = Client(args.server, args.port, args.cert) client = Client(args.server, args.port, args.cert)
await client.connect()
try: try:
if args.action == "build": if args.action == "build":
# do this before user enters password so errors are reported without unnecessary user action # do this before user enters password so errors are reported without unnecessary user action
@ -216,7 +223,7 @@ def main():
password = getpass("Current password: ") password = getpass("Current password: ")
else: else:
password = getpass() password = getpass()
if not client.login(args.username, password): if not await client.login(args.username, password):
print("Login failed") print("Login failed")
sys.exit(1) sys.exit(1)
@ -229,12 +236,12 @@ def main():
print("Passwords do not match") print("Passwords do not match")
password = getpass("New password: ") password = getpass("New password: ")
password_confirm = getpass("New password (again): ") password_confirm = getpass("New password (again): ")
if not client.passwd(password): if not await client.passwd(password):
print("Failed to change password") print("Failed to change password")
sys.exit(1) sys.exit(1)
elif args.action == "build": elif args.action == "build":
# build dir and version variables set up above # build dir and version variables set up above
result, contents = client.build(major_ver, rev, args.variant, args.log, args.experimental) result, contents = await client.build(major_ver, rev, args.variant, args.log, args.experimental)
if result != "OK": if result != "OK":
if result == "UNAUTHORIZED": if result == "UNAUTHORIZED":
print("You are not authorized to build this variant. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.") print("You are not authorized to build this variant. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.")
@ -245,7 +252,7 @@ def main():
sys.exit(1) sys.exit(1)
zip_unarchive(contents, args.directory) zip_unarchive(contents, args.directory)
elif args.action == "get_variants": elif args.action == "get_variants":
variants = client.get_variants() variants = await client.get_variants()
table = PrettyTable() table = PrettyTable()
table.field_names = ["Variant", "Expiry date"] table.field_names = ["Variant", "Expiry date"]
for variant in variants: for variant in variants:
@ -255,8 +262,8 @@ def main():
if args.variant: if args.variant:
variant = args.variant variant = args.variant
else: else:
variant = client.get_single_variant(error_msg="User can get JSON of more than 1 variant - need to specify") variant = await client.get_single_variant(error_msg="User can get JSON of more than 1 variant - need to specify")
result, json_str = client.get_json(variant) result, json_str = await client.get_json(variant)
if result != "OK": if result != "OK":
if result == "UNAUTHORIZED": if result == "UNAUTHORIZED":
print(f"You are not authorized to get JSON of variant {variant}. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.") print(f"You are not authorized to get JSON of variant {variant}. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.")
@ -272,8 +279,10 @@ def main():
else: else:
raise ValueError raise ValueError
finally: finally:
client.close() await client.close()
def main():
asyncio.run(main_async())
if __name__ == "__main__": if __name__ == "__main__":
main() main()