diff --git a/artiq/frontend/artiq_flash.py b/artiq/frontend/artiq_flash.py index 4c14d9ef9..e1ced8702 100755 --- a/artiq/frontend/artiq_flash.py +++ b/artiq/frontend/artiq_flash.py @@ -118,7 +118,7 @@ class Programmer: return re.sub(rb"\[find (.+?)\]", repl, content, re.DOTALL) script = os.path.join(scripts_path(), script) - return self._client.transfer_file(script, rewriter) + return self._client.upload(script, rewriter) def add_flash_bank(self, name, tap, index): add_commands(self._board_script, @@ -133,7 +133,7 @@ class Programmer: return self._loaded[pld] = bitfile - bitfile = self._client.transfer_file(bitfile) + bitfile = self._client.upload(bitfile) add_commands(self._script, "pld load {pld} {filename}", pld=pld, filename=bitfile) @@ -141,11 +141,11 @@ class Programmer: def load_proxy(self): raise NotImplementedError - def flash_binary(self, bankname, address, filename): + def write_binary(self, bankname, address, filename): self.load_proxy() size = os.path.getsize(filename) - filename = self._client.transfer_file(filename) + filename = self._client.upload(filename) add_commands(self._script, "flash probe {bankname}", "flash erase_sector {bankname} {firstsector} {lastsector}", @@ -155,6 +155,15 @@ class Programmer: firstsector=address // self._sector_size, lastsector=(address + size - 1) // self._sector_size) + def read_binary(self, bankname, address, length, filename): + self.load_proxy() + + filename = self._client.prepare_download(filename) + add_commands(self._script, + "flash probe {bankname}", + "flash read_bank {bankname} {filename} {address:#x} {length}", + bankname=bankname, filename=filename, address=address, length=length) + def start(self): raise NotImplementedError @@ -175,6 +184,9 @@ class Programmer: cmdline = [arg.replace("{", "{{").replace("}", "}}") for arg in cmdline] self._client.run_command(cmdline) + self._client.download() + + self._script = [] class ProgrammerXC7(Programmer): @@ -304,13 +316,13 @@ def main(): bit2bin(bit_file, bin_file) atexit.register(lambda: os.unlink(gateware_bin)) - programmer.flash_binary(*config["gateware"], gateware_bin) + programmer.write_binary(*config["gateware"], gateware_bin) elif action == "bootloader": bootloader_bin = artifact_path("software", "bootloader", "bootloader.bin") - programmer.flash_binary(*config["bootloader"], bootloader_bin) + programmer.write_binary(*config["bootloader"], bootloader_bin) elif action == "storage": storage_img = args.storage - programmer.flash_binary(*config["storage"], storage_img) + programmer.write_binary(*config["storage"], storage_img) elif action == "firmware": if variant == "satellite": firmware = "satman" @@ -318,7 +330,7 @@ def main(): firmware = "runtime" firmware_fbi = artifact_path("software", firmware, firmware + ".fbi") - programmer.flash_binary(*config["firmware"], firmware_fbi) + programmer.write_binary(*config["firmware"], firmware_fbi) elif action == "load": if args.target == "sayma_rtm": gateware_bit = artifact_path("top.bit") diff --git a/artiq/remoting.py b/artiq/remoting.py index aa3c27f3f..8400b35e8 100644 --- a/artiq/remoting.py +++ b/artiq/remoting.py @@ -6,6 +6,7 @@ import shutil import shlex import subprocess import hashlib +import random __all__ = ["LocalClient", "SSHClient"] @@ -13,7 +14,13 @@ logger = logging.getLogger(__name__) class Client: - def transfer_file(self, filename, rewriter=None): + def upload(self, filename, rewriter=None): + raise NotImplementedError + + def prepare_download(self, filename): + raise NotImplementedError + + def download(self): raise NotImplementedError def run_command(self, cmd, **kws): @@ -24,8 +31,8 @@ class LocalClient(Client): def __init__(self): self._tmp = os.path.join(tempfile.gettempdir(), "artiq") - def transfer_file(self, filename, rewriter=None): - logger.debug("Transferring {}".format(filename)) + def upload(self, filename, rewriter=None): + logger.debug("Uploading {}".format(filename)) if rewriter is None: return filename else: @@ -37,6 +44,13 @@ class LocalClient(Client): tmp.write(rewritten) return tmp_filename + def prepare_download(self, filename): + logger.debug("Downloading {}".format(filename)) + return filename + + def download(self): + pass + def run_command(self, cmd, **kws): logger.debug("Executing {}".format(cmd)) subprocess.check_call([arg.format(tmp=self._tmp, **kws) for arg in cmd]) @@ -47,8 +61,10 @@ class SSHClient(Client): self.host = host self.ssh = None self.sftp = None - self._tmp = "/tmp/artiq" + self._tmpr = "/tmp/artiq" + self._tmpl = tempfile.TemporaryDirectory(prefix="artiq") self._cached = [] + self._downloads = {} def get_ssh(self): if self.ssh is None: @@ -68,39 +84,59 @@ class SSHClient(Client): if self.sftp is None: self.sftp = self.get_ssh().open_sftp() try: - self._cached = self.sftp.listdir(self._tmp) + self._cached = self.sftp.listdir(self._tmpr) except OSError: - self.sftp.mkdir(self._tmp) + self.sftp.mkdir(self._tmpr) return self.sftp - def transfer_file(self, filename, rewriter=lambda x: x): - sftp = self.get_sftp() + def upload(self, filename, rewriter=lambda x: x): with open(filename, 'rb') as local: rewritten = rewriter(local.read()) digest = hashlib.sha1(rewritten).hexdigest() - remote_filename = os.path.join(self._tmp, digest) - if digest in self._cached: - logger.debug("Using cached {}".format(filename)) - else: - logger.debug("Transferring {}".format(filename)) - # Avoid a race condition by writing into a temporary file - # and atomically replacing - with sftp.open(remote_filename + ".~", "wb") as remote: - remote.write(rewritten) - try: - sftp.rename(remote_filename + ".~", remote_filename) - except IOError: - # Either it already exists (this is OK) or something else - # happened (this isn't) and we need to re-raise - sftp.stat(remote_filename) + remote_filename = "{}/{}".format(self._tmpr, digest) + + sftp = self.get_sftp() + if digest in self._cached: + logger.debug("Using cached {}".format(filename)) + else: + logger.debug("Uploading {}".format(filename)) + # Avoid a race condition by writing into a temporary file + # and atomically replacing + with sftp.open(remote_filename + ".~", "wb") as remote: + remote.write(rewritten) + try: + sftp.rename(remote_filename + ".~", remote_filename) + except IOError: + # Either it already exists (this is OK) or something else + # happened (this isn't) and we need to re-raise + sftp.stat(remote_filename) + return remote_filename + def prepare_download(self, filename): + tmpname = "".join([random.Random().choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + for _ in range(6)]) + remote_filename = "{}/{}_{}".format(self._tmpr, tmpname, filename) + + _sftp = self.get_sftp() + logger.debug("Downloading {}".format(filename)) + self._downloads[filename] = remote_filename + + return remote_filename + + def download(self): + sftp = self.get_sftp() + for filename, remote_filename in self._downloads.items(): + sftp.get(remote_filename, filename) + + self._downloads = {} + def spawn_command(self, cmd, get_pty=False, **kws): chan = self.get_transport().open_session() chan.set_combine_stderr(True) if get_pty: chan.get_pty() - cmd = " ".join([shlex.quote(arg.format(tmp=self._tmp, **kws)) for arg in cmd]) + cmd = " ".join([shlex.quote(arg.format(tmp=self._tmpr, **kws)) for arg in cmd]) logger.debug("Executing {}".format(cmd)) chan.exec_command(cmd) return chan