artiq/artiq/frontend/afws_client.py

280 lines
11 KiB
Python
Raw Normal View History

2022-02-07 14:28:00 +08:00
#!/usr/bin/env python3
import sys
import argparse
import os
import socket
import ssl
import io
import zipfile
2022-05-18 19:04:13 +08:00
import json
from prettytable import PrettyTable
2022-02-07 14:28:00 +08:00
from getpass import getpass
2022-05-18 19:04:13 +08:00
from tqdm import tqdm
2022-02-07 14:28:00 +08:00
def get_artiq_rev():
try:
import artiq
except ImportError:
return None
2022-09-15 09:15:38 +08:00
rev = artiq._version.get_rev()
if rev == "unknown":
return None
return rev
def get_artiq_major_version():
try:
import artiq
except ImportError:
return None
version = artiq._version.get_version()
return version.split(".")[0]
2022-02-07 14:28:00 +08:00
def zip_unarchive(data, directory):
buf = io.BytesIO(data)
with zipfile.ZipFile(buf) as archive:
archive.extractall(directory)
class Client:
def __init__(self, server, port, cafile):
self.ssl_context = ssl.create_default_context(cafile=cafile)
self.raw_socket = socket.create_connection((server, port))
self.init_websocket(server)
2022-02-07 14:28:00 +08:00
try:
self.socket = self.ssl_context.wrap_socket(self.raw_socket, server_hostname=server)
except:
self.raw_socket.close()
raise
self.fsocket = self.socket.makefile("rwb")
def init_websocket(self, server):
self.raw_socket.sendall("GET / HTTP/1.1\r\nHost: {}\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n"
.format(server).encode())
crlf_count = 0
while crlf_count < 4:
char = self.raw_socket.recv(1)
if not char:
return ValueError("Connection closed during WebSocket initialization")
if char == b"\r" or char == b"\n":
crlf_count += 1
else:
crlf_count = 0
2022-02-07 14:28:00 +08:00
def close(self):
self.socket.close()
self.raw_socket.close()
def send_command(self, *command):
self.fsocket.write((" ".join(command) + "\n").encode())
self.fsocket.flush()
2022-05-18 19:04:13 +08:00
def read_line(self):
return self.fsocket.readline().decode("ascii")
2022-02-07 14:28:00 +08:00
def read_reply(self):
return self.fsocket.readline().decode("ascii").split()
2022-05-18 19:04:13 +08:00
def read_json(self):
return json.loads(self.fsocket.readline().decode("ascii"))
2022-02-07 14:28:00 +08:00
def login(self, username, password):
self.send_command("LOGIN", username, password)
return self.read_reply() == ["HELLO"]
2022-09-19 16:58:41 +08:00
def build(self, major_ver, rev, variant, log, experimental_features):
2022-05-18 19:04:13 +08:00
if not variant:
2022-10-07 11:39:36 +08:00
variant = self.get_single_variant(error_msg="User can build more than 1 variant - need to specify")
2022-05-18 19:04:13 +08:00
print("Building variant: {}".format(variant))
2022-09-15 09:15:38 +08:00
build_args = (
rev,
variant,
"LOG_ENABLE" if log else "LOG_DISABLE",
major_ver,
2022-09-19 16:58:41 +08:00
*experimental_features,
2022-09-15 09:15:38 +08:00
)
self.send_command("BUILD", *build_args)
2022-02-07 14:28:00 +08:00
reply = self.read_reply()[0]
if reply != "BUILDING":
return reply, None
print("Build in progress. This may take 10-15 minutes.")
2022-05-18 19:04:13 +08:00
if log:
line = self.read_line()
while line != "" and line.startswith("LOG"):
print(line[4:], end="")
line = self.read_line()
reply, status = line.split()
else:
reply, status = self.read_reply()
2022-02-07 14:28:00 +08:00
if reply != "DONE":
raise ValueError("Unexpected server reply: expected 'DONE', got '{}'".format(reply))
if status != "done":
return status, None
print("Build completed. Downloading...")
reply, length = self.read_reply()
if reply != "PRODUCT":
raise ValueError("Unexpected server reply: expected 'PRODUCT', got '{}'".format(reply))
2022-05-18 19:04:13 +08:00
length = int(length)
contents = bytearray()
with tqdm(total=length, unit="iB", unit_scale=True, unit_divisor=1024) as progress_bar:
total = 0
while total != length:
chunk_len = min(4096, length-total)
contents += self.fsocket.read(chunk_len)
total += chunk_len
progress_bar.update(chunk_len)
2022-02-07 14:28:00 +08:00
print("Download completed.")
return "OK", contents
def passwd(self, password):
self.send_command("PASSWD", password)
return self.read_reply() == ["OK"]
2022-05-18 19:04:13 +08:00
def get_variants(self):
self.send_command("GET_VARIANTS")
reply = self.read_reply()[0]
if reply != "OK":
raise ValueError("Unexpected server reply: expected 'OK', got '{}'".format(reply))
return self.read_json()
2022-02-07 14:28:00 +08:00
2022-10-07 11:39:36 +08:00
def get_single_variant(self, error_msg):
variants = self.get_variants()
if len(variants) != 1:
print(error_msg)
table = PrettyTable()
table.field_names = ["Variant", "Expiry date"]
for variant in variants:
table.add_row(variant)
2022-10-07 11:39:36 +08:00
print(table)
sys.exit(1)
return variants[0][0]
def get_json(self, variant):
self.send_command("GET_JSON", variant)
reply = self.read_reply()
if reply[0] != "OK":
return reply[0], None
length = int(reply[1])
json_str = self.fsocket.read(length).decode("ascii")
return "OK", json_str
2022-02-07 14:28:00 +08:00
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--server", default="afws.m-labs.hk", help="server to connect to (default: %(default)s)")
parser.add_argument("--port", default=80, type=int, help="port to connect to (default: %(default)d)")
parser.add_argument("--cert", default=None, help="SSL certificate file used to authenticate server (default: use system certificates)")
2022-02-07 14:28:00 +08:00
parser.add_argument("username", help="user name for logging into AFWS")
action = parser.add_subparsers(dest="action")
action.required = True
act_build = action.add_parser("build", help="build and download firmware")
2022-09-15 09:15:38 +08:00
act_build.add_argument("--major-ver", default=None, help="ARTIQ major version")
2022-02-07 14:28:00 +08:00
act_build.add_argument("--rev", default=None, help="revision to build (default: currently installed ARTIQ revision)")
2022-05-18 19:04:13 +08:00
act_build.add_argument("--log", action="store_true", help="Display the build log")
2022-09-19 16:58:41 +08:00
act_build.add_argument("--experimental", action="append", default=[], help="enable an experimental feature (can be repeatedly specified to enable multiple features)")
2022-02-07 14:28:00 +08:00
act_build.add_argument("directory", help="output directory")
2022-05-18 19:04:13 +08:00
act_build.add_argument("variant", nargs="?", default=None, help="variant to build (can be omitted if user is authorised to build only one)")
2022-02-07 14:28:00 +08:00
act_passwd = action.add_parser("passwd", help="change password")
2022-05-18 19:04:13 +08:00
act_get_variants = action.add_parser("get_variants", help="get available variants and expiry dates")
2022-10-07 11:39:36 +08:00
act_get_json = action.add_parser("get_json", help="get JSON description file of variant")
act_get_json.add_argument("variant", nargs="?", default=None, help="variant to get (can be omitted if user is authorised to build only one)")
act_get_json.add_argument("-o", "--out", default=None, help="output JSON file")
act_get_json.add_argument("-f", "--force", action="store_true", help="overwrite file if it already exists")
2022-02-07 14:28:00 +08:00
args = parser.parse_args()
client = Client(args.server, args.port, args.cert)
2022-02-07 14:28:00 +08:00
try:
if args.action == "build":
# do this before user enters password so errors are reported without unnecessary user action
2022-02-07 14:28:00 +08:00
try:
os.mkdir(args.directory)
except FileExistsError:
try:
if any(os.scandir(args.directory)):
print("Output directory already exists and is not empty. Please remove it and try again.")
sys.exit(1)
except NotADirectoryError:
print("A file with the same name as the output directory already exists. Please remove it and try again.")
2022-02-07 14:28:00 +08:00
sys.exit(1)
2022-09-15 09:15:38 +08:00
major_ver = args.major_ver
if major_ver is None:
major_ver = get_artiq_major_version()
if major_ver is None:
print("Unable to determine currently installed ARTIQ major version. Specify manually using --major-ver.")
sys.exit(1)
2022-02-07 14:28:00 +08:00
rev = args.rev
if rev is None:
rev = get_artiq_rev()
if rev is None:
print("Unable to determine currently installed ARTIQ revision. Specify manually using --rev.")
sys.exit(1)
if args.action == "passwd":
password = getpass("Current password: ")
else:
password = getpass()
if not client.login(args.username, password):
print("Login failed")
sys.exit(1)
print("Logged in successfully.")
if args.action == "passwd":
print("Password must made of alphanumeric characters (a-z, A-Z, 0-9) and be at least 8 characters long.")
password = getpass("New password: ")
password_confirm = getpass("New password (again): ")
while password != password_confirm:
print("Passwords do not match")
password = getpass("New password: ")
password_confirm = getpass("New password (again): ")
if not client.passwd(password):
print("Failed to change password")
sys.exit(1)
elif args.action == "build":
# build dir and version variables set up above
2022-09-19 16:58:41 +08:00
result, contents = client.build(major_ver, rev, args.variant, args.log, args.experimental)
2022-02-07 14:28:00 +08:00
if result != "OK":
if result == "UNAUTHORIZED":
print("You are not authorized to build this variant. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.")
2022-05-18 19:04:13 +08:00
elif result == "TOOMANY":
print("Too many builds in a queue. Please wait for others to finish.")
2022-02-07 14:28:00 +08:00
else:
print("Build failed: {}".format(result))
sys.exit(1)
zip_unarchive(contents, args.directory)
2022-05-18 19:04:13 +08:00
elif args.action == "get_variants":
variants = client.get_variants()
2022-05-18 19:04:13 +08:00
table = PrettyTable()
table.field_names = ["Variant", "Expiry date"]
for variant in variants:
table.add_row(variant)
2022-05-18 19:04:13 +08:00
print(table)
2022-10-07 11:39:36 +08:00
elif args.action == "get_json":
if args.variant:
variant = args.variant
else:
variant = client.get_single_variant(error_msg="User can get JSON of more than 1 variant - need to specify")
result, json_str = client.get_json(variant)
if result != "OK":
if result == "UNAUTHORIZED":
print(f"You are not authorized to get JSON of variant {variant}. Your firmware subscription may have expired. Contact helpdesk\x40m-labs.hk.")
sys.exit(1)
if args.out:
if not args.force and os.path.exists(args.out):
print(f"File {args.out} already exists. You can use -f to overwrite the existing file.")
sys.exit(1)
with open(args.out, "w") as f:
f.write(json_str)
else:
print(json_str)
2022-02-07 14:28:00 +08:00
else:
raise ValueError
finally:
client.close()
if __name__ == "__main__":
main()