From 4fd236d234e20b04e90a504ebccfe771a079694c Mon Sep 17 00:00:00 2001 From: whitequark Date: Fri, 19 Jan 2018 08:28:04 +0000 Subject: [PATCH] artiq_flash: cache transferred artifacts. --- artiq/frontend/artiq_devtool.py | 24 ++++--- artiq/frontend/artiq_flash.py | 14 +++-- artiq/remoting.py | 108 ++++++++++++++++++++++++++++++++ artiq/tools.py | 104 +----------------------------- 4 files changed, 133 insertions(+), 117 deletions(-) create mode 100644 artiq/remoting.py diff --git a/artiq/frontend/artiq_devtool.py b/artiq/frontend/artiq_devtool.py index c4c230ff4..527ef0bce 100755 --- a/artiq/frontend/artiq_devtool.py +++ b/artiq/frontend/artiq_devtool.py @@ -15,7 +15,10 @@ import os import shutil import re -from artiq.tools import verbosity_args, init_logger, logger, SSHClient +from artiq.tools import verbosity_args, init_logger +from artiq.remoting import SSHClient + +logger = logging.getLogger(__name__) def get_argparser(): @@ -98,7 +101,7 @@ def main(): fuser = client.spawn_command(fuser_args) fuser_file = fuser.makefile('r') fuser_match = re.search(r"\((.+?)\)", fuser_file.readline()) - if fuser_match.group(1) == os.getenv("USER"): + if fuser_match and fuser_match.group(1) == os.getenv("USER"): logger.info("Lock already acquired by {}".format(os.getenv("USER"))) flock_acquired = True return @@ -156,19 +159,22 @@ def main(): logger.info("Resetting device") flash("start") - elif action == "flash" or action == "flash+log": + elif action == "flash": + lock() + + logger.info("Flashing and booting firmware") + flash("proxy", "bootloader", "firmware", "start") + + elif action == "flash+log": lock() logger.info("Flashing firmware") flash("proxy", "bootloader", "firmware") + flterm = client.spawn_command(["flterm", serial, "--output-only"]) logger.info("Booting firmware") - if action == "flash+log": - flterm = client.spawn_command(["flterm", serial, "--output-only"]) - flash("start") - client.drain(flterm) - else: - flash("start") + flash("start") + client.drain(flterm) elif action == "connect": lock() diff --git a/artiq/frontend/artiq_flash.py b/artiq/frontend/artiq_flash.py index fbba5c5ee..e5c68b5f7 100755 --- a/artiq/frontend/artiq_flash.py +++ b/artiq/frontend/artiq_flash.py @@ -9,9 +9,11 @@ import re from functools import partial from artiq import __artiq_dir__ as artiq_dir -from artiq.tools import verbosity_args, init_logger, logger, SSHClient, LocalClient +from artiq.tools import verbosity_args, init_logger +from artiq.remoting import SSHClient, LocalClient from artiq.frontend.bit2bin import bit2bin + def get_argparser(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, @@ -95,8 +97,8 @@ class Programmer: def rewriter(content): def repl(match): - return "[find {}]".format(self._transfer_script(match.group(1))) - return re.sub(r"\[find (.+?)\]", repl, content, re.DOTALL) + return self._transfer_script(match.group(1).decode()).encode() + return re.sub(rb"\[find (.+?)\]", repl, content, re.DOTALL) script = os.path.join(scripts_path(), script) return self.client.transfer_file(script, rewriter) @@ -178,9 +180,11 @@ class ProgrammerSayma(Programmer): "adapter_khz 5000", "transport select jtag", - "source [find cpld/xilinx-xc7.cfg]", # tap 0, pld 0 + # tap 0, pld 0 + "source {}".format(self._transfer_script("cpld/xilinx-xc7.cfg")), + # tap 1, pld 1 "set CHIP XCKU040", - "source [find cpld/xilinx-xcu.cfg]", # tap 1, pld 1 + "source {}".format(self._transfer_script("cpld/xilinx-xcu.cfg")), "target create xcu.proxy testee -chain-position xcu.tap", "set XILINX_USER1 0x02", diff --git a/artiq/remoting.py b/artiq/remoting.py new file mode 100644 index 000000000..0c6a99899 --- /dev/null +++ b/artiq/remoting.py @@ -0,0 +1,108 @@ +import os +import sys +import logging +import tempfile +import shutil +import shlex +import subprocess +import hashlib + +__all__ = ["LocalClient", "SSHClient"] + +logger = logging.getLogger(__name__) + + +class Client: + def transfer_file(self, filename, rewriter=None): + raise NotImplementedError + + def run_command(self, cmd, **kws): + raise NotImplementedError + + +class LocalClient(Client): + def __init__(self): + self._tmp = os.path.join(tempfile.gettempdir(), "artiq") + + def transfer_file(self, filename, rewriter=None): + logger.debug("Transferring {}".format(filename)) + if rewriter is None: + return filename + else: + os.makedirs(self._tmp, exist_ok=True) + with open(filename, 'rb') as local: + rewritten = rewriter(local.read()) + tmp_filename = os.path.join(self._tmp, hashlib.sha1(rewritten).hexdigest()) + with open(tmp_filename, 'wb') as tmp: + tmp.write(rewritten) + return tmp_filename + + def run_command(self, cmd, **kws): + logger.debug("Executing {}".format(cmd)) + subprocess.check_call([arg.format(tmp=self._tmp, **kws) for arg in cmd]) + + +class SSHClient(Client): + def __init__(self, host): + self.host = host + self.ssh = None + self.sftp = None + self._tmp = "/tmp/artiq" + self._cached = [] + + def get_ssh(self): + if self.ssh is None: + import paramiko + logging.getLogger("paramiko").setLevel(logging.WARNING) + self.ssh = paramiko.SSHClient() + self.ssh.load_system_host_keys() + self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.ssh.connect(self.host) + logger.debug("Connecting to {}".format(self.host)) + return self.ssh + + def get_transport(self): + return self.get_ssh().get_transport() + + def get_sftp(self): + if self.sftp is None: + self.sftp = self.get_ssh().open_sftp() + try: + self._cached = self.sftp.listdir(self._tmp) + except OSError: + self.sftp.mkdir(self._tmp) + return self.sftp + + def transfer_file(self, filename, rewriter=lambda x: x): + sftp = self.get_sftp() + with open(filename, 'rb') as local: + rewritten = rewriter(local.read()) + digest = hashlib.sha1(rewritten).hexdigest() + remote_filename = os.path.join(self._tmp, digest) + if digest in self._cached: + logger.debug("Using cached {}".format(filename)) + else: + logger.debug("Transferring {}".format(filename)) + with sftp.open(remote_filename, 'wb') as remote: + remote.write(rewritten) + return remote_filename + + def spawn_command(self, cmd, get_pty=False, **kws): + chan = self.get_transport().open_session() + chan.set_combine_stderr(True) + if get_pty: + chan.get_pty() + cmd = " ".join([shlex.quote(arg.format(tmp=self._tmp, **kws)) for arg in cmd]) + logger.debug("Executing {}".format(cmd)) + chan.exec_command(cmd) + return chan + + def drain(self, chan): + while True: + char = chan.recv(1) + if char == b"": + break + sys.stderr.write(char.decode("utf-8", errors='replace')) + + def run_command(self, cmd, **kws): + self.drain(self.spawn_command(cmd, **kws)) diff --git a/artiq/tools.py b/artiq/tools.py index 9f2cf97c5..f8f4314bf 100644 --- a/artiq/tools.py +++ b/artiq/tools.py @@ -3,9 +3,8 @@ import logging import sys import asyncio import collections -import atexit import string -import os, random, tempfile, shutil, shlex, subprocess +import os import numpy as np @@ -253,104 +252,3 @@ def get_user_config_dir(): dir = user_config_dir("artiq", "m-labs", major) os.makedirs(dir, exist_ok=True) return dir - - -class Client: - def transfer_file(self, filename, rewriter=None): - raise NotImplementedError - - def run_command(self, cmd, **kws): - raise NotImplementedError - - -class LocalClient(Client): - def __init__(self): - tmpname = "".join([random.Random().choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ") - for _ in range(6)]) - self.tmp = os.path.join(tempfile.gettempdir(), "artiq" + tmpname) - self._has_tmp = False - - def _prepare_tmp(self): - if not self._has_tmp: - os.mkdir(self.tmp) - atexit.register(lambda: shutil.rmtree(self.tmp, ignore_errors=True)) - self._has_tmp = True - - def transfer_file(self, filename, rewriter=None): - logger.debug("Transferring {}".format(filename)) - if rewriter is None: - return filename - else: - tmp_filename = os.path.join(self.tmp, filename.replace(os.sep, "_")) - with open(filename) as local: - self._prepare_tmp() - with open(tmp_filename, 'w') as tmp: - tmp.write(rewriter(local.read())) - return tmp_filename - - def run_command(self, cmd, **kws): - logger.debug("Executing {}".format(cmd)) - subprocess.check_call([arg.format(tmp=self.tmp, **kws) for arg in cmd]) - - -class SSHClient(Client): - def __init__(self, host): - self.host = host - self.ssh = None - self.sftp = None - - tmpname = "".join([random.Random().choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ") - for _ in range(6)]) - self.tmp = "/tmp/artiq" + tmpname - - def get_ssh(self): - if self.ssh is None: - import paramiko - logging.getLogger("paramiko").setLevel(logging.WARNING) - self.ssh = paramiko.SSHClient() - self.ssh.load_system_host_keys() - self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - self.ssh.connect(self.host) - logger.debug("Connecting to {}".format(self.host)) - return self.ssh - - def get_transport(self): - return self.get_ssh().get_transport() - - def get_sftp(self): - if self.sftp is None: - self.sftp = self.get_ssh().open_sftp() - self.sftp.mkdir(self.tmp) - atexit.register(lambda: self.run_command(["rm", "-rf", "{tmp}"])) - return self.sftp - - def transfer_file(self, filename, rewriter=None): - remote_filename = "{}/{}".format(self.tmp, filename.replace("/", "_")) - logger.debug("Transferring {}".format(filename)) - if rewriter is None: - self.get_sftp().put(filename, remote_filename) - else: - with open(filename) as local: - with self.get_sftp().open(remote_filename, 'w') as remote: - remote.write(rewriter(local.read())) - return remote_filename - - def spawn_command(self, cmd, get_pty=False, **kws): - chan = self.get_transport().open_session() - chan.set_combine_stderr(True) - if get_pty: - chan.get_pty() - cmd = " ".join([shlex.quote(arg.format(tmp=self.tmp, **kws)) for arg in cmd]) - logger.debug("Executing {}".format(cmd)) - chan.exec_command(cmd) - return chan - - def drain(self, chan): - while True: - char = chan.recv(1) - if char == b"": - break - sys.stderr.write(char.decode("utf-8", errors='replace')) - - def run_command(self, cmd, **kws): - self.drain(self.spawn_command(cmd, **kws))