artiq_flash: implement network transparency.

This commit is contained in:
whitequark 2018-01-19 07:39:55 +00:00
parent 80cbef0031
commit b553804e5a
3 changed files with 198 additions and 144 deletions

View File

@ -18,7 +18,9 @@ from artiq.tools import verbosity_args, init_logger, logger, SSHClient
def get_argparser(): def get_argparser():
parser = argparse.ArgumentParser(description="ARTIQ core device development tool") parser = argparse.ArgumentParser(
description="ARTIQ core device development tool",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
verbosity_args(parser) verbosity_args(parser)
@ -27,19 +29,19 @@ def get_argparser():
help="Target to build, one of: " help="Target to build, one of: "
"kc705_dds kasli sayma_rtm sayma_amc_standalone " "kc705_dds kasli sayma_rtm sayma_amc_standalone "
"sayma_amc_drtio_master sayma_amc_drtio_satellite") "sayma_amc_drtio_master sayma_amc_drtio_satellite")
parser.add_argument("-H", "--host", metavar="HOSTNAME", parser.add_argument("-H", "--host",
type=str, default="lab.m-labs.hk", type=str, default="lab.m-labs.hk",
help="SSH host where the development board is located") help="SSH host where the development board is located")
parser.add_argument('-b', "--board", metavar="BOARD", parser.add_argument('-b', "--board",
type=str, default=None, type=str, default="{boardtype}-1",
help="Board to connect to on the development SSH host") help="Board to connect to on the development SSH host")
parser.add_argument("-d", "--device", metavar="DEVICENAME", parser.add_argument("-d", "--device",
type=str, default="{board}.{hostname}", type=str, default="{board}.{host}",
help="Address or domain corresponding to the development board") help="Address or domain corresponding to the development board")
parser.add_argument("-s", "--serial", metavar="SERIAL", parser.add_argument("-s", "--serial",
type=str, default="/dev/ttyUSB_{board}", type=str, default="/dev/ttyUSB_{board}",
help="TTY device corresponding to the development board") help="TTY device corresponding to the development board")
parser.add_argument("-l", "--lockfile", metavar="LOCKFILE", parser.add_argument("-l", "--lockfile",
type=str, default="/run/boards/{board}", type=str, default="/run/boards/{board}",
help="The lockfile to be acquired for the duration of the actions") help="The lockfile to be acquired for the duration of the actions")
parser.add_argument("-w", "--wait", action="store_true", parser.add_argument("-w", "--wait", action="store_true",
@ -59,47 +61,30 @@ def main():
if args.verbose == args.quiet == 0: if args.verbose == args.quiet == 0:
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
def build_dir(*path, target=args.target):
return os.path.join("/tmp", target, *path)
build_args = [] build_args = []
if args.target == "kc705_dds": if args.target == "kc705_dds":
boardtype, firmware = "kc705", "runtime" boardtype, firmware = "kc705", "runtime"
elif args.target == "sayma_amc_standalone": elif args.target == "sayma_amc_standalone":
boardtype, firmware = "sayma", "runtime" boardtype, firmware = "sayma_amc", "runtime"
build_args += ["--rtm-csr-csv", "/tmp/sayma_rtm/sayma_rtm_csr.csv"] build_args += ["--rtm-csr-csv", build_dir("sayma_rtm_csr.csv", target="sayma_rtm")]
elif args.target == "sayma_amc_drtio_master": elif args.target == "sayma_amc_drtio_master":
boardtype, firmware = "sayma", "runtime" boardtype, firmware = "sayma_amc", "runtime"
elif args.target == "sayma_amc_drtio_satellite": elif args.target == "sayma_amc_drtio_satellite":
boardtype, firmware = "sayma", "satman" boardtype, firmware = "sayma_amc", "satman"
elif args.target == "sayma_rtm": elif args.target == "sayma_rtm":
boardtype, firmware = "sayma_rtm", None boardtype, firmware = "sayma_rtm", None
else: else:
raise NotImplementedError("unknown target {}".format(args.target)) raise NotImplementedError("unknown target {}".format(args.target))
flash_args = ["-t", boardtype] board = args.board.format(boardtype=boardtype)
if boardtype == "sayma": device = args.device.format(board=board, host=args.host)
if args.board is None: lockfile = args.lockfile.format(board=board)
args.board = "sayma-1" serial = args.serial.format(board=board)
if args.board == "sayma-1":
flash_args += ["--preinit-command", "ftdi_location 5:2"]
elif args.board == "sayma-2":
flash_args += ["--preinit-command", "ftdi_location 3:10"]
elif args.board == "sayma-3":
flash_args += ["--preinit-command", "ftdi_location 5:1"]
else:
raise NotImplementedError("unknown --preinit-command for {}".format(boardtype))
client = SSHClient(args.host) client = SSHClient(args.host)
substs = {
"target": args.target,
"hostname": args.host,
"boardtype": boardtype,
"board": args.board if args.board else boardtype + "-1",
"firmware": firmware,
}
substs.update({
"devicename": args.device.format(**substs),
"lockfile": args.lockfile.format(**substs),
"serial": args.serial.format(**substs),
})
flock_acquired = False flock_acquired = False
flock_file = None # GC root flock_file = None # GC root
@ -109,10 +94,13 @@ def main():
if not flock_acquired: if not flock_acquired:
logger.info("Acquiring device lock") logger.info("Acquiring device lock")
flock = client.spawn_command("flock --verbose {block} {lockfile} sleep 86400" flock_args = ["flock"]
.format(block="" if args.wait else "--nonblock", if not args.wait:
**substs), flock_args.append("--nonblock")
get_pty=True) flock_args += ["--verbose", lockfile]
flock_args += ["sleep", "86400"]
flock = client.spawn_command(flock_args, get_pty=True)
flock_file = flock.makefile('r') flock_file = flock.makefile('r')
while not flock_acquired: while not flock_acquired:
line = flock_file.readline() line = flock_file.readline()
@ -125,65 +113,52 @@ def main():
logger.error("Failed to get lock") logger.error("Failed to get lock")
sys.exit(1) sys.exit(1)
def artiq_flash(args, synchronous=True): def flash(*steps):
args = flash_args + args flash_args = ["artiq_flash"]
args = ["'{}'".format(arg) if " " in arg else arg for arg in args] for _ in range(args.verbose):
cmd = client.spawn_command( flash_args.append("-v")
"artiq_flash " + " ".join(args), flash_args += ["-H", args.host, "-t", boardtype]
**substs) flash_args += ["--srcbuild", build_dir()]
if synchronous: flash_args += ["--preinit-command", "source /var/boards/{}".format(board)]
client.drain(cmd) flash_args += steps
else: subprocess.check_call(flash_args)
return cmd
for action in args.actions: for action in args.actions:
if action == "build": if action == "build":
logger.info("Building target") logger.info("Building target")
try: try:
subprocess.check_call(["python3", subprocess.check_call([
"-m", "artiq.gateware.targets." + args.target, "python3", "-m", "artiq.gateware.targets." + args.target,
"--no-compile-gateware", "--no-compile-gateware",
*build_args, *build_args,
"--output-dir", "--output-dir", build_dir()])
"/tmp/{target}".format(**substs)])
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
logger.error("Build failed") logger.error("Build failed")
sys.exit(1) sys.exit(1)
elif action == "clean": elif action == "clean":
logger.info("Cleaning build directory") logger.info("Cleaning build directory")
target_dir = "/tmp/{target}".format(**substs) shutil.rmtree(build_dir, ignore_errors=True)
if os.path.isdir(target_dir):
shutil.rmtree(target_dir)
elif action == "reset": elif action == "reset":
lock()
logger.info("Resetting device") logger.info("Resetting device")
artiq_flash(["reset"]) flash("start")
elif action == "flash" or action == "flash+log": elif action == "flash" or action == "flash+log":
def upload_product(product, ext): lock()
logger.info("Uploading {}".format(product))
client.get_sftp().put("/tmp/{target}/software/{product}/{product}.{ext}"
.format(target=args.target, product=product, ext=ext),
"{tmp}/{product}.{ext}"
.format(tmp=client.tmp, product=product, ext=ext))
upload_product("bootloader", "bin")
upload_product(firmware, "fbi")
logger.info("Flashing firmware") logger.info("Flashing firmware")
artiq_flash(["-d", "{tmp}", "proxy", "bootloader", "firmware", flash("proxy", "bootloader", "firmware")
"start" if action == "flash" else ""])
logger.info("Booting firmware")
if action == "flash+log": if action == "flash+log":
logger.info("Booting firmware") flterm = client.spawn_command(["flterm", serial, "--output-only"])
flterm = client.spawn_command( flash("start")
"flterm {serial} " +
"--kernel {tmp}/{firmware}.bin " +
("--upload-only" if action == "boot" else "--output-only"),
**substs)
artiq_flash(["start"], synchronous=False)
client.drain(flterm) client.drain(flterm)
else:
flash("start")
elif action == "connect": elif action == "connect":
lock() lock()
@ -218,10 +193,10 @@ def main():
while True: while True:
local_stream, peer_addr = listener.accept() local_stream, peer_addr = listener.accept()
logger.info("Accepting %s:%s and opening SSH channel to %s:%s", logger.info("Accepting %s:%s and opening SSH channel to %s:%s",
*peer_addr, args.device, port) *peer_addr, device, port)
try: try:
remote_stream = \ remote_stream = \
transport.open_channel('direct-tcpip', (args.device, port), peer_addr) transport.open_channel('direct-tcpip', (device, port), peer_addr)
except Exception: except Exception:
logger.exception("Cannot open channel on port %s", port) logger.exception("Cannot open channel on port %s", port)
continue continue
@ -238,17 +213,13 @@ def main():
logger.info("Forwarding ports {} to core device and logs from core device" logger.info("Forwarding ports {} to core device and logs from core device"
.format(", ".join(map(str, ports)))) .format(", ".join(map(str, ports))))
client.run_command( client.run_command(["flterm", serial, "--output-only"])
"flterm {serial} --output-only",
**substs)
elif action == "hotswap": elif action == "hotswap":
logger.info("Hotswapping firmware") logger.info("Hotswapping firmware")
try: try:
subprocess.check_call(["python3", subprocess.check_call(["artiq_coreboot", "hotswap",
"-m", "artiq.frontend.artiq_coreboot", "hotswap", build_dir("software", firmware, firmware + ".bin")])
"/tmp/{target}/software/{firmware}/{firmware}.bin"
.format(target=args.target, firmware=firmware)])
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
logger.error("Build failed") logger.error("Build failed")
sys.exit(1) sys.exit(1)

View File

@ -5,12 +5,13 @@ import os
import subprocess import subprocess
import tempfile import tempfile
import shutil import shutil
import re
from functools import partial from functools import partial
from artiq import __artiq_dir__ as artiq_dir from artiq import __artiq_dir__ as artiq_dir
from artiq.tools import verbosity_args, init_logger, logger, SSHClient, LocalClient
from artiq.frontend.bit2bin import bit2bin from artiq.frontend.bit2bin import bit2bin
def get_argparser(): def get_argparser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
@ -35,6 +36,12 @@ Prerequisites:
and replug the device. Ensure you are member of the and replug the device. Ensure you are member of the
plugdev group: 'sudo adduser $USER plugdev' and re-login. plugdev group: 'sudo adduser $USER plugdev' and re-login.
""") """)
verbosity_args(parser)
parser.add_argument("-H", "--host", metavar="HOSTNAME",
type=str, default=None,
help="SSH host where the development board is located")
parser.add_argument("-t", "--target", default="kc705", parser.add_argument("-t", "--target", default="kc705",
help="target board, default: %(default)s, one of: " help="target board, default: %(default)s, one of: "
"kc705 kasli sayma_amc sayma_rtm") "kc705 kasli sayma_amc sayma_rtm")
@ -73,14 +80,34 @@ def proxy_path():
class Programmer: class Programmer:
def __init__(self, target_file, preinit_commands): def __init__(self, client, target_file, preinit_commands):
self.target_file = target_file self.client = client
if target_file:
self.target_file = self._transfer_script(target_file)
else:
self.target_file = None
self.preinit_commands = preinit_commands self.preinit_commands = preinit_commands
self.prog = [] self.prog = []
def _transfer_script(self, script):
if isinstance(self.client, LocalClient):
return script
def rewriter(content):
def repl(match):
return "[find {}]".format(self._transfer_script(match.group(1)))
return re.sub(r"\[find (.+?)\]", repl, content, re.DOTALL)
script = os.path.join(scripts_path(), script)
return self.client.transfer_file(script, rewriter)
def _command(self, cmd):
self.prog.append(cmd.replace("{", "{{").replace("}", "}}"))
def init(self): def init(self):
self.prog.extend(self.preinit_commands) for command in self.preinit_commands:
self.prog.append("init") self._command(command)
self._command("init")
def load(self, bitfile): def load(self, bitfile):
raise NotImplementedError raise NotImplementedError
@ -95,46 +122,49 @@ class Programmer:
raise NotImplementedError raise NotImplementedError
def do(self): def do(self):
self.prog.append("exit") self._command("exit")
cmdline = [
"openocd", cmdline = ["openocd"]
"-s", scripts_path() if isinstance(self.client, LocalClient):
] cmdline += ["-s", scripts_path()]
if self.target_file is not None: if self.target_file is not None:
cmdline += ["-f", self.target_file] cmdline += ["-f", self.target_file]
cmdline += ["-c", "; ".join(self.prog)] cmdline += ["-c", "; ".join(self.prog)]
subprocess.check_call(cmdline)
self.client.run_command(cmdline)
class ProgrammerJtagSpi7(Programmer): class ProgrammerJtagSpi7(Programmer):
def __init__(self, target, preinit_commands): def __init__(self, client, target, preinit_commands):
Programmer.__init__(self, os.path.join("board", target + ".cfg"), Programmer.__init__(self, client, os.path.join("board", target + ".cfg"),
preinit_commands) preinit_commands)
self.init() self.init()
def load(self, bitfile, pld=0): def load(self, bitfile, pld=0):
self.prog.append("pld load {} {{{}}}".format(pld, bitfile)) bitfile = self.client.transfer_file(bitfile)
self._command("pld load {} {{{}}}".format(pld, bitfile))
def proxy(self, proxy_bitfile, pld=0): def proxy(self, proxy_bitfile, pld=0):
self.prog.append("jtagspi_init {} {{{}}}".format(pld, proxy_bitfile)) proxy_bitfile = self.client.transfer_file(proxy_bitfile)
self._command("jtagspi_init {} {{{}}}".format(pld, proxy_bitfile))
def flash_binary(self, flashno, address, filename): def flash_binary(self, flashno, address, filename):
# jtagspi_program supports only one flash # jtagspi_program supports only one flash
assert flashno == 0 assert flashno == 0
self.prog.append("jtagspi_program {{{}}} 0x{:x}".format( filename = self.client.transfer_file(filename)
self._command("jtagspi_program {{{}}} 0x{:x}".format(
filename, address)) filename, address))
def start(self): def start(self):
self.prog.append("xc7_program xc7.tap") self._command("xc7_program xc7.tap")
class ProgrammerSayma(Programmer): class ProgrammerSayma(Programmer):
sector_size = 0x10000 sector_size = 0x10000
def __init__(self, preinit_commands): def __init__(self, client, preinit_commands):
# TODO: support Sayma RTM Programmer.__init__(self, client, None, preinit_commands)
Programmer.__init__(self, None, preinit_commands)
self.proxy_loaded = False
self.prog += [ self.prog += [
"interface ftdi", "interface ftdi",
"ftdi_device_desc \"Quad RS232-HS\"", "ftdi_device_desc \"Quad RS232-HS\"",
@ -161,11 +191,12 @@ class ProgrammerSayma(Programmer):
self.init() self.init()
def load(self, bitfile, pld=1): def load(self, bitfile, pld=1):
self.prog.append("pld load {} {{{}}}".format(pld, bitfile)) bitfile = self.client.transfer_file(bitfile)
self._command("pld load {} {{{}}}".format(pld, bitfile))
def proxy(self, proxy_bitfile, pld=1): def proxy(self, proxy_bitfile, pld=1):
self.load(proxy_bitfile, pld) self.load(proxy_bitfile, pld)
self.prog.append("reset halt") self._command("reset halt")
def flash_binary(self, flashno, address, filename): def flash_binary(self, flashno, address, filename):
sector_first = address // self.sector_size sector_first = address // self.sector_size
@ -173,25 +204,23 @@ class ProgrammerSayma(Programmer):
assert size assert size
sector_last = sector_first + (size - 1) // self.sector_size sector_last = sector_first + (size - 1) // self.sector_size
assert sector_last >= sector_first assert sector_last >= sector_first
self.prog += [ filename = self.client.transfer_file(filename)
"flash probe xcu.spi{}".format(flashno), self._command("flash probe xcu.spi{}".format(flashno))
"flash erase_sector {} {} {}".format(flashno, sector_first, sector_last), self._command("flash erase_sector {} {} {}".format(flashno, sector_first, sector_last))
"flash write_bank {} {{{}}} 0x{:x}".format(flashno, filename, address), self._command("flash write_bank {} {{{}}} 0x{:x}".format(flashno, filename, address))
"flash verify_bank {} {{{}}} 0x{:x}".format(flashno, filename, address), self._command("flash verify_bank {} {{{}}} 0x{:x}".format(flashno, filename, address))
]
def start(self): def start(self):
self.proxy_loaded = False self._command("xcu_program xcu.tap")
self.prog.append("xcu_program xcu.tap")
def main(): def main():
parser = get_argparser() args = get_argparser().parse_args()
opts = parser.parse_args() init_logger(args)
config = { config = {
"kc705": { "kc705": {
"programmer_factory": partial(ProgrammerJtagSpi7, "kc705"), "programmer_factory": partial(ProgrammerJtagSpi7, target="kc705"),
"proxy_bitfile": "bscan_spi_xc7k325t.bit", "proxy_bitfile": "bscan_spi_xc7k325t.bit",
"variants": ["nist_clock", "nist_qc2"], "variants": ["nist_clock", "nist_qc2"],
"gateware": (0, 0x000000), "gateware": (0, 0x000000),
@ -200,7 +229,7 @@ def main():
"firmware": (0, 0xb40000), "firmware": (0, 0xb40000),
}, },
"kasli": { "kasli": {
"programmer_factory": partial(ProgrammerJtagSpi7, "kasli"), "programmer_factory": partial(ProgrammerJtagSpi7, target="kasli"),
"proxy_bitfile": "bscan_spi_xc7a100t.bit", "proxy_bitfile": "bscan_spi_xc7a100t.bit",
"variants": ["opticlock"], "variants": ["opticlock"],
"gateware": (0, 0x000000), "gateware": (0, 0x000000),
@ -222,29 +251,34 @@ def main():
"proxy_bitfile": "bscan_spi_xcku040-sayma.bit", "proxy_bitfile": "bscan_spi_xcku040-sayma.bit",
"gateware": (1, 0x150000), "gateware": (1, 0x150000),
}, },
}[opts.target] }[args.target]
variant = opts.variant variant = args.variant
if "variants" in config: if "variants" in config:
if variant is not None and variant not in config["variants"]: if variant is not None and variant not in config["variants"]:
raise SystemExit("Invalid variant for this board") raise SystemExit("Invalid variant for this board")
if variant is None: if variant is None:
variant = config["variants"][0] variant = config["variants"][0]
bin_dir = opts.dir bin_dir = args.dir
if bin_dir is None: if bin_dir is None:
if variant: if variant:
bin_name = "{}-{}".format(opts.target, variant) bin_name = "{}-{}".format(args.target, variant)
else: else:
bin_name = opts.target bin_name = args.target
bin_dir = os.path.join(artiq_dir, "binaries", bin_name) bin_dir = os.path.join(artiq_dir, "binaries", bin_name)
if opts.srcbuild is None and not os.path.exists(bin_dir) and opts.action != ["start"]: if args.srcbuild is None and not os.path.exists(bin_dir) and args.action != ["start"]:
raise SystemExit("Binaries directory '{}' does not exist" raise SystemExit("Binaries directory '{}' does not exist"
.format(bin_dir)) .format(bin_dir))
programmer = config["programmer_factory"](opts.preinit_command) if args.host is None:
client = LocalClient()
else:
client = SSHClient(args.host)
programmer = config["programmer_factory"](client, preinit_commands=args.preinit_command)
conv = False conv = False
for action in opts.action: for action in args.action:
if action == "proxy": if action == "proxy":
proxy_found = False proxy_found = False
for p in [bin_dir, proxy_path(), os.path.expanduser("~/.migen"), for p in [bin_dir, proxy_path(), os.path.expanduser("~/.migen"),
@ -258,10 +292,10 @@ def main():
raise SystemExit( raise SystemExit(
"proxy gateware bitstream {} not found".format(config["proxy_bitfile"])) "proxy gateware bitstream {} not found".format(config["proxy_bitfile"]))
elif action == "gateware": elif action == "gateware":
if opts.srcbuild is None: if args.srcbuild is None:
path = bin_dir path = bin_dir
else: else:
path = os.path.join(opts.srcbuild, "gateware") path = os.path.join(args.srcbuild, "gateware")
bin = os.path.join(path, "top.bin") bin = os.path.join(path, "top.bin")
if not os.access(bin, os.R_OK): if not os.access(bin, os.R_OK):
bin_handle, bin = tempfile.mkstemp() bin_handle, bin = tempfile.mkstemp()
@ -271,29 +305,29 @@ def main():
conv = True conv = True
programmer.flash_binary(*config["gateware"], bin) programmer.flash_binary(*config["gateware"], bin)
elif action == "bootloader": elif action == "bootloader":
if opts.srcbuild is None: if args.srcbuild is None:
path = bin_dir path = bin_dir
else: else:
path = os.path.join(opts.srcbuild, "software", "bootloader") path = os.path.join(args.srcbuild, "software", "bootloader")
programmer.flash_binary(*config["bootloader"], os.path.join(path, "bootloader.bin")) programmer.flash_binary(*config["bootloader"], os.path.join(path, "bootloader.bin"))
elif action == "storage": elif action == "storage":
programmer.flash_binary(*config["storage"], opts.storage) programmer.flash_binary(*config["storage"], args.storage)
elif action == "firmware": elif action == "firmware":
if variant == "satellite": if variant == "satellite":
firmware_name = "satman" firmware_name = "satman"
else: else:
firmware_name = "runtime" firmware_name = "runtime"
if opts.srcbuild is None: if args.srcbuild is None:
path = bin_dir path = bin_dir
else: else:
path = os.path.join(opts.srcbuild, "software", firmware_name) path = os.path.join(args.srcbuild, "software", firmware_name)
programmer.flash_binary(*config["firmware"], programmer.flash_binary(*config["firmware"],
os.path.join(path, firmware_name + ".fbi")) os.path.join(path, firmware_name + ".fbi"))
elif action == "load": elif action == "load":
if opts.srcbuild is None: if args.srcbuild is None:
path = bin_dir path = bin_dir
else: else:
path = os.path.join(opts.srcbuild, "gateware") path = os.path.join(args.srcbuild, "gateware")
programmer.load(os.path.join(path, "top.bit")) programmer.load(os.path.join(path, "top.bit"))
elif action == "start": elif action == "start":
programmer.start() programmer.start()

View File

@ -3,10 +3,9 @@ import logging
import sys import sys
import asyncio import asyncio
import collections import collections
import os
import atexit import atexit
import string import string
import random import os, random, tempfile, shutil, shlex, subprocess
import numpy as np import numpy as np
@ -256,7 +255,45 @@ def get_user_config_dir():
return dir return dir
class SSHClient: class Client:
def transfer_file(self, filename, rewriter=None):
raise NotImplementedError
def run_command(self, cmd, **kws):
raise NotImplementedError
class LocalClient(Client):
def __init__(self):
tmpname = "".join([random.Random().choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
for _ in range(6)])
self.tmp = os.path.join(tempfile.gettempdir(), "artiq" + tmpname)
self._has_tmp = False
def _prepare_tmp(self):
if not self._has_tmp:
os.mkdir(self.tmp)
atexit.register(lambda: shutil.rmtree(self.tmp, ignore_errors=True))
self._has_tmp = True
def transfer_file(self, filename, rewriter=None):
logger.debug("Transferring {}".format(filename))
if rewriter is None:
return filename
else:
tmp_filename = os.path.join(self.tmp, filename.replace(os.sep, "_"))
with open(filename) as local:
self._prepare_tmp()
with open(tmp_filename, 'w') as tmp:
tmp.write(rewriter(local.read()))
return tmp_filename
def run_command(self, cmd, **kws):
logger.debug("Executing {}".format(cmd))
subprocess.check_call([arg.format(tmp=self.tmp, **kws) for arg in cmd])
class SSHClient(Client):
def __init__(self, host): def __init__(self, host):
self.host = host self.host = host
self.ssh = None self.ssh = None
@ -284,16 +321,28 @@ class SSHClient:
if self.sftp is None: if self.sftp is None:
self.sftp = self.get_ssh().open_sftp() self.sftp = self.get_ssh().open_sftp()
self.sftp.mkdir(self.tmp) self.sftp.mkdir(self.tmp)
atexit.register(lambda: self.run_command("rm -rf {tmp}")) atexit.register(lambda: self.run_command(["rm", "-rf", "{tmp}"]))
return self.sftp return self.sftp
def transfer_file(self, filename, rewriter=None):
remote_filename = "{}/{}".format(self.tmp, filename.replace("/", "_"))
logger.debug("Transferring {}".format(filename))
if rewriter is None:
self.get_sftp().put(filename, remote_filename)
else:
with open(filename) as local:
with self.get_sftp().open(remote_filename, 'w') as remote:
remote.write(rewriter(local.read()))
return remote_filename
def spawn_command(self, cmd, get_pty=False, **kws): def spawn_command(self, cmd, get_pty=False, **kws):
chan = self.get_transport().open_session() chan = self.get_transport().open_session()
chan.set_combine_stderr(True) chan.set_combine_stderr(True)
if get_pty: if get_pty:
chan.get_pty() chan.get_pty()
cmd = " ".join([shlex.quote(arg.format(tmp=self.tmp, **kws)) for arg in cmd])
logger.debug("Executing {}".format(cmd)) logger.debug("Executing {}".format(cmd))
chan.exec_command(cmd.format(tmp=self.tmp, **kws)) chan.exec_command(cmd)
return chan return chan
def drain(self, chan): def drain(self, chan):