forked from M-Labs/artiq
966ed5d013
See test case – previously, the highest-priority pending run would be used to calculate the timeout, rather than the earliest one. This probably managed to go undetected for that long as any unrelated changes to the pipeline (e.g. new submissions, or experiments pausing) would also cause _get_run() to be re-evaluated.
482 lines
17 KiB
Python
482 lines
17 KiB
Python
import asyncio
|
|
import logging
|
|
from enum import Enum
|
|
from time import time
|
|
|
|
from sipyco.sync_struct import Notifier
|
|
from sipyco.asyncio_tools import TaskObject, Condition
|
|
|
|
from artiq.master.worker import Worker, log_worker_exception
|
|
from artiq.tools import asyncio_wait_or_cancel
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RunStatus(Enum):
|
|
pending = 0
|
|
flushing = 1
|
|
preparing = 2
|
|
prepare_done = 3
|
|
running = 4
|
|
run_done = 5
|
|
analyzing = 6
|
|
deleting = 7
|
|
paused = 8
|
|
|
|
|
|
def _mk_worker_method(name):
|
|
async def worker_method(self, *args, **kwargs):
|
|
if self.worker.closed.is_set():
|
|
return True
|
|
m = getattr(self.worker, name)
|
|
try:
|
|
return await m(*args, **kwargs)
|
|
except Exception as e:
|
|
if isinstance(e, asyncio.CancelledError):
|
|
raise
|
|
if self.worker.closed.is_set():
|
|
logger.debug("suppressing worker exception of terminated run",
|
|
exc_info=True)
|
|
# Return completion on termination
|
|
return True
|
|
else:
|
|
raise
|
|
return worker_method
|
|
|
|
|
|
class Run:
|
|
def __init__(self, rid, pipeline_name,
|
|
wd, expid, priority, due_date, flush,
|
|
pool, **kwargs):
|
|
# called through pool
|
|
self.rid = rid
|
|
self.pipeline_name = pipeline_name
|
|
self.wd = wd
|
|
self.expid = expid
|
|
self.priority = priority
|
|
self.due_date = due_date
|
|
self.flush = flush
|
|
|
|
self.worker = Worker(pool.worker_handlers)
|
|
self.termination_requested = False
|
|
|
|
self._status = RunStatus.pending
|
|
|
|
notification = {
|
|
"pipeline": self.pipeline_name,
|
|
"expid": self.expid,
|
|
"priority": self.priority,
|
|
"due_date": self.due_date,
|
|
"flush": self.flush,
|
|
"status": self._status.name
|
|
}
|
|
notification.update(kwargs)
|
|
self._notifier = pool.notifier
|
|
self._notifier[self.rid] = notification
|
|
self._state_changed = pool.state_changed
|
|
|
|
@property
|
|
def status(self):
|
|
return self._status
|
|
|
|
@status.setter
|
|
def status(self, value):
|
|
self._status = value
|
|
if not self.worker.closed.is_set():
|
|
self._notifier[self.rid]["status"] = self._status.name
|
|
self._state_changed.notify()
|
|
|
|
def priority_key(self):
|
|
"""Return a comparable value that defines a run priority order.
|
|
|
|
Applies only to runs the due date of which has already elapsed.
|
|
"""
|
|
return (self.priority, -(self.due_date or 0), -self.rid)
|
|
|
|
async def close(self):
|
|
# called through pool
|
|
await self.worker.close()
|
|
del self._notifier[self.rid]
|
|
|
|
_build = _mk_worker_method("build")
|
|
|
|
async def build(self):
|
|
await self._build(self.rid, self.pipeline_name,
|
|
self.wd, self.expid,
|
|
self.priority)
|
|
|
|
prepare = _mk_worker_method("prepare")
|
|
run = _mk_worker_method("run")
|
|
resume = _mk_worker_method("resume")
|
|
analyze = _mk_worker_method("analyze")
|
|
|
|
|
|
class RunPool:
|
|
def __init__(self, ridc, worker_handlers, notifier, experiment_db):
|
|
self.runs = dict()
|
|
self.state_changed = Condition()
|
|
|
|
self.ridc = ridc
|
|
self.worker_handlers = worker_handlers
|
|
self.notifier = notifier
|
|
self.experiment_db = experiment_db
|
|
|
|
def submit(self, expid, priority, due_date, flush, pipeline_name):
|
|
# mutates expid to insert head repository revision if None.
|
|
# called through scheduler.
|
|
rid = self.ridc.get()
|
|
if "repo_rev" in expid:
|
|
if expid["repo_rev"] is None:
|
|
expid["repo_rev"] = self.experiment_db.cur_rev
|
|
wd, repo_msg = self.experiment_db.repo_backend.request_rev(
|
|
expid["repo_rev"])
|
|
else:
|
|
wd, repo_msg = None, None
|
|
run = Run(rid, pipeline_name, wd, expid, priority, due_date, flush,
|
|
self, repo_msg=repo_msg)
|
|
self.runs[rid] = run
|
|
self.state_changed.notify()
|
|
return rid
|
|
|
|
async def delete(self, rid):
|
|
# called through deleter
|
|
if rid not in self.runs:
|
|
return
|
|
run = self.runs[rid]
|
|
await run.close()
|
|
if "repo_rev" in run.expid:
|
|
self.experiment_db.repo_backend.release_rev(run.expid["repo_rev"])
|
|
del self.runs[rid]
|
|
|
|
|
|
class PrepareStage(TaskObject):
|
|
def __init__(self, pool, delete_cb):
|
|
self.pool = pool
|
|
self.delete_cb = delete_cb
|
|
|
|
def _get_run(self):
|
|
"""If a run should get prepared now, return it. Otherwise, return a
|
|
float giving the time until the next check, or None if no time-based
|
|
check is required.
|
|
|
|
The latter can be the case if there are no due-date runs, or none
|
|
of them are going to become next-in-line before further pool state
|
|
changes (which will also cause a re-evaluation).
|
|
"""
|
|
pending_runs = list(
|
|
filter(lambda r: r.status == RunStatus.pending,
|
|
self.pool.runs.values()))
|
|
|
|
now = time()
|
|
def is_runnable(r):
|
|
return (r.due_date or 0) < now
|
|
|
|
prepared_max = max((r.priority_key() for r in self.pool.runs.values()
|
|
if r.status == RunStatus.prepare_done),
|
|
default=None)
|
|
def takes_precedence(r):
|
|
return prepared_max is None or r.priority_key() > prepared_max
|
|
|
|
candidate = max(filter(is_runnable, pending_runs),
|
|
key=lambda r: r.priority_key(),
|
|
default=None)
|
|
if candidate is not None and takes_precedence(candidate):
|
|
return candidate
|
|
|
|
return min((r.due_date - now for r in pending_runs
|
|
if (not is_runnable(r) and takes_precedence(r))),
|
|
default=None)
|
|
|
|
async def _do(self):
|
|
while True:
|
|
run = self._get_run()
|
|
if run is None:
|
|
await self.pool.state_changed.wait()
|
|
elif isinstance(run, float):
|
|
await asyncio_wait_or_cancel([self.pool.state_changed.wait()],
|
|
timeout=run)
|
|
else:
|
|
if run.flush:
|
|
run.status = RunStatus.flushing
|
|
while not all(r.status in (RunStatus.pending,
|
|
RunStatus.deleting)
|
|
or r.priority < run.priority
|
|
or r is run
|
|
for r in self.pool.runs.values()):
|
|
ev = [self.pool.state_changed.wait(),
|
|
run.worker.closed.wait()]
|
|
await asyncio_wait_or_cancel(
|
|
ev, return_when=asyncio.FIRST_COMPLETED)
|
|
if run.worker.closed.is_set():
|
|
break
|
|
if run.worker.closed.is_set():
|
|
continue
|
|
run.status = RunStatus.preparing
|
|
try:
|
|
await run.build()
|
|
await run.prepare()
|
|
except:
|
|
logger.error("got worker exception in prepare stage, "
|
|
"deleting RID %d", run.rid)
|
|
log_worker_exception()
|
|
self.delete_cb(run.rid)
|
|
else:
|
|
run.status = RunStatus.prepare_done
|
|
|
|
|
|
class RunStage(TaskObject):
|
|
def __init__(self, pool, delete_cb):
|
|
self.pool = pool
|
|
self.delete_cb = delete_cb
|
|
|
|
def _get_run(self):
|
|
prepared_runs = filter(lambda r: r.status == RunStatus.prepare_done,
|
|
self.pool.runs.values())
|
|
try:
|
|
r = max(prepared_runs, key=lambda r: r.priority_key())
|
|
except ValueError:
|
|
# prepared_runs is an empty sequence
|
|
r = None
|
|
return r
|
|
|
|
async def _do(self):
|
|
stack = []
|
|
|
|
while True:
|
|
next_irun = self._get_run()
|
|
if not stack or (
|
|
next_irun is not None and
|
|
next_irun.priority_key() > stack[-1].priority_key()):
|
|
while next_irun is None:
|
|
await self.pool.state_changed.wait()
|
|
next_irun = self._get_run()
|
|
stack.append(next_irun)
|
|
|
|
run = stack.pop()
|
|
try:
|
|
if run.status == RunStatus.paused:
|
|
run.status = RunStatus.running
|
|
# clear "termination requested" flag now
|
|
# so that if it is set again during the resume, this
|
|
# results in another exception.
|
|
request_termination = run.termination_requested
|
|
run.termination_requested = False
|
|
completed = await run.resume(request_termination)
|
|
else:
|
|
run.status = RunStatus.running
|
|
completed = await run.run()
|
|
except:
|
|
logger.error("got worker exception in run stage, "
|
|
"deleting RID %d", run.rid)
|
|
log_worker_exception()
|
|
self.delete_cb(run.rid)
|
|
else:
|
|
if completed:
|
|
run.status = RunStatus.run_done
|
|
else:
|
|
run.status = RunStatus.paused
|
|
stack.append(run)
|
|
|
|
|
|
class AnalyzeStage(TaskObject):
|
|
def __init__(self, pool, delete_cb):
|
|
self.pool = pool
|
|
self.delete_cb = delete_cb
|
|
|
|
def _get_run(self):
|
|
run_runs = filter(lambda r: r.status == RunStatus.run_done,
|
|
self.pool.runs.values())
|
|
try:
|
|
r = max(run_runs, key=lambda r: r.priority_key())
|
|
except ValueError:
|
|
# run_runs is an empty sequence
|
|
r = None
|
|
return r
|
|
|
|
async def _do(self):
|
|
while True:
|
|
run = self._get_run()
|
|
while run is None:
|
|
await self.pool.state_changed.wait()
|
|
run = self._get_run()
|
|
run.status = RunStatus.analyzing
|
|
try:
|
|
await run.analyze()
|
|
except:
|
|
logger.error("got worker exception in analyze stage of RID %d.",
|
|
run.rid)
|
|
log_worker_exception()
|
|
self.delete_cb(run.rid)
|
|
|
|
|
|
class Pipeline:
|
|
def __init__(self, ridc, deleter, worker_handlers, notifier, experiment_db):
|
|
self.pool = RunPool(ridc, worker_handlers, notifier, experiment_db)
|
|
self._prepare = PrepareStage(self.pool, deleter.delete)
|
|
self._run = RunStage(self.pool, deleter.delete)
|
|
self._analyze = AnalyzeStage(self.pool, deleter.delete)
|
|
|
|
def start(self):
|
|
self._prepare.start()
|
|
self._run.start()
|
|
self._analyze.start()
|
|
|
|
async def stop(self):
|
|
# NB: restart of a stopped pipeline is not supported
|
|
await self._analyze.stop()
|
|
await self._run.stop()
|
|
await self._prepare.stop()
|
|
|
|
|
|
class Deleter(TaskObject):
|
|
"""Provides a synchronous interface for instigating deletion of runs.
|
|
|
|
:meth:`RunPool.delete` is an async function (it needs to close the worker
|
|
connection, etc.), so we maintain a queue of RIDs to delete on a background task.
|
|
"""
|
|
def __init__(self, pipelines):
|
|
self._pipelines = pipelines
|
|
self._queue = asyncio.Queue()
|
|
|
|
def delete(self, rid):
|
|
"""Delete the run with the given RID.
|
|
|
|
Multiple calls for the same RID are silently ignored.
|
|
"""
|
|
logger.debug("delete request for RID %d", rid)
|
|
for pipeline in self._pipelines.values():
|
|
if rid in pipeline.pool.runs:
|
|
pipeline.pool.runs[rid].status = RunStatus.deleting
|
|
break
|
|
self._queue.put_nowait(rid)
|
|
|
|
async def join(self):
|
|
await self._queue.join()
|
|
|
|
async def _delete(self, rid):
|
|
# By looking up the run by RID, we implicitly make sure to delete each run only
|
|
# once.
|
|
for pipeline in self._pipelines.values():
|
|
if rid in pipeline.pool.runs:
|
|
logger.debug("deleting RID %d...", rid)
|
|
await pipeline.pool.delete(rid)
|
|
logger.debug("deletion of RID %d completed", rid)
|
|
break
|
|
|
|
async def _gc_pipelines(self):
|
|
pipeline_names = list(self._pipelines.keys())
|
|
for name in pipeline_names:
|
|
if not self._pipelines[name].pool.runs:
|
|
logger.debug("garbage-collecting pipeline '%s'...", name)
|
|
await self._pipelines[name].stop()
|
|
del self._pipelines[name]
|
|
logger.debug("garbage-collection of pipeline '%s' completed",
|
|
name)
|
|
|
|
async def _do(self):
|
|
while True:
|
|
rid = await self._queue.get()
|
|
await self._delete(rid)
|
|
await self._gc_pipelines()
|
|
self._queue.task_done()
|
|
|
|
|
|
class Scheduler:
|
|
def __init__(self, ridc, worker_handlers, experiment_db):
|
|
self.notifier = Notifier(dict())
|
|
|
|
self._pipelines = dict()
|
|
self._worker_handlers = worker_handlers
|
|
self._experiment_db = experiment_db
|
|
self._terminated = False
|
|
|
|
self._ridc = ridc
|
|
self._deleter = Deleter(self._pipelines)
|
|
|
|
def start(self):
|
|
self._deleter.start()
|
|
|
|
async def stop(self):
|
|
# NB: restart of a stopped scheduler is not supported
|
|
self._terminated = True # prevent further runs from being created
|
|
for pipeline in self._pipelines.values():
|
|
for rid in pipeline.pool.runs.keys():
|
|
self._deleter.delete(rid)
|
|
await self._deleter.join()
|
|
await self._deleter.stop()
|
|
if self._pipelines:
|
|
logger.warning("some pipelines were not garbage-collected")
|
|
|
|
def submit(self, pipeline_name, expid, priority=0, due_date=None, flush=False):
|
|
"""Submits a new run.
|
|
|
|
When called through an experiment, the default values of
|
|
``pipeline_name``, ``expid`` and ``priority`` correspond to those of
|
|
the current run."""
|
|
# mutates expid to insert head repository revision if None
|
|
if self._terminated:
|
|
return
|
|
try:
|
|
pipeline = self._pipelines[pipeline_name]
|
|
except KeyError:
|
|
logger.debug("creating pipeline '%s'", pipeline_name)
|
|
pipeline = Pipeline(self._ridc, self._deleter,
|
|
self._worker_handlers, self.notifier,
|
|
self._experiment_db)
|
|
self._pipelines[pipeline_name] = pipeline
|
|
pipeline.start()
|
|
return pipeline.pool.submit(expid, priority, due_date, flush, pipeline_name)
|
|
|
|
def delete(self, rid):
|
|
"""Kills the run with the specified RID."""
|
|
self._deleter.delete(rid)
|
|
|
|
def request_termination(self, rid):
|
|
"""Requests graceful termination of the run with the specified RID."""
|
|
for pipeline in self._pipelines.values():
|
|
if rid in pipeline.pool.runs:
|
|
run = pipeline.pool.runs[rid]
|
|
if run.status == RunStatus.running or run.status == RunStatus.paused:
|
|
run.termination_requested = True
|
|
else:
|
|
self.delete(rid)
|
|
break
|
|
|
|
def get_status(self):
|
|
"""Returns a dictionary containing information about the runs currently
|
|
tracked by the scheduler.
|
|
|
|
Must not be modified."""
|
|
return self.notifier.raw_view
|
|
|
|
def check_pause(self, rid):
|
|
"""Returns ``True`` if there is a condition that could make ``pause``
|
|
not return immediately (termination requested or higher priority run).
|
|
|
|
The typical purpose of this function is to check from a kernel
|
|
whether returning control to the host and pausing would have an effect,
|
|
in order to avoid the cost of switching kernels in the common case
|
|
where ``pause`` does nothing.
|
|
|
|
This function does not have side effects, and does not have to be
|
|
followed by a call to ``pause``.
|
|
"""
|
|
for pipeline in self._pipelines.values():
|
|
if rid in pipeline.pool.runs:
|
|
run = pipeline.pool.runs[rid]
|
|
if run.status != RunStatus.running:
|
|
return False
|
|
if run.termination_requested:
|
|
return True
|
|
|
|
prepared_runs = filter(lambda r: r.status == RunStatus.prepare_done,
|
|
pipeline.pool.runs.values())
|
|
try:
|
|
r = max(prepared_runs, key=lambda r: r.priority_key())
|
|
except ValueError:
|
|
# prepared_runs is an empty sequence
|
|
return False
|
|
return r.priority_key() > run.priority_key()
|
|
raise KeyError("RID not found")
|