2022-02-07 14:28:00 +08:00
#!/usr/bin/env python3
import sys
import argparse
import os
import socket
import ssl
import io
import zipfile
2022-05-18 19:04:13 +08:00
import json
from prettytable import PrettyTable
2022-02-07 14:28:00 +08:00
from getpass import getpass
2022-05-18 19:04:13 +08:00
from tqdm import tqdm
2022-02-07 14:28:00 +08:00
def get_artiq_rev ( ) :
try :
import artiq
except ImportError :
return None
2022-09-15 09:15:38 +08:00
rev = artiq . _version . get_rev ( )
if rev == " unknown " :
return None
return rev
def get_artiq_major_version ( ) :
try :
import artiq
except ImportError :
return None
version = artiq . _version . get_version ( )
return version . split ( " . " ) [ 0 ]
2022-02-07 14:28:00 +08:00
def zip_unarchive ( data , directory ) :
buf = io . BytesIO ( data )
with zipfile . ZipFile ( buf ) as archive :
archive . extractall ( directory )
class Client :
def __init__ ( self , server , port , cafile ) :
self . ssl_context = ssl . create_default_context ( cafile = cafile )
self . raw_socket = socket . create_connection ( ( server , port ) )
2023-04-07 16:03:33 +08:00
self . init_websocket ( server )
2022-02-07 14:28:00 +08:00
try :
self . socket = self . ssl_context . wrap_socket ( self . raw_socket , server_hostname = server )
except :
self . raw_socket . close ( )
raise
self . fsocket = self . socket . makefile ( " rwb " )
2023-04-07 16:03:33 +08:00
def init_websocket ( self , server ) :
self . raw_socket . sendall ( " GET / HTTP/1.1 \r \n Host: {} \r \n Connection: Upgrade \r \n Upgrade: websocket \r \n \r \n "
. format ( server ) . encode ( ) )
crlf_count = 0
while crlf_count < 4 :
char = self . raw_socket . recv ( 1 )
if not char :
return ValueError ( " Connection closed during WebSocket initialization " )
if char == b " \r " or char == b " \n " :
crlf_count + = 1
else :
crlf_count = 0
2022-02-07 14:28:00 +08:00
def close ( self ) :
self . socket . close ( )
self . raw_socket . close ( )
def send_command ( self , * command ) :
self . fsocket . write ( ( " " . join ( command ) + " \n " ) . encode ( ) )
self . fsocket . flush ( )
2022-05-18 19:04:13 +08:00
def read_line ( self ) :
return self . fsocket . readline ( ) . decode ( " ascii " )
2022-02-07 14:28:00 +08:00
def read_reply ( self ) :
return self . fsocket . readline ( ) . decode ( " ascii " ) . split ( )
2022-05-18 19:04:13 +08:00
def read_json ( self ) :
return json . loads ( self . fsocket . readline ( ) . decode ( " ascii " ) )
2022-02-07 14:28:00 +08:00
def login ( self , username , password ) :
self . send_command ( " LOGIN " , username , password )
return self . read_reply ( ) == [ " HELLO " ]
2022-09-19 16:58:41 +08:00
def build ( self , major_ver , rev , variant , log , experimental_features ) :
2022-05-18 19:04:13 +08:00
if not variant :
2022-10-07 11:39:36 +08:00
variant = self . get_single_variant ( error_msg = " User can build more than 1 variant - need to specify " )
2022-05-18 19:04:13 +08:00
print ( " Building variant: {} " . format ( variant ) )
2022-09-15 09:15:38 +08:00
build_args = (
rev ,
variant ,
" LOG_ENABLE " if log else " LOG_DISABLE " ,
major_ver ,
2022-09-19 16:58:41 +08:00
* experimental_features ,
2022-09-15 09:15:38 +08:00
)
self . send_command ( " BUILD " , * build_args )
2022-02-07 14:28:00 +08:00
reply = self . read_reply ( ) [ 0 ]
if reply != " BUILDING " :
return reply , None
print ( " Build in progress. This may take 10-15 minutes. " )
2022-05-18 19:04:13 +08:00
if log :
line = self . read_line ( )
while line != " " and line . startswith ( " LOG " ) :
print ( line [ 4 : ] , end = " " )
line = self . read_line ( )
reply , status = line . split ( )
else :
reply , status = self . read_reply ( )
2022-02-07 14:28:00 +08:00
if reply != " DONE " :
raise ValueError ( " Unexpected server reply: expected ' DONE ' , got ' {} ' " . format ( reply ) )
if status != " done " :
return status , None
print ( " Build completed. Downloading... " )
reply , length = self . read_reply ( )
if reply != " PRODUCT " :
raise ValueError ( " Unexpected server reply: expected ' PRODUCT ' , got ' {} ' " . format ( reply ) )
2022-05-18 19:04:13 +08:00
length = int ( length )
contents = bytearray ( )
with tqdm ( total = length , unit = " iB " , unit_scale = True , unit_divisor = 1024 ) as progress_bar :
total = 0
while total != length :
chunk_len = min ( 4096 , length - total )
contents + = self . fsocket . read ( chunk_len )
total + = chunk_len
progress_bar . update ( chunk_len )
2022-02-07 14:28:00 +08:00
print ( " Download completed. " )
return " OK " , contents
def passwd ( self , password ) :
self . send_command ( " PASSWD " , password )
return self . read_reply ( ) == [ " OK " ]
2022-05-18 19:04:13 +08:00
def get_variants ( self ) :
self . send_command ( " GET_VARIANTS " )
reply = self . read_reply ( ) [ 0 ]
if reply != " OK " :
raise ValueError ( " Unexpected server reply: expected ' OK ' , got ' {} ' " . format ( reply ) )
return self . read_json ( )
2022-02-07 14:28:00 +08:00
2022-10-07 11:39:36 +08:00
def get_single_variant ( self , error_msg ) :
variants = self . get_variants ( )
if len ( variants ) != 1 :
print ( error_msg )
table = PrettyTable ( )
table . field_names = [ " Variant " , " Expiry date " ]
2023-11-07 14:06:31 +08:00
for variant in variants :
table . add_row ( variant )
2022-10-07 11:39:36 +08:00
print ( table )
sys . exit ( 1 )
return variants [ 0 ] [ 0 ]
def get_json ( self , variant ) :
self . send_command ( " GET_JSON " , variant )
reply = self . read_reply ( )
if reply [ 0 ] != " OK " :
return reply [ 0 ] , None
length = int ( reply [ 1 ] )
json_str = self . fsocket . read ( length ) . decode ( " ascii " )
return " OK " , json_str
2022-02-07 14:28:00 +08:00
def main ( ) :
parser = argparse . ArgumentParser ( )
2023-04-07 16:03:33 +08:00
parser . add_argument ( " --server " , default = " afws.m-labs.hk " , help = " server to connect to (default: %(default)s ) " )
parser . add_argument ( " --port " , default = 80 , type = int , help = " port to connect to (default: %(default)d ) " )
parser . add_argument ( " --cert " , default = None , help = " SSL certificate file used to authenticate server (default: use system certificates) " )
2022-02-07 14:28:00 +08:00
parser . add_argument ( " username " , help = " user name for logging into AFWS " )
action = parser . add_subparsers ( dest = " action " )
action . required = True
act_build = action . add_parser ( " build " , help = " build and download firmware " )
2022-09-15 09:15:38 +08:00
act_build . add_argument ( " --major-ver " , default = None , help = " ARTIQ major version " )
2022-02-07 14:28:00 +08:00
act_build . add_argument ( " --rev " , default = None , help = " revision to build (default: currently installed ARTIQ revision) " )
2022-05-18 19:04:13 +08:00
act_build . add_argument ( " --log " , action = " store_true " , help = " Display the build log " )
2022-09-19 16:58:41 +08:00
act_build . add_argument ( " --experimental " , action = " append " , default = [ ] , help = " enable an experimental feature (can be repeatedly specified to enable multiple features) " )
2022-02-07 14:28:00 +08:00
act_build . add_argument ( " directory " , help = " output directory " )
2022-05-18 19:04:13 +08:00
act_build . add_argument ( " variant " , nargs = " ? " , default = None , help = " variant to build (can be omitted if user is authorised to build only one) " )
2022-02-07 14:28:00 +08:00
act_passwd = action . add_parser ( " passwd " , help = " change password " )
2022-05-18 19:04:13 +08:00
act_get_variants = action . add_parser ( " get_variants " , help = " get available variants and expiry dates " )
2022-10-07 11:39:36 +08:00
act_get_json = action . add_parser ( " get_json " , help = " get JSON description file of variant " )
act_get_json . add_argument ( " variant " , nargs = " ? " , default = None , help = " variant to get (can be omitted if user is authorised to build only one) " )
act_get_json . add_argument ( " -o " , " --out " , default = None , help = " output JSON file " )
act_get_json . add_argument ( " -f " , " --force " , action = " store_true " , help = " overwrite file if it already exists " )
2022-02-07 14:28:00 +08:00
args = parser . parse_args ( )
2023-04-07 16:03:33 +08:00
client = Client ( args . server , args . port , args . cert )
2022-02-07 14:28:00 +08:00
try :
2023-04-08 16:50:15 +08:00
if args . action == " build " :
# do this before user enters password so errors are reported without unnecessary user action
2022-02-07 14:28:00 +08:00
try :
os . mkdir ( args . directory )
except FileExistsError :
2022-02-08 21:26:12 +08:00
try :
if any ( os . scandir ( args . directory ) ) :
print ( " Output directory already exists and is not empty. Please remove it and try again. " )
sys . exit ( 1 )
except NotADirectoryError :
print ( " A file with the same name as the output directory already exists. Please remove it and try again. " )
2022-02-07 14:28:00 +08:00
sys . exit ( 1 )
2022-09-15 09:15:38 +08:00
major_ver = args . major_ver
if major_ver is None :
major_ver = get_artiq_major_version ( )
if major_ver is None :
print ( " Unable to determine currently installed ARTIQ major version. Specify manually using --major-ver. " )
sys . exit ( 1 )
2022-02-07 14:28:00 +08:00
rev = args . rev
if rev is None :
rev = get_artiq_rev ( )
if rev is None :
print ( " Unable to determine currently installed ARTIQ revision. Specify manually using --rev. " )
sys . exit ( 1 )
2023-04-08 16:50:15 +08:00
if args . action == " passwd " :
password = getpass ( " Current password: " )
else :
password = getpass ( )
if not client . login ( args . username , password ) :
print ( " Login failed " )
sys . exit ( 1 )
print ( " Logged in successfully. " )
if args . action == " passwd " :
print ( " Password must made of alphanumeric characters (a-z, A-Z, 0-9) and be at least 8 characters long. " )
password = getpass ( " New password: " )
password_confirm = getpass ( " New password (again): " )
while password != password_confirm :
print ( " Passwords do not match " )
password = getpass ( " New password: " )
password_confirm = getpass ( " New password (again): " )
if not client . passwd ( password ) :
print ( " Failed to change password " )
sys . exit ( 1 )
elif args . action == " build " :
# build dir and version variables set up above
2022-09-19 16:58:41 +08:00
result , contents = client . build ( major_ver , rev , args . variant , args . log , args . experimental )
2022-02-07 14:28:00 +08:00
if result != " OK " :
if result == " UNAUTHORIZED " :
print ( " You are not authorized to build this variant. Your firmware subscription may have expired. Contact helpdesk \x40 m-labs.hk. " )
2022-05-18 19:04:13 +08:00
elif result == " TOOMANY " :
print ( " Too many builds in a queue. Please wait for others to finish. " )
2022-02-07 14:28:00 +08:00
else :
print ( " Build failed: {} " . format ( result ) )
sys . exit ( 1 )
zip_unarchive ( contents , args . directory )
2022-05-18 19:04:13 +08:00
elif args . action == " get_variants " :
2023-11-07 14:06:31 +08:00
variants = client . get_variants ( )
2022-05-18 19:04:13 +08:00
table = PrettyTable ( )
table . field_names = [ " Variant " , " Expiry date " ]
2023-11-07 14:06:31 +08:00
for variant in variants :
table . add_row ( variant )
2022-05-18 19:04:13 +08:00
print ( table )
2022-10-07 11:39:36 +08:00
elif args . action == " get_json " :
if args . variant :
variant = args . variant
else :
variant = client . get_single_variant ( error_msg = " User can get JSON of more than 1 variant - need to specify " )
result , json_str = client . get_json ( variant )
if result != " OK " :
if result == " UNAUTHORIZED " :
print ( f " You are not authorized to get JSON of variant { variant } . Your firmware subscription may have expired. Contact helpdesk \x40 m-labs.hk. " )
sys . exit ( 1 )
if args . out :
if not args . force and os . path . exists ( args . out ) :
print ( f " File { args . out } already exists. You can use -f to overwrite the existing file. " )
sys . exit ( 1 )
with open ( args . out , " w " ) as f :
f . write ( json_str )
else :
print ( json_str )
2022-02-07 14:28:00 +08:00
else :
raise ValueError
finally :
client . close ( )
if __name__ == " __main__ " :
main ( )