artiq_flash: implement flash read functionality.

This commit is contained in:
whitequark 2018-01-27 16:26:02 +00:00
parent 0aacdb0458
commit 0b9c551962
2 changed files with 80 additions and 32 deletions

View File

@ -118,7 +118,7 @@ class Programmer:
return re.sub(rb"\[find (.+?)\]", repl, content, re.DOTALL) return re.sub(rb"\[find (.+?)\]", repl, content, re.DOTALL)
script = os.path.join(scripts_path(), script) script = os.path.join(scripts_path(), script)
return self._client.transfer_file(script, rewriter) return self._client.upload(script, rewriter)
def add_flash_bank(self, name, tap, index): def add_flash_bank(self, name, tap, index):
add_commands(self._board_script, add_commands(self._board_script,
@ -133,7 +133,7 @@ class Programmer:
return return
self._loaded[pld] = bitfile self._loaded[pld] = bitfile
bitfile = self._client.transfer_file(bitfile) bitfile = self._client.upload(bitfile)
add_commands(self._script, add_commands(self._script,
"pld load {pld} {filename}", "pld load {pld} {filename}",
pld=pld, filename=bitfile) pld=pld, filename=bitfile)
@ -141,11 +141,11 @@ class Programmer:
def load_proxy(self): def load_proxy(self):
raise NotImplementedError raise NotImplementedError
def flash_binary(self, bankname, address, filename): def write_binary(self, bankname, address, filename):
self.load_proxy() self.load_proxy()
size = os.path.getsize(filename) size = os.path.getsize(filename)
filename = self._client.transfer_file(filename) filename = self._client.upload(filename)
add_commands(self._script, add_commands(self._script,
"flash probe {bankname}", "flash probe {bankname}",
"flash erase_sector {bankname} {firstsector} {lastsector}", "flash erase_sector {bankname} {firstsector} {lastsector}",
@ -155,6 +155,15 @@ class Programmer:
firstsector=address // self._sector_size, firstsector=address // self._sector_size,
lastsector=(address + size - 1) // self._sector_size) lastsector=(address + size - 1) // self._sector_size)
def read_binary(self, bankname, address, length, filename):
self.load_proxy()
filename = self._client.prepare_download(filename)
add_commands(self._script,
"flash probe {bankname}",
"flash read_bank {bankname} {filename} {address:#x} {length}",
bankname=bankname, filename=filename, address=address, length=length)
def start(self): def start(self):
raise NotImplementedError raise NotImplementedError
@ -175,6 +184,9 @@ class Programmer:
cmdline = [arg.replace("{", "{{").replace("}", "}}") for arg in cmdline] cmdline = [arg.replace("{", "{{").replace("}", "}}") for arg in cmdline]
self._client.run_command(cmdline) self._client.run_command(cmdline)
self._client.download()
self._script = []
class ProgrammerXC7(Programmer): class ProgrammerXC7(Programmer):
@ -304,13 +316,13 @@ def main():
bit2bin(bit_file, bin_file) bit2bin(bit_file, bin_file)
atexit.register(lambda: os.unlink(gateware_bin)) atexit.register(lambda: os.unlink(gateware_bin))
programmer.flash_binary(*config["gateware"], gateware_bin) programmer.write_binary(*config["gateware"], gateware_bin)
elif action == "bootloader": elif action == "bootloader":
bootloader_bin = artifact_path("software", "bootloader", "bootloader.bin") bootloader_bin = artifact_path("software", "bootloader", "bootloader.bin")
programmer.flash_binary(*config["bootloader"], bootloader_bin) programmer.write_binary(*config["bootloader"], bootloader_bin)
elif action == "storage": elif action == "storage":
storage_img = args.storage storage_img = args.storage
programmer.flash_binary(*config["storage"], storage_img) programmer.write_binary(*config["storage"], storage_img)
elif action == "firmware": elif action == "firmware":
if variant == "satellite": if variant == "satellite":
firmware = "satman" firmware = "satman"
@ -318,7 +330,7 @@ def main():
firmware = "runtime" firmware = "runtime"
firmware_fbi = artifact_path("software", firmware, firmware + ".fbi") firmware_fbi = artifact_path("software", firmware, firmware + ".fbi")
programmer.flash_binary(*config["firmware"], firmware_fbi) programmer.write_binary(*config["firmware"], firmware_fbi)
elif action == "load": elif action == "load":
if args.target == "sayma_rtm": if args.target == "sayma_rtm":
gateware_bit = artifact_path("top.bit") gateware_bit = artifact_path("top.bit")

View File

@ -6,6 +6,7 @@ import shutil
import shlex import shlex
import subprocess import subprocess
import hashlib import hashlib
import random
__all__ = ["LocalClient", "SSHClient"] __all__ = ["LocalClient", "SSHClient"]
@ -13,7 +14,13 @@ logger = logging.getLogger(__name__)
class Client: class Client:
def transfer_file(self, filename, rewriter=None): def upload(self, filename, rewriter=None):
raise NotImplementedError
def prepare_download(self, filename):
raise NotImplementedError
def download(self):
raise NotImplementedError raise NotImplementedError
def run_command(self, cmd, **kws): def run_command(self, cmd, **kws):
@ -24,8 +31,8 @@ class LocalClient(Client):
def __init__(self): def __init__(self):
self._tmp = os.path.join(tempfile.gettempdir(), "artiq") self._tmp = os.path.join(tempfile.gettempdir(), "artiq")
def transfer_file(self, filename, rewriter=None): def upload(self, filename, rewriter=None):
logger.debug("Transferring {}".format(filename)) logger.debug("Uploading {}".format(filename))
if rewriter is None: if rewriter is None:
return filename return filename
else: else:
@ -37,6 +44,13 @@ class LocalClient(Client):
tmp.write(rewritten) tmp.write(rewritten)
return tmp_filename return tmp_filename
def prepare_download(self, filename):
logger.debug("Downloading {}".format(filename))
return filename
def download(self):
pass
def run_command(self, cmd, **kws): def run_command(self, cmd, **kws):
logger.debug("Executing {}".format(cmd)) logger.debug("Executing {}".format(cmd))
subprocess.check_call([arg.format(tmp=self._tmp, **kws) for arg in cmd]) subprocess.check_call([arg.format(tmp=self._tmp, **kws) for arg in cmd])
@ -47,8 +61,10 @@ class SSHClient(Client):
self.host = host self.host = host
self.ssh = None self.ssh = None
self.sftp = None self.sftp = None
self._tmp = "/tmp/artiq" self._tmpr = "/tmp/artiq"
self._tmpl = tempfile.TemporaryDirectory(prefix="artiq")
self._cached = [] self._cached = []
self._downloads = {}
def get_ssh(self): def get_ssh(self):
if self.ssh is None: if self.ssh is None:
@ -68,21 +84,22 @@ class SSHClient(Client):
if self.sftp is None: if self.sftp is None:
self.sftp = self.get_ssh().open_sftp() self.sftp = self.get_ssh().open_sftp()
try: try:
self._cached = self.sftp.listdir(self._tmp) self._cached = self.sftp.listdir(self._tmpr)
except OSError: except OSError:
self.sftp.mkdir(self._tmp) self.sftp.mkdir(self._tmpr)
return self.sftp return self.sftp
def transfer_file(self, filename, rewriter=lambda x: x): def upload(self, filename, rewriter=lambda x: x):
sftp = self.get_sftp()
with open(filename, 'rb') as local: with open(filename, 'rb') as local:
rewritten = rewriter(local.read()) rewritten = rewriter(local.read())
digest = hashlib.sha1(rewritten).hexdigest() digest = hashlib.sha1(rewritten).hexdigest()
remote_filename = os.path.join(self._tmp, digest) remote_filename = "{}/{}".format(self._tmpr, digest)
sftp = self.get_sftp()
if digest in self._cached: if digest in self._cached:
logger.debug("Using cached {}".format(filename)) logger.debug("Using cached {}".format(filename))
else: else:
logger.debug("Transferring {}".format(filename)) logger.debug("Uploading {}".format(filename))
# Avoid a race condition by writing into a temporary file # Avoid a race condition by writing into a temporary file
# and atomically replacing # and atomically replacing
with sftp.open(remote_filename + ".~", "wb") as remote: with sftp.open(remote_filename + ".~", "wb") as remote:
@ -93,14 +110,33 @@ class SSHClient(Client):
# Either it already exists (this is OK) or something else # Either it already exists (this is OK) or something else
# happened (this isn't) and we need to re-raise # happened (this isn't) and we need to re-raise
sftp.stat(remote_filename) sftp.stat(remote_filename)
return remote_filename return remote_filename
def prepare_download(self, filename):
tmpname = "".join([random.Random().choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
for _ in range(6)])
remote_filename = "{}/{}_{}".format(self._tmpr, tmpname, filename)
_sftp = self.get_sftp()
logger.debug("Downloading {}".format(filename))
self._downloads[filename] = remote_filename
return remote_filename
def download(self):
sftp = self.get_sftp()
for filename, remote_filename in self._downloads.items():
sftp.get(remote_filename, filename)
self._downloads = {}
def spawn_command(self, cmd, get_pty=False, **kws): def spawn_command(self, cmd, get_pty=False, **kws):
chan = self.get_transport().open_session() chan = self.get_transport().open_session()
chan.set_combine_stderr(True) chan.set_combine_stderr(True)
if get_pty: if get_pty:
chan.get_pty() chan.get_pty()
cmd = " ".join([shlex.quote(arg.format(tmp=self._tmp, **kws)) for arg in cmd]) cmd = " ".join([shlex.quote(arg.format(tmp=self._tmpr, **kws)) for arg in cmd])
logger.debug("Executing {}".format(cmd)) logger.debug("Executing {}".format(cmd))
chan.exec_command(cmd) chan.exec_command(cmd)
return chan return chan