artiq_flash: cache transferred artifacts.

This commit is contained in:
whitequark 2018-01-19 08:28:04 +00:00
parent 06388e21b7
commit 4fd236d234
4 changed files with 133 additions and 117 deletions

View File

@ -15,7 +15,10 @@ import os
import shutil import shutil
import re import re
from artiq.tools import verbosity_args, init_logger, logger, SSHClient from artiq.tools import verbosity_args, init_logger
from artiq.remoting import SSHClient
logger = logging.getLogger(__name__)
def get_argparser(): def get_argparser():
@ -98,7 +101,7 @@ def main():
fuser = client.spawn_command(fuser_args) fuser = client.spawn_command(fuser_args)
fuser_file = fuser.makefile('r') fuser_file = fuser.makefile('r')
fuser_match = re.search(r"\((.+?)\)", fuser_file.readline()) fuser_match = re.search(r"\((.+?)\)", fuser_file.readline())
if fuser_match.group(1) == os.getenv("USER"): if fuser_match and fuser_match.group(1) == os.getenv("USER"):
logger.info("Lock already acquired by {}".format(os.getenv("USER"))) logger.info("Lock already acquired by {}".format(os.getenv("USER")))
flock_acquired = True flock_acquired = True
return return
@ -156,19 +159,22 @@ def main():
logger.info("Resetting device") logger.info("Resetting device")
flash("start") flash("start")
elif action == "flash" or action == "flash+log": elif action == "flash":
lock()
logger.info("Flashing and booting firmware")
flash("proxy", "bootloader", "firmware", "start")
elif action == "flash+log":
lock() lock()
logger.info("Flashing firmware") logger.info("Flashing firmware")
flash("proxy", "bootloader", "firmware") flash("proxy", "bootloader", "firmware")
logger.info("Booting firmware")
if action == "flash+log":
flterm = client.spawn_command(["flterm", serial, "--output-only"]) flterm = client.spawn_command(["flterm", serial, "--output-only"])
logger.info("Booting firmware")
flash("start") flash("start")
client.drain(flterm) client.drain(flterm)
else:
flash("start")
elif action == "connect": elif action == "connect":
lock() lock()

View File

@ -9,9 +9,11 @@ import re
from functools import partial from functools import partial
from artiq import __artiq_dir__ as artiq_dir from artiq import __artiq_dir__ as artiq_dir
from artiq.tools import verbosity_args, init_logger, logger, SSHClient, LocalClient from artiq.tools import verbosity_args, init_logger
from artiq.remoting import SSHClient, LocalClient
from artiq.frontend.bit2bin import bit2bin from artiq.frontend.bit2bin import bit2bin
def get_argparser(): def get_argparser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
@ -95,8 +97,8 @@ class Programmer:
def rewriter(content): def rewriter(content):
def repl(match): def repl(match):
return "[find {}]".format(self._transfer_script(match.group(1))) return self._transfer_script(match.group(1).decode()).encode()
return re.sub(r"\[find (.+?)\]", repl, content, re.DOTALL) return re.sub(rb"\[find (.+?)\]", repl, content, re.DOTALL)
script = os.path.join(scripts_path(), script) script = os.path.join(scripts_path(), script)
return self.client.transfer_file(script, rewriter) return self.client.transfer_file(script, rewriter)
@ -178,9 +180,11 @@ class ProgrammerSayma(Programmer):
"adapter_khz 5000", "adapter_khz 5000",
"transport select jtag", "transport select jtag",
"source [find cpld/xilinx-xc7.cfg]", # tap 0, pld 0 # tap 0, pld 0
"source {}".format(self._transfer_script("cpld/xilinx-xc7.cfg")),
# tap 1, pld 1
"set CHIP XCKU040", "set CHIP XCKU040",
"source [find cpld/xilinx-xcu.cfg]", # tap 1, pld 1 "source {}".format(self._transfer_script("cpld/xilinx-xcu.cfg")),
"target create xcu.proxy testee -chain-position xcu.tap", "target create xcu.proxy testee -chain-position xcu.tap",
"set XILINX_USER1 0x02", "set XILINX_USER1 0x02",

108
artiq/remoting.py Normal file
View File

