2014-12-08 16:11:31 +08:00
|
|
|
import importlib.machinery
|
2015-01-28 21:44:15 +08:00
|
|
|
import logging
|
2015-02-18 23:52:31 +08:00
|
|
|
import sys
|
2015-03-11 23:43:07 +08:00
|
|
|
import asyncio
|
2015-08-10 21:58:11 +08:00
|
|
|
import collections
|
2015-11-11 16:22:12 +08:00
|
|
|
import atexit
|
2015-12-09 19:13:57 +08:00
|
|
|
import string
|
2018-01-19 15:39:55 +08:00
|
|
|
import os, random, tempfile, shutil, shlex, subprocess
|
2014-12-08 16:11:31 +08:00
|
|
|
|
2015-10-12 19:46:14 +08:00
|
|
|
import numpy as np
|
|
|
|
|
2015-07-14 04:08:20 +08:00
|
|
|
from artiq.language.environment import is_experiment
|
2015-04-07 13:04:47 +08:00
|
|
|
from artiq.protocols import pyon
|
2016-07-18 22:50:27 +08:00
|
|
|
from artiq.appdirs import user_config_dir
|
2016-08-04 19:42:13 +08:00
|
|
|
from artiq import __version__ as artiq_version
|
2015-04-07 13:04:47 +08:00
|
|
|
|
|
|
|
|
2016-01-26 09:04:06 +08:00
|
|
|
__all__ = ["parse_arguments", "elide", "short_format", "file_import",
|
2016-07-18 22:50:45 +08:00
|
|
|
"get_experiment", "verbosity_args", "simple_network_args",
|
2017-07-18 12:10:33 +08:00
|
|
|
"multiline_log_config", "init_logger", "bind_address_from_args",
|
2016-07-18 22:50:45 +08:00
|
|
|
"atexit_register_coroutine", "exc_to_warning",
|
|
|
|
"asyncio_wait_or_cancel", "TaskObject", "Condition",
|
2016-07-19 00:16:56 +08:00
|
|
|
"get_windows_drives", "get_user_config_dir"]
|
2015-11-11 16:22:12 +08:00
|
|
|
|
|
|
|
|
2015-08-08 23:36:12 +08:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2015-04-07 13:04:47 +08:00
|
|
|
def parse_arguments(arguments):
|
|
|
|
d = {}
|
|
|
|
for argument in arguments:
|
|
|
|
name, eq, value = argument.partition("=")
|
|
|
|
d[name] = pyon.decode(value)
|
|
|
|
return d
|
2014-12-08 16:11:31 +08:00
|
|
|
|
2015-06-05 00:37:26 +08:00
|
|
|
|
2015-10-12 18:10:58 +08:00
|
|
|
def elide(s, maxlen):
|
|
|
|
elided = False
|
|
|
|
if len(s) > maxlen:
|
|
|
|
s = s[:maxlen]
|
|
|
|
elided = True
|
|
|
|
try:
|
|
|
|
idx = s.index("\n")
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
s = s[:idx]
|
|
|
|
elided = True
|
|
|
|
if elided:
|
|
|
|
maxlen -= 3
|
|
|
|
if len(s) > maxlen:
|
|
|
|
s = s[:maxlen]
|
|
|
|
s += "..."
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
def short_format(v):
|
|
|
|
if v is None:
|
|
|
|
return "None"
|
|
|
|
t = type(v)
|
2016-05-22 02:11:08 +08:00
|
|
|
if np.issubdtype(t, np.number) or np.issubdtype(t, np.bool_):
|
2015-10-12 18:10:58 +08:00
|
|
|
return str(v)
|
2016-05-22 02:11:08 +08:00
|
|
|
elif np.issubdtype(t, np.unicode_):
|
2015-10-18 13:32:29 +08:00
|
|
|
return "\"" + elide(v, 50) + "\""
|
2015-10-12 18:10:58 +08:00
|
|
|
else:
|
|
|
|
r = t.__name__
|
|
|
|
if t is list or t is dict or t is set:
|
|
|
|
r += " ({})".format(len(v))
|
2015-10-21 11:13:46 +08:00
|
|
|
if t is np.ndarray:
|
|
|
|
r += " " + str(np.shape(v))
|
2015-10-12 18:10:58 +08:00
|
|
|
return r
|
|
|
|
|
|
|
|
|
2015-08-28 15:22:59 +08:00
|
|
|
def file_import(filename, prefix="file_import_"):
|
2014-12-08 16:11:31 +08:00
|
|
|
modname = filename
|
|
|
|
i = modname.rfind("/")
|
|
|
|
if i > 0:
|
|
|
|
modname = modname[i+1:]
|
|
|
|
i = modname.find(".")
|
|
|
|
if i > 0:
|
|
|
|
modname = modname[:i]
|
2015-08-28 15:22:59 +08:00
|
|
|
modname = prefix + modname
|
2014-12-08 16:11:31 +08:00
|
|
|
|
2015-02-18 07:13:00 +08:00
|
|
|
path = os.path.dirname(os.path.realpath(filename))
|
|
|
|
sys.path.insert(0, path)
|
2016-02-01 03:33:17 +08:00
|
|
|
try:
|
|
|
|
loader = importlib.machinery.SourceFileLoader(modname, filename)
|
|
|
|
module = loader.load_module()
|
|
|
|
finally:
|
|
|
|
sys.path.remove(path)
|
2015-02-18 07:13:00 +08:00
|
|
|
|
2015-02-18 04:07:09 +08:00
|
|
|
return module
|
2015-01-28 21:44:15 +08:00
|
|
|
|
|
|
|
|
2015-04-07 13:04:47 +08:00
|
|
|
def get_experiment(module, experiment=None):
|
|
|
|
if experiment:
|
|
|
|
return getattr(module, experiment)
|
|
|
|
|
|
|
|
exps = [(k, v) for k, v in module.__dict__.items()
|
2015-09-02 03:21:03 +08:00
|
|
|
if k[0] != "_" and is_experiment(v)]
|
2015-04-07 13:04:47 +08:00
|
|
|
if not exps:
|
|
|
|
raise ValueError("No experiments in module")
|
|
|
|
if len(exps) > 1:
|
|
|
|
raise ValueError("More than one experiment found in module")
|
|
|
|
return exps[0][1]
|
|
|
|
|
|
|
|
|
2015-01-28 21:44:15 +08:00
|
|
|
def verbosity_args(parser):
|
|
|
|
group = parser.add_argument_group("verbosity")
|
2015-02-04 19:09:37 +08:00
|
|
|
group.add_argument("-v", "--verbose", default=0, action="count",
|
|
|
|
help="increase logging level")
|
|
|
|
group.add_argument("-q", "--quiet", default=0, action="count",
|
|
|
|
help="decrease logging level")
|
2015-01-28 21:44:15 +08:00
|
|
|
|
|
|
|
|
2015-02-16 05:55:43 +08:00
|
|
|
def simple_network_args(parser, default_port):
|
2015-12-27 18:03:13 +08:00
|
|
|
group = parser.add_argument_group("network server")
|
|
|
|
group.add_argument(
|
|
|
|
"--bind", default=[], action="append",
|
2016-07-19 00:30:34 +08:00
|
|
|
help="additional hostname or IP addresse to bind to; "
|
|
|
|
"use '*' to bind to all interfaces (default: %(default)s)")
|
2015-12-27 18:03:13 +08:00
|
|
|
group.add_argument(
|
|
|
|
"--no-localhost-bind", default=False, action="store_true",
|
2016-07-19 00:30:34 +08:00
|
|
|
help="do not implicitly also bind to localhost addresses")
|
2015-12-27 18:03:13 +08:00
|
|
|
if isinstance(default_port, int):
|
|
|
|
group.add_argument("-p", "--port", default=default_port, type=int,
|
2016-07-19 01:47:05 +08:00
|
|
|
help="TCP port to listen on (default: %(default)d)")
|
2015-12-27 18:03:13 +08:00
|
|
|
else:
|
|
|
|
for name, purpose, default in default_port:
|
2016-07-19 01:47:05 +08:00
|
|
|
h = ("TCP port to listen on for {} connections (default: {})"
|
2016-07-18 22:50:45 +08:00
|
|
|
.format(purpose, default))
|
2015-12-27 18:03:13 +08:00
|
|
|
group.add_argument("--port-" + name, default=default, type=int,
|
2016-07-18 22:50:45 +08:00
|
|
|
help=h)
|
|
|
|
|
2015-02-16 05:55:43 +08:00
|
|
|
|
2016-01-27 04:59:37 +08:00
|
|
|
class MultilineFormatter(logging.Formatter):
|
|
|
|
def __init__(self):
|
|
|
|
logging.Formatter.__init__(
|
|
|
|
self, "%(levelname)s:%(name)s:%(message)s")
|
|
|
|
|
|
|
|
def format(self, record):
|
|
|
|
r = logging.Formatter.format(self, record)
|
|
|
|
linebreaks = r.count("\n")
|
|
|
|
if linebreaks:
|
|
|
|
i = r.index(":")
|
|
|
|
r = r[:i] + "<" + str(linebreaks + 1) + ">" + r[i:]
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
|
|
def multiline_log_config(level):
|
|
|
|
root_logger = logging.getLogger()
|
|
|
|
root_logger.setLevel(level)
|
|
|
|
handler = logging.StreamHandler()
|
|
|
|
handler.setFormatter(MultilineFormatter())
|
|
|
|
root_logger.addHandler(handler)
|
|
|
|
|
|
|
|
|
2015-01-28 21:44:15 +08:00
|
|
|
def init_logger(args):
|
2016-07-18 22:50:45 +08:00
|
|
|
multiline_log_config(
|
|
|
|
level=logging.WARNING + args.quiet*10 - args.verbose*10)
|
2015-03-11 23:43:07 +08:00
|
|
|
|
|
|
|
|
2015-12-27 18:03:13 +08:00
|
|
|
def bind_address_from_args(args):
|
2016-07-19 00:30:34 +08:00
|
|
|
if "*" in args.bind:
|
|
|
|
return None
|
2015-12-27 18:03:13 +08:00
|
|
|
if args.no_localhost_bind:
|
|
|
|
return args.bind
|
|
|
|
else:
|
|
|
|
return ["127.0.0.1", "::1"] + args.bind
|
|
|
|
|
|
|
|
|
2015-11-11 16:22:12 +08:00
|
|
|
def atexit_register_coroutine(coroutine, loop=None):
|
|
|
|
if loop is None:
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
atexit.register(lambda: loop.run_until_complete(coroutine()))
|
|
|
|
|
|
|
|
|
2015-10-03 19:28:57 +08:00
|
|
|
async def exc_to_warning(coro):
|
2015-08-08 23:36:12 +08:00
|
|
|
try:
|
2015-10-03 19:28:57 +08:00
|
|
|
await coro
|
2015-08-08 23:36:12 +08:00
|
|
|
except:
|
|
|
|
logger.warning("asyncio coroutine terminated with exception",
|
|
|
|
exc_info=True)
|
|
|
|
|
|
|
|
|
2015-10-03 19:28:57 +08:00
|
|
|
async def asyncio_wait_or_cancel(fs, **kwargs):
|
2015-10-03 14:37:02 +08:00
|
|
|
fs = [asyncio.ensure_future(f) for f in fs]
|
2015-05-17 16:11:00 +08:00
|
|
|
try:
|
2015-10-03 19:28:57 +08:00
|
|
|
d, p = await asyncio.wait(fs, **kwargs)
|
2015-05-17 16:11:00 +08:00
|
|
|
except:
|
|
|
|
for f in fs:
|
|
|
|
f.cancel()
|
|
|
|
raise
|
|
|
|
for f in p:
|
|
|
|
f.cancel()
|
2015-10-03 19:28:57 +08:00
|
|
|
await asyncio.wait([f])
|
2015-05-17 16:11:00 +08:00
|
|
|
return fs
|
|
|
|
|
|
|
|
|
2015-06-05 14:52:41 +08:00
|
|
|
class TaskObject:
|
|
|
|
def start(self):
|
2015-10-03 14:37:02 +08:00
|
|
|
self.task = asyncio.ensure_future(self._do())
|
2015-06-05 14:52:41 +08:00
|
|
|
|
2015-10-03 19:28:57 +08:00
|
|
|
async def stop(self):
|
2015-06-05 14:52:41 +08:00
|
|
|
self.task.cancel()
|
2015-08-06 22:14:49 +08:00
|
|
|
try:
|
2015-10-03 19:28:57 +08:00
|
|
|
await asyncio.wait_for(self.task, None)
|
2015-08-06 22:14:49 +08:00
|
|
|
except asyncio.CancelledError:
|
|
|
|
pass
|
2015-06-05 14:52:41 +08:00
|
|
|
del self.task
|
|
|
|
|
2015-10-03 19:28:57 +08:00
|
|
|
async def _do(self):
|
2015-06-05 14:52:41 +08:00
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
2015-08-10 21:58:11 +08:00
|
|
|
class Condition:
|
|
|
|
def __init__(self, *, loop=None):
|
|
|
|
if loop is not None:
|
|
|
|
self._loop = loop
|
2015-05-28 17:20:58 +08:00
|
|
|
else:
|
2015-08-10 21:58:11 +08:00
|
|
|
self._loop = asyncio.get_event_loop()
|
|
|
|
self._waiters = collections.deque()
|
2015-05-28 17:20:58 +08:00
|
|
|
|
2015-10-03 19:28:57 +08:00
|
|
|
async def wait(self):
|
2015-08-10 21:58:11 +08:00
|
|
|
"""Wait until notified."""
|
|
|
|
fut = asyncio.Future(loop=self._loop)
|
|
|
|
self._waiters.append(fut)
|
|
|
|
try:
|
2015-10-03 19:28:57 +08:00
|
|
|
await fut
|
2015-08-10 21:58:11 +08:00
|
|
|
finally:
|
|
|
|
self._waiters.remove(fut)
|
|
|
|
|
|
|
|
def notify(self):
|
|
|
|
for fut in self._waiters:
|
|
|
|
if not fut.done():
|
|
|
|
fut.set_result(False)
|
2015-10-20 00:35:33 +08:00
|
|
|
|
|
|
|
|
2015-12-09 19:13:57 +08:00
|
|
|
def get_windows_drives():
|
|
|
|
from ctypes import windll
|
|
|
|
|
|
|
|
drives = []
|
|
|
|
bitmask = windll.kernel32.GetLogicalDrives()
|
|
|
|
for letter in string.ascii_uppercase:
|
|
|
|
if bitmask & 1:
|
|
|
|
drives.append(letter)
|
|
|
|
bitmask >>= 1
|
|
|
|
return drives
|
2016-07-18 22:50:27 +08:00
|
|
|
|
|
|
|
|
|
|
|
def get_user_config_dir():
|
|
|
|
major = artiq_version.split(".")[0]
|
|
|
|
dir = user_config_dir("artiq", "m-labs", major)
|
|
|
|
os.makedirs(dir, exist_ok=True)
|
|
|
|
return dir
|
2017-06-25 15:04:29 +08:00
|
|
|
|
|
|
|
|
2018-01-19 15:39:55 +08:00
|
|
|
class Client:
|
|
|
|
def transfer_file(self, filename, rewriter=None):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def run_command(self, cmd, **kws):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
class LocalClient(Client):
|
|
|
|
def __init__(self):
|
|
|
|
tmpname = "".join([random.Random().choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
|
|
|
for _ in range(6)])
|
|
|
|
self.tmp = os.path.join(tempfile.gettempdir(), "artiq" + tmpname)
|
|
|
|
self._has_tmp = False
|
|
|
|
|
|
|
|
def _prepare_tmp(self):
|
|
|
|
if not self._has_tmp:
|
|
|
|
os.mkdir(self.tmp)
|
|
|
|
atexit.register(lambda: shutil.rmtree(self.tmp, ignore_errors=True))
|
|
|
|
self._has_tmp = True
|
|
|
|
|
|
|
|
def transfer_file(self, filename, rewriter=None):
|
|
|
|
logger.debug("Transferring {}".format(filename))
|
|
|
|
if rewriter is None:
|
|
|
|
return filename
|
|
|
|
else:
|
|
|
|
tmp_filename = os.path.join(self.tmp, filename.replace(os.sep, "_"))
|
|
|
|
with open(filename) as local:
|
|
|
|
self._prepare_tmp()
|
|
|
|
with open(tmp_filename, 'w') as tmp:
|
|
|
|
tmp.write(rewriter(local.read()))
|
|
|
|
return tmp_filename
|
|
|
|
|
|
|
|
def run_command(self, cmd, **kws):
|
|
|
|
logger.debug("Executing {}".format(cmd))
|
|
|
|
subprocess.check_call([arg.format(tmp=self.tmp, **kws) for arg in cmd])
|
|
|
|
|
|
|
|
|
|
|
|
class SSHClient(Client):
|
2017-06-25 15:04:29 +08:00
|
|
|
def __init__(self, host):
|
|
|
|
self.host = host
|
|
|
|
self.ssh = None
|
|
|
|
self.sftp = None
|
|
|
|
|
|
|
|
tmpname = "".join([random.Random().choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
|
|
|
for _ in range(6)])
|
|
|
|
self.tmp = "/tmp/artiq" + tmpname
|
|
|
|
|
|
|
|
def get_ssh(self):
|
|
|
|
if self.ssh is None:
|
2017-06-25 15:17:03 +08:00
|
|
|
import paramiko
|
2017-11-26 23:17:35 +08:00
|
|
|
logging.getLogger("paramiko").setLevel(logging.WARNING)
|
2017-06-25 15:04:29 +08:00
|
|
|
self.ssh = paramiko.SSHClient()
|
|
|
|
self.ssh.load_system_host_keys()
|
|
|
|
self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
|
|
self.ssh.connect(self.host)
|
2017-11-26 23:17:35 +08:00
|
|
|
logger.debug("Connecting to {}".format(self.host))
|
2017-06-25 15:04:29 +08:00
|
|
|
return self.ssh
|
|
|
|
|
|
|
|
def get_transport(self):
|
|
|
|
return self.get_ssh().get_transport()
|
|
|
|
|
|
|
|
def get_sftp(self):
|
|
|
|
if self.sftp is None:
|
|
|
|
self.sftp = self.get_ssh().open_sftp()
|
|
|
|
self.sftp.mkdir(self.tmp)
|
2018-01-19 15:39:55 +08:00
|
|
|
atexit.register(lambda: self.run_command(["rm", "-rf", "{tmp}"]))
|
2017-06-25 15:04:29 +08:00
|
|
|
return self.sftp
|
|
|
|
|
2018-01-19 15:39:55 +08:00
|
|
|
def transfer_file(self, filename, rewriter=None):
|
|
|
|
remote_filename = "{}/{}".format(self.tmp, filename.replace("/", "_"))
|
|
|
|
logger.debug("Transferring {}".format(filename))
|
|
|
|
if rewriter is None:
|
|
|
|
self.get_sftp().put(filename, remote_filename)
|
|
|
|
else:
|
|
|
|
with open(filename) as local:
|
|
|
|
with self.get_sftp().open(remote_filename, 'w') as remote:
|
|
|
|
remote.write(rewriter(local.read()))
|
|
|
|
return remote_filename
|
|
|
|
|
2017-06-25 15:04:29 +08:00
|
|
|
def spawn_command(self, cmd, get_pty=False, **kws):
|
2017-11-26 23:17:35 +08:00
|
|
|
chan = self.get_transport().open_session()
|
|
|
|
chan.set_combine_stderr(True)
|
2017-06-25 15:04:29 +08:00
|
|
|
if get_pty:
|
|
|
|
chan.get_pty()
|
2018-01-19 15:39:55 +08:00
|
|
|
cmd = " ".join([shlex.quote(arg.format(tmp=self.tmp, **kws)) for arg in cmd])
|
2017-11-26 23:17:35 +08:00
|
|
|
logger.debug("Executing {}".format(cmd))
|
2018-01-19 15:39:55 +08:00
|
|
|
chan.exec_command(cmd)
|
2017-06-25 15:04:29 +08:00
|
|
|
return chan
|
|
|
|
|
|
|
|
def drain(self, chan):
|
|
|
|
while True:
|
|
|
|
char = chan.recv(1)
|
|
|
|
if char == b"":
|
|
|
|
break
|
|
|
|
sys.stderr.write(char.decode("utf-8", errors='replace'))
|
|
|
|
|
|
|
|
def run_command(self, cmd, **kws):
|
|
|
|
self.drain(self.spawn_command(cmd, **kws))
|