forked from M-Labs/artiq
client: integrate asyncio with happy eyeballs support
This commit is contained in:
parent
6f3322ea35
commit
688f3d9225
@ -4,6 +4,7 @@ import sys
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
import asyncio
|
||||||
import ssl
|
import ssl
|
||||||
import io
|
import io
|
||||||
import zipfile
|
import zipfile
|
||||||
@ -41,22 +42,27 @@ def zip_unarchive(data, directory):
|
|||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
def __init__(self, server, port, cafile):
|
def __init__(self, server, port, cafile):
|
||||||
|
self.server = server
|
||||||
|
self.port = port
|
||||||
self.ssl_context = ssl.create_default_context(cafile=cafile)
|
self.ssl_context = ssl.create_default_context(cafile=cafile)
|
||||||
self.raw_socket = socket.create_connection((server, port))
|
self.reader = None
|
||||||
self.init_websocket(server)
|
self.writer = None
|
||||||
try:
|
|
||||||
self.socket = self.ssl_context.wrap_socket(self.raw_socket, server_hostname=server)
|
|
||||||
except:
|
|
||||||
self.raw_socket.close()
|
|
||||||
raise
|
|
||||||
self.fsocket = self.socket.makefile("rwb")
|
|
||||||
|
|
||||||
def init_websocket(self, server):
|
async def connect(self):
|
||||||
self.raw_socket.sendall("GET / HTTP/1.1\r\nHost: {}\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n"
|
self.reader, self.writer = await asyncio.open_connection(
|
||||||
.format(server).encode())
|
host=self.server,
|
||||||
|
port=self.port,
|
||||||
|
happy_eyeballs_delay=0.25
|
||||||
|
)
|
||||||
|
await self.init_websocket()
|
||||||
|
await self.writer.start_tls(self.ssl_context)
|
||||||
|
|
||||||
|
async def init_websocket(self):
|
||||||
|
self.writer.write("GET / HTTP/1.1\r\nHost: {}\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n"
|
||||||
|
.format(self.server).encode())
|
||||||
crlf_count = 0
|
crlf_count = 0
|
||||||
while crlf_count < 4:
|
while crlf_count < 4:
|
||||||
char = self.raw_socket.recv(1)
|
char = await self.reader.read(1)
|
||||||
if not char:
|
if not char:
|
||||||
return ValueError("Connection closed during WebSocket initialization")
|
return ValueError("Connection closed during WebSocket initialization")
|
||||||
if char == b"\r" or char == b"\n":
|
if char == b"\r" or char == b"\n":
|
||||||
@ -64,30 +70,30 @@ class Client:
|
|||||||
else:
|
else:
|
||||||
crlf_count = 0
|
crlf_count = 0
|
||||||
|
|
||||||
def close(self):
|
async def close(self):
|
||||||
self.socket.close()
|
if self.writer:
|
||||||
self.raw_socket.close()
|
await self.writer.close()
|
||||||
|
await self.writer.wait_closed()
|
||||||
|
|
||||||
def send_command(self, *command):
|
async def send_command(self, *command):
|
||||||
self.fsocket.write((" ".join(command) + "\n").encode())
|
await self.writer.write((" ".join(command) + "\n").encode())
|
||||||
self.fsocket.flush()
|
|
||||||
|
|
||||||
def read_line(self):
|
async def read_line(self):
|
||||||
return self.fsocket.readline().decode("ascii")
|
return (await self.reader.readline()).decode("ascii")
|
||||||
|
|
||||||
def read_reply(self):
|
async def read_reply(self):
|
||||||
return self.fsocket.readline().decode("ascii").split()
|
return (await self.reader.readline()).decode("ascii").split()
|
||||||
|
|
||||||
def read_json(self):
|
async def read_json(self):
|
||||||
return json.loads(self.fsocket.readline().decode("ascii"))
|
return json.loads(await self.reader.readline().decode("ascii"))
|
||||||
|
|
||||||
def login(self, username, password):
|
async def login(self, username, password):
|
||||||
self.send_command("LOGIN", username, password)
|
await self.send_command("LOGIN", username, password)
|
||||||
return self.read_reply() == ["HELLO"]
|
return await self.read_reply() == ["HELLO"]
|
||||||
|
|
||||||
def build(self, major_ver, rev, variant, log, experimental_features):
|
async def build(self, major_ver, rev, variant, log, experimental_features):
|
||||||
if not variant:
|
if not variant:
|
||||||
variant = self.get_single_variant(error_msg="User can build more than 1 variant - need to specify")
|
variant = await self.get_single_variant(error_msg="User can build more than 1 variant - need to specify")
|
||||||
print("Building variant: {}".format(variant))
|
print("Building variant: {}".format(variant))
|
||||||
build_args = (
|
build_args = (
|
||||||
rev,
|
rev,
|
||||||
@ -96,19 +102,19 @@ class Client:
|
|||||||
major_ver,
|
major_ver,
|
||||||
*experimental_features,
|
*experimental_features,
|
||||||
)
|
)
|
||||||
self.send_command("BUILD", *build_args)
|
await self.send_command("BUILD", *build_args)
|
||||||
reply = self.read_reply()[0]
|
reply = (await self.read_reply())[0]
|
||||||
if reply != "BUILDING":
|
if reply != "BUILDING":
|
||||||
return reply, None
|
return reply, None
|
||||||
print("Build in progress. This may take 10-15 minutes.")
|
print("Build in progress. This may take 10-15 minutes.")
|
||||||
if log:
|
if log:
|
||||||
line = self.read_line()
|
line = await self.read_line()
|
||||||
while line != "" and line.startswith("LOG"):
|
while line != "" and line.startswith("LOG"):
|
||||||
print(line[4:], end="")
|
print(line[4:], end="")
|
||||||
line = self.read_line()
|
line = await self.read_line()
|
||||||
reply, status = line.split()
|
reply, status = line.split()
|
||||||
else:
|
else:
|
||||||
reply, status = self.read_reply()
|
reply, status = await self.read_reply()
|
||||||
if reply != "DONE":
|
if reply != "DONE":
|
||||||
raise ValueError("Unexpected server reply: expected 'DONE', got '{}'".format(reply))
|
raise ValueError("Unexpected server reply: expected 'DONE', got '{}'".format(reply))
|
||||||
if status != "done":
|
if status != "done":
|
||||||
@ -129,19 +135,19 @@ class Client:
|
|||||||
print("Download completed.")
|
print("Download completed.")
|
||||||
return "OK", contents
|
return "OK", contents
|
||||||
|
|
||||||
def passwd(self, password):
|
async def passwd(self, password):
|
||||||
self.send_command("PASSWD", password)
|
await self.send_command("PASSWD", password)
|
||||||
return self.read_reply() == ["OK"]
|
return (await self.read_reply()) == ["OK"]
|
||||||
|
|
||||||
def get_variants(self):
|
async def get_variants(self):
|
||||||
self.send_command("GET_VARIANTS")
|
await self.send_command("GET_VARIANTS")
|
||||||
reply = self.read_reply()[0]
|
reply = (await self.read_reply())[0]
|
||||||
if reply != "OK":
|
if reply != "OK":
|
||||||
raise ValueError("Unexpected server reply: expected 'OK', got '{}'".format(reply))
|
raise ValueError("Unexpected server reply: expected 'OK', got '{}'".format(reply))
|
||||||
return self.read_json()
|
return await self.read_json()
|
||||||
|
|
||||||
def get_single_variant(self, error_msg):
|
async def get_single_variant(self, error_msg):
|
||||||
variants = self.get_variants()
|
variants = await self.get_variants()
|
||||||
if len(variants) != 1:
|
if len(variants) != 1:
|
||||||
print(error_msg)
|
print(error_msg)
|
||||||
table = PrettyTable()
|
table = PrettyTable()
|
||||||
@ -152,17 +158,17 @@ class Client:
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
return variants[0][0]
|
return variants[0][0]
|
||||||
|
|
||||||
def get_json(self, variant):
|
async def get_json(self, variant):
|
||||||
self.send_command("GET_JSON", variant)
|
await self.send_command("GET_JSON", variant)
|
||||||
reply = self.read_reply()
|
reply = await self.read_reply()
|
||||||
if reply[0] != "OK":
|
if reply[0] != "OK":
|
||||||
return reply[0], None
|
return reply[0], None
|
||||||
length = int(reply[1])
|
length = int(reply[1])
|
||||||
json_str = self.fsocket.read(length).decode("ascii")
|
json_str = (await self.reader.read(length)).decode("ascii")
|
||||||
return "OK", json_str
|
return "OK", json_str
|
||||||
|
|
||||||
|
|
||||||
def main():
|
async def main_async():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--server", default="afws.m-labs.hk", help="server to connect to (default: %(default)s)")
|
parser.add_argument("--server", default="afws.m-labs.hk", help="server to connect to (default: %(default)s)")
|
||||||
parser.add_argument("--port", default=80, type=int, help="port to connect to (default: %(default)d)")
|
parser.add_argument("--port", default=80, type=int, help="port to connect to (default: %(default)d)")
|
||||||
@ -186,6 +192,7 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
client = Client(args.server, args.port, args.cert)
|
client = Client(args.server, args.port, args.cert)
|
||||||
|
await client.connect()
|
||||||
try:
|
try:
|
||||||
if args.action == "build":
|
if args.action == "build":
|
||||||
# do this before user enters password so errors are reported without unnecessary user action
|
# do this before user enters password so errors are reported without unnecessary user action
|
||||||
@ -216,7 +223,7 @@ def main():
|
|||||||
password = getpass("Current password: ")
|
password = getpass("Current password: ")
|
||||||
else:
|
else:
|
||||||
password = getpass()
|
password = getpass()
|
||||||
if not client.login(args.username, password):
|
if not await client.login(args.username, password):
|
||||||
print("Login failed")
|
print("Login failed")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@ -229,12 +236,12 @@ def main():
|
|||||||
print("Passwords do not match")
|
print("Passwords do not match")
|
||||||
password = getpass("New password: ")
|
password = getpass("New password: ")
|
||||||
password_confirm = getpass("New password (again): ")
|
password_confirm = getpass("New password (again): ")
|
||||||
if not client.passwd(password):
|
if not await client.passwd(password):
|
||||||
print("Failed to change password")
|
print("Failed to change password")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
elif args.action == "build":
|
elif args.action == "build":
|
||||||
# build dir and version variables set up above
|
# build dir and version variables set up above
|
||||||
result, contents = client.build(major_ver, rev, args.variant, args.log, args.experimental)
|
result, contents = await client.build(major_ver, rev, args.variant, args.log, args.experimental)
|
||||||
if result != "OK":
|
if result != "OK":
|
||||||
if result == "UNAUTHORIZED":
|
if result == "UNAUTHORIZED":
|
||||||
print("You are not authorized to build this variant. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.")
|
print("You are not authorized to build this variant. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.")
|
||||||
@ -245,7 +252,7 @@ def main():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
zip_unarchive(contents, args.directory)
|
zip_unarchive(contents, args.directory)
|
||||||
elif args.action == "get_variants":
|
elif args.action == "get_variants":
|
||||||
variants = client.get_variants()
|
variants = await client.get_variants()
|
||||||
table = PrettyTable()
|
table = PrettyTable()
|
||||||
table.field_names = ["Variant", "Expiry date"]
|
table.field_names = ["Variant", "Expiry date"]
|
||||||
for variant in variants:
|
for variant in variants:
|
||||||
@ -255,8 +262,8 @@ def main():
|
|||||||
if args.variant:
|
if args.variant:
|
||||||
variant = args.variant
|
variant = args.variant
|
||||||
else:
|
else:
|
||||||
variant = client.get_single_variant(error_msg="User can get JSON of more than 1 variant - need to specify")
|
variant = await client.get_single_variant(error_msg="User can get JSON of more than 1 variant - need to specify")
|
||||||
result, json_str = client.get_json(variant)
|
result, json_str = await client.get_json(variant)
|
||||||
if result != "OK":
|
if result != "OK":
|
||||||
if result == "UNAUTHORIZED":
|
if result == "UNAUTHORIZED":
|
||||||
print(f"You are not authorized to get JSON of variant {variant}. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.")
|
print(f"You are not authorized to get JSON of variant {variant}. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.")
|
||||||
@ -272,8 +279,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
finally:
|
finally:
|
||||||
client.close()
|
await client.close()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
asyncio.run(main_async())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
Reference in New Issue
Block a user