@ -0,0 +1,108 @@
import os
import sys
import logging
import tempfile
import shutil
import shlex
import subprocess
import hashlib
__all__ = ["LocalClient", "SSHClient"]
logger = logging.getLogger(__name__)
class Client:
def transfer_file(self, filename, rewriter=None):
raise NotImplementedError
def run_command(self, cmd, **kws):
raise NotImplementedError
class LocalClient(Client):
def __init__(self):
self._tmp = os.path.join(tempfile.gettempdir(), "artiq")
def transfer_file(self, filename, rewriter=None):
logger.debug("Transferring {}".format(filename))
if rewriter is None:
return filename
else:
os.makedirs(self._tmp, exist_ok=True)
with open(filename, 'rb') as local:
rewritten = rewriter(local.read())
tmp_filename = os.path.join(self._tmp, hashlib.sha1(rewritten).hexdigest())
with open(tmp_filename, 'wb') as tmp:
tmp.write(rewritten)
return tmp_filename
def run_command(self, cmd, **kws):
logger.debug("Executing {}".format(cmd))
subprocess.check_call([arg.format(tmp=self._tmp, **kws) for arg in cmd])
class SSHClient(Client):
def __init__(self, host):
self.host = host
self.ssh = None
self.sftp = None
self._tmp = "/tmp/artiq"
self._cached = []
def get_ssh(self):
if self.ssh is None:
import paramiko
logging.getLogger("paramiko").setLevel(logging.WARNING)
self.ssh = paramiko.SSHClient()
self.ssh.load_system_host_keys()
self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.ssh.connect(self.host)
logger.debug("Connecting to {}".format(self.host))
return self.ssh
def get_transport(self):
return self.get_ssh().get_transport()
def get_sftp(self):
if self.sftp is None:
self.sftp = self.get_ssh().open_sftp()
try:
self._cached = self.sftp.listdir(self._tmp)
except OSError:
self.sftp.mkdir(self._tmp)
return self.sftp
def transfer_file(self, filename, rewriter=lambda x: x):
sftp = self.get_sftp()
with open(filename, 'rb') as local:
rewritten = rewriter(local.read())
digest = hashlib.sha1(rewritten).hexdigest()
remote_filename = os.path.join(self._tmp, digest)
if digest in self._cached:
logger.debug("Using cached {}".format(filename))
else:
logger.debug("Transferring {}".format(filename))
with sftp.open(remote_filename, 'wb') as remote:
remote.write(rewritten)
return remote_filename
def spawn_command(self, cmd, get_pty=False, **kws):
chan = self.get_transport().open_session()
chan.set_combine_stderr(True)
if get_pty:
chan.get_pty()
cmd = " ".join([shlex.quote(arg.format(tmp=self._tmp, **kws)) for arg in cmd])
logger.debug("Executing {}".format(cmd))
chan.exec_command(cmd)
return chan
def drain(self, chan):
while True:
char = chan.recv(1)
if char == b"":
break
sys.stderr.write(char.decode("utf-8", errors='replace'))
def run_command(self, cmd, **kws):
self.drain(self.spawn_command(cmd, **kws))

View File

@ -3,9 +3,8 @@ import logging
import sys import sys
import asyncio import asyncio
import collections import collections
import atexit
import string import string
import os, random, tempfile, shutil, shlex, subprocess import os
import numpy as np import numpy as np
@ -253,104 +252,3 @@ def get_user_config_dir():
dir = user_config_dir("artiq", "m-labs", major) dir = user_config_dir("artiq", "m-labs", major)
os.makedirs(dir, exist_ok=True) os.makedirs(dir, exist_ok=True)
return dir return dir
class Client:
def transfer_file(self, filename, rewriter=None):
raise NotImplementedError
def run_command(self, cmd, **kws):
raise NotImplementedError
class LocalClient(Client):
def __init__(self):
tmpname = "".join([random.Random().choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
for _ in range(6)])
self.tmp = os.path.join(tempfile.gettempdir(), "artiq" + tmpname)
self._has_tmp = False
def _prepare_tmp(self):
if not self._has_tmp:
os.mkdir(self.tmp)
atexit.register(lambda: shutil.rmtree(self.tmp, ignore_errors=True))
self._has_tmp = True
def transfer_file(self, filename, rewriter=None):
logger.debug("Transferring {}".format(filename))
if rewriter is None:
return filename
else:
tmp_filename = os.path.join(self.tmp, filename.replace(os.sep, "_"))
with open(filename) as local:
self._prepare_tmp()
with open(tmp_filename, 'w') as tmp:
tmp.write(rewriter(local.read()))
return tmp_filename
def run_command(self, cmd, **kws):
logger.debug("Executing {}".format(cmd))
subprocess.check_call([arg.format(tmp=self.tmp, **kws) for arg in cmd])
class SSHClient(Client):
def __init__(self, host):
self.host = host
self.ssh = None
self.sftp = None
tmpname = "".join([random.Random().choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
for _ in range(6)])
self.tmp = "/tmp/artiq" + tmpname
def get_ssh(self):
if self.ssh is None:
import paramiko
logging.getLogger("paramiko").setLevel(logging.WARNING)
self.ssh = paramiko.SSHClient()
self.ssh.load_system_host_keys()
self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.ssh.connect(self.host)
logger.debug("Connecting to {}".format(self.host))
return self.ssh
def get_transport(self):
return self.get_ssh().get_transport()
def get_sftp(self):
if self.sftp is None:
self.sftp = self.get_ssh().open_sftp()
self.sftp.mkdir(self.tmp)
atexit.register(lambda: self.run_command(["rm", "-rf", "{tmp}"]))
return self.sftp
def transfer_file(self, filename, rewriter=None):
remote_filename = "{}/{}".format(self.tmp, filename.replace("/", "_"))
logger.debug("Transferring {}".format(filename))
if rewriter is None:
self.get_sftp().put(filename, remote_filename)
else:
with open(filename) as local:
with self.get_sftp().open(remote_filename, 'w') as remote:
remote.write(rewriter(local.read()))
return remote_filename
def spawn_command(self, cmd, get_pty=False, **kws):
chan = self.get_transport().open_session()
chan.set_combine_stderr(True)
if get_pty:
chan.get_pty()
cmd = " ".join([shlex.quote(arg.format(tmp=self.tmp, **kws)) for arg in cmd])
logger.debug("Executing {}".format(cmd))
chan.exec_command(cmd)
return chan
def drain(self, chan):
while True:
char = chan.recv(1)
if char == b"":
break
sys.stderr.write(char.decode("utf-8", errors='replace'))
def run_command(self, cmd, **kws):
self.drain(self.spawn_command(cmd, **kws))