artiq_flash: implement flash read functionality.

This commit is contained in:
whitequark 2018-01-27 16:26:02 +00:00
parent 0aacdb0458
commit 0b9c551962
2 changed files with 80 additions and 32 deletions

View File

@ -118,7 +118,7 @@ class Programmer:
return re.sub(rb"\[find (.+?)\]", repl, content, re.DOTALL)
script = os.path.join(scripts_path(), script)
return self._client.transfer_file(script, rewriter)
return self._client.upload(script, rewriter)
def add_flash_bank(self, name, tap, index):
add_commands(self._board_script,
@ -133,7 +133,7 @@ class Programmer:
return
self._loaded[pld] = bitfile
bitfile = self._client.transfer_file(bitfile)
bitfile = self._client.upload(bitfile)
add_commands(self._script,
"pld load {pld} {filename}",
pld=pld, filename=bitfile)
@ -141,11 +141,11 @@ class Programmer:
def load_proxy(self):
raise NotImplementedError
def flash_binary(self, bankname, address, filename):
def write_binary(self, bankname, address, filename):
self.load_proxy()
size = os.path.getsize(filename)
filename = self._client.transfer_file(filename)
filename = self._client.upload(filename)
add_commands(self._script,
"flash probe {bankname}",
"flash erase_sector {bankname} {firstsector} {lastsector}",
@ -155,6 +155,15 @@ class Programmer:
firstsector=address // self._sector_size,
lastsector=(address + size - 1) // self._sector_size)
def read_binary(self, bankname, address, length, filename):
self.load_proxy()
filename = self._client.prepare_download(filename)
add_commands(self._script,
"flash probe {bankname}",
"flash read_bank {bankname} {filename} {address:#x} {length}",
bankname=bankname, filename=filename, address=address, length=length)
def start(self):
raise NotImplementedError
@ -175,6 +184,9 @@ class Programmer:
cmdline = [arg.replace("{", "{{").replace("}", "}}") for arg in cmdline]
self._client.run_command(cmdline)
self._client.download()
self._script = []
class ProgrammerXC7(Programmer):
@ -304,13 +316,13 @@ def main():
bit2bin(bit_file, bin_file)
atexit.register(lambda: os.unlink(gateware_bin))
programmer.flash_binary(*config["gateware"], gateware_bin)
programmer.write_binary(*config["gateware"], gateware_bin)
elif action == "bootloader":
bootloader_bin = artifact_path("software", "bootloader", "bootloader.bin")
programmer.flash_binary(*config["bootloader"], bootloader_bin)
programmer.write_binary(*config["bootloader"], bootloader_bin)
elif action == "storage":
storage_img = args.storage
programmer.flash_binary(*config["storage"], storage_img)
programmer.write_binary(*config["storage"], storage_img)
elif action == "firmware":
if variant == "satellite":
firmware = "satman"
@ -318,7 +330,7 @@ def main():
firmware = "runtime"
firmware_fbi = artifact_path("software", firmware, firmware + ".fbi")
programmer.flash_binary(*config["firmware"], firmware_fbi)
programmer.write_binary(*config["firmware"], firmware_fbi)
elif action == "load":
if args.target == "sayma_rtm":
gateware_bit = artifact_path("top.bit")

View File

@ -6,6 +6,7 @@ import shutil
import shlex
import subprocess
import hashlib
import random
__all__ = ["LocalClient", "SSHClient"]
@ -13,7 +14,13 @@ logger = logging.getLogger(__name__)
class Client:
def transfer_file(self, filename, rewriter=None):
def upload(self, filename, rewriter=None):
raise NotImplementedError
def prepare_download(self, filename):
raise NotImplementedError
def download(self):
raise NotImplementedError
def run_command(self, cmd, **kws):
@ -24,8 +31,8 @@ class LocalClient(Client):
def __init__(self):
self._tmp = os.path.join(tempfile.gettempdir(), "artiq")
def transfer_file(self, filename, rewriter=None):
logger.debug("Transferring {}".format(filename))
def upload(self, filename, rewriter=None):
logger.debug("Uploading {}".format(filename))
if rewriter is None:
return filename
else:
@ -37,6 +44,13 @@ class LocalClient(Client):
tmp.write(rewritten)
return tmp_filename
def prepare_download(self, filename):
logger.debug("Downloading {}".format(filename))
return filename
def download(self):
pass
def run_command(self, cmd, **kws):
logger.debug("Executing {}".format(cmd))
subprocess.check_call([arg.format(tmp=self._tmp, **kws) for arg in cmd])
@ -47,8 +61,10 @@ class SSHClient(Client):
self.host = host
self.ssh = None
self.sftp = None
self._tmp = "/tmp/artiq"
self._tmpr = "/tmp/artiq"
self._tmpl = tempfile.TemporaryDirectory(prefix="artiq")
self._cached = []
self._downloads = {}
def get_ssh(self):
if self.ssh is None:
@ -68,21 +84,22 @@ class SSHClient(Client):
if self.sftp is None:
self.sftp = self.get_ssh().open_sftp()
try:
self._cached = self.sftp.listdir(self._tmp)
self._cached = self.sftp.listdir(self._tmpr)
except OSError:
self.sftp.mkdir(self._tmp)
self.sftp.mkdir(self._tmpr)
return self.sftp
def transfer_file(self, filename, rewriter=lambda x: x):
sftp = self.get_sftp()
def upload(self, filename, rewriter=lambda x: x):
with open(filename, 'rb') as local:
rewritten = rewriter(local.read())
digest = hashlib.sha1(rewritten).hexdigest()
remote_filename = os.path.join(self._tmp, digest)
remote_filename = "{}/{}".format(self._tmpr, digest)
sftp = self.get_sftp()
if digest in self._cached:
logger.debug("Using cached {}".format(filename))
else:
logger.debug("Transferring {}".format(filename))
logger.debug("Uploading {}".format(filename))
# Avoid a race condition by writing into a temporary file
# and atomically replacing
with sftp.open(remote_filename + ".~", "wb") as remote:
@ -93,14 +110,33 @@ class SSHClient(Client):
# Either it already exists (this is OK) or something else
# happened (this isn't) and we need to re-raise
sftp.stat(remote_filename)
return remote_filename
def prepare_download(self, filename):
tmpname = "".join([random.Random().choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
for _ in range(6)])
remote_filename = "{}/{}_{}".format(self._tmpr, tmpname, filename)
_sftp = self.get_sftp()
logger.debug("Downloading {}".format(filename))
self._downloads[filename] = remote_filename
return remote_filename
def download(self):
sftp = self.get_sftp()
for filename, remote_filename in self._downloads.items():
sftp.get(remote_filename, filename)
self._downloads = {}
def spawn_command(self, cmd, get_pty=False, **kws):
chan = self.get_transport().open_session()
chan.set_combine_stderr(True)
if get_pty:
chan.get_pty()
cmd = " ".join([shlex.quote(arg.format(tmp=self._tmp, **kws)) for arg in cmd])
cmd = " ".join([shlex.quote(arg.format(tmp=self._tmpr, **kws)) for arg in cmd])
logger.debug("Executing {}".format(cmd))
chan.exec_command(cmd)
return chan