artiq_flash: cache transferred artifacts.

This commit is contained in:
whitequark 2018-01-19 08:28:04 +00:00
parent 06388e21b7
commit 4fd236d234
4 changed files with 133 additions and 117 deletions

View File

@ -15,7 +15,10 @@ import os
import shutil
import re
from artiq.tools import verbosity_args, init_logger, logger, SSHClient
from artiq.tools import verbosity_args, init_logger
from artiq.remoting import SSHClient
logger = logging.getLogger(__name__)
def get_argparser():
@ -98,7 +101,7 @@ def main():
fuser = client.spawn_command(fuser_args)
fuser_file = fuser.makefile('r')
fuser_match = re.search(r"\((.+?)\)", fuser_file.readline())
if fuser_match.group(1) == os.getenv("USER"):
if fuser_match and fuser_match.group(1) == os.getenv("USER"):
logger.info("Lock already acquired by {}".format(os.getenv("USER")))
flock_acquired = True
return
@ -156,19 +159,22 @@ def main():
logger.info("Resetting device")
flash("start")
elif action == "flash" or action == "flash+log":
elif action == "flash":
lock()
logger.info("Flashing and booting firmware")
flash("proxy", "bootloader", "firmware", "start")
elif action == "flash+log":
lock()
logger.info("Flashing firmware")
flash("proxy", "bootloader", "firmware")
logger.info("Booting firmware")
if action == "flash+log":
flterm = client.spawn_command(["flterm", serial, "--output-only"])
logger.info("Booting firmware")
flash("start")
client.drain(flterm)
else:
flash("start")
elif action == "connect":
lock()

View File

@ -9,9 +9,11 @@ import re
from functools import partial
from artiq import __artiq_dir__ as artiq_dir
from artiq.tools import verbosity_args, init_logger, logger, SSHClient, LocalClient
from artiq.tools import verbosity_args, init_logger
from artiq.remoting import SSHClient, LocalClient
from artiq.frontend.bit2bin import bit2bin
def get_argparser():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
@ -95,8 +97,8 @@ class Programmer:
def rewriter(content):
def repl(match):
return "[find {}]".format(self._transfer_script(match.group(1)))
return re.sub(r"\[find (.+?)\]", repl, content, re.DOTALL)
return self._transfer_script(match.group(1).decode()).encode()
return re.sub(rb"\[find (.+?)\]", repl, content, re.DOTALL)
script = os.path.join(scripts_path(), script)
return self.client.transfer_file(script, rewriter)
@ -178,9 +180,11 @@ class ProgrammerSayma(Programmer):
"adapter_khz 5000",
"transport select jtag",
"source [find cpld/xilinx-xc7.cfg]", # tap 0, pld 0
# tap 0, pld 0
"source {}".format(self._transfer_script("cpld/xilinx-xc7.cfg")),
# tap 1, pld 1
"set CHIP XCKU040",
"source [find cpld/xilinx-xcu.cfg]", # tap 1, pld 1
"source {}".format(self._transfer_script("cpld/xilinx-xcu.cfg")),
"target create xcu.proxy testee -chain-position xcu.tap",
"set XILINX_USER1 0x02",

108
artiq/remoting.py Normal file
View File

@ -0,0 +1,108 @@
import os
import sys
import logging
import tempfile
import shutil
import shlex
import subprocess
import hashlib
__all__ = ["LocalClient", "SSHClient"]
logger = logging.getLogger(__name__)
class Client:
def transfer_file(self, filename, rewriter=None):
raise NotImplementedError
def run_command(self, cmd, **kws):
raise NotImplementedError
class LocalClient(Client):
def __init__(self):
self._tmp = os.path.join(tempfile.gettempdir(), "artiq")
def transfer_file(self, filename, rewriter=None):
logger.debug("Transferring {}".format(filename))
if rewriter is None:
return filename
else:
os.makedirs(self._tmp, exist_ok=True)
with open(filename, 'rb') as local:
rewritten = rewriter(local.read())
tmp_filename = os.path.join(self._tmp, hashlib.sha1(rewritten).hexdigest())
with open(tmp_filename, 'wb') as tmp:
tmp.write(rewritten)
return tmp_filename
def run_command(self, cmd, **kws):
logger.debug("Executing {}".format(cmd))
subprocess.check_call([arg.format(tmp=self._tmp, **kws) for arg in cmd])
class SSHClient(Client):
def __init__(self, host):
self.host = host
self.ssh = None
self.sftp = None
self._tmp = "/tmp/artiq"
self._cached = []
def get_ssh(self):
if self.ssh is None:
import paramiko
logging.getLogger("paramiko").setLevel(logging.WARNING)
self.ssh = paramiko.SSHClient()
self.ssh.load_system_host_keys()
self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.ssh.connect(self.host)
logger.debug("Connecting to {}".format(self.host))
return self.ssh
def get_transport(self):
return self.get_ssh().get_transport()
def get_sftp(self):
if self.sftp is None:
self.sftp = self.get_ssh().open_sftp()
try:
self._cached = self.sftp.listdir(self._tmp)
except OSError:
self.sftp.mkdir(self._tmp)
return self.sftp
def transfer_file(self, filename, rewriter=lambda x: x):
sftp = self.get_sftp()
with open(filename, 'rb') as local:
rewritten = rewriter(local.read())
digest = hashlib.sha1(rewritten).hexdigest()
remote_filename = os.path.join(self._tmp, digest)
if digest in self._cached:
logger.debug("Using cached {}".format(filename))
else:
logger.debug("Transferring {}".format(filename))
with sftp.open(remote_filename, 'wb') as remote:
remote.write(rewritten)
return remote_filename
def spawn_command(self, cmd, get_pty=False, **kws):
chan = self.get_transport().open_session()
chan.set_combine_stderr(True)
if get_pty:
chan.get_pty()
cmd = " ".join([shlex.quote(arg.format(tmp=self._tmp, **kws)) for arg in cmd])
logger.debug("Executing {}".format(cmd))
chan.exec_command(cmd)
return chan
def drain(self, chan):
while True:
char = chan.recv(1)
if char == b"":
break
sys.stderr.write(char.decode("utf-8", errors='replace'))
def run_command(self, cmd, **kws):
self.drain(self.spawn_command(cmd, **kws))

View File

@ -3,9 +3,8 @@ import logging
import sys
import asyncio
import collections
import atexit
import string
import os, random, tempfile, shutil, shlex, subprocess
import os
import numpy as np
@ -253,104 +252,3 @@ def get_user_config_dir():
dir = user_config_dir("artiq", "m-labs", major)
os.makedirs(dir, exist_ok=True)
return dir
class Client:
def transfer_file(self, filename, rewriter=None):
raise NotImplementedError
def run_command(self, cmd, **kws):
raise NotImplementedError
class LocalClient(Client):
def __init__(self):
tmpname = "".join([random.Random().choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
for _ in range(6)])
self.tmp = os.path.join(tempfile.gettempdir(), "artiq" + tmpname)
self._has_tmp = False
def _prepare_tmp(self):
if not self._has_tmp:
os.mkdir(self.tmp)
atexit.register(lambda: shutil.rmtree(self.tmp, ignore_errors=True))
self._has_tmp = True
def transfer_file(self, filename, rewriter=None):
logger.debug("Transferring {}".format(filename))
if rewriter is None:
return filename
else:
tmp_filename = os.path.join(self.tmp, filename.replace(os.sep, "_"))
with open(filename) as local:
self._prepare_tmp()
with open(tmp_filename, 'w') as tmp:
tmp.write(rewriter(local.read()))
return tmp_filename
def run_command(self, cmd, **kws):
logger.debug("Executing {}".format(cmd))
subprocess.check_call([arg.format(tmp=self.tmp, **kws) for arg in cmd])
class SSHClient(Client):
def __init__(self, host):
self.host = host
self.ssh = None
self.sftp = None
tmpname = "".join([random.Random().choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
for _ in range(6)])
self.tmp = "/tmp/artiq" + tmpname
def get_ssh(self):
if self.ssh is None:
import paramiko
logging.getLogger("paramiko").setLevel(logging.WARNING)
self.ssh = paramiko.SSHClient()
self.ssh.load_system_host_keys()
self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.ssh.connect(self.host)
logger.debug("Connecting to {}".format(self.host))
return self.ssh
def get_transport(self):
return self.get_ssh().get_transport()
def get_sftp(self):
if self.sftp is None:
self.sftp = self.get_ssh().open_sftp()
self.sftp.mkdir(self.tmp)
atexit.register(lambda: self.run_command(["rm", "-rf", "{tmp}"]))
return self.sftp
def transfer_file(self, filename, rewriter=None):
remote_filename = "{}/{}".format(self.tmp, filename.replace("/", "_"))
logger.debug("Transferring {}".format(filename))
if rewriter is None:
self.get_sftp().put(filename, remote_filename)
else:
with open(filename) as local:
with self.get_sftp().open(remote_filename, 'w') as remote:
remote.write(rewriter(local.read()))
return remote_filename
def spawn_command(self, cmd, get_pty=False, **kws):
chan = self.get_transport().open_session()
chan.set_combine_stderr(True)
if get_pty:
chan.get_pty()
cmd = " ".join([shlex.quote(arg.format(tmp=self.tmp, **kws)) for arg in cmd])
logger.debug("Executing {}".format(cmd))
chan.exec_command(cmd)
return chan
def drain(self, chan):
while True:
char = chan.recv(1)
if char == b"":
break
sys.stderr.write(char.decode("utf-8", errors='replace'))
def run_command(self, cmd, **kws):
self.drain(self.spawn_command(cmd, **kws))