forked from M-Labs/artiq
467 lines
16 KiB
Python
467 lines
16 KiB
Python
import asyncio
|
|
import logging
|
|
from enum import Enum
|
|
from time import time
|
|
|
|
from artiq.master.worker import Worker, log_worker_exception
|
|
from artiq.tools import asyncio_wait_or_cancel, TaskObject, Condition
|
|
from artiq.protocols.sync_struct import Notifier
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RunStatus(Enum):
|
|
pending = 0
|
|
flushing = 1
|
|
preparing = 2
|
|
prepare_done = 3
|
|
running = 4
|
|
run_done = 5
|
|
analyzing = 6
|
|
deleting = 7
|
|
paused = 8
|
|
|
|
|
|
def _mk_worker_method(name):
|
|
async def worker_method(self, *args, **kwargs):
|
|
if self.worker.closed.is_set():
|
|
return True
|
|
m = getattr(self.worker, name)
|
|
try:
|
|
return await m(*args, **kwargs)
|
|
except Exception as e:
|
|
if isinstance(e, asyncio.CancelledError):
|
|
raise
|
|
if self.worker.closed.is_set():
|
|
logger.debug("suppressing worker exception of terminated run",
|
|
exc_info=True)
|
|
# Return completion on termination
|
|
return True
|
|
else:
|
|
raise
|
|
return worker_method
|
|
|
|
|
|
class Run:
|
|
def __init__(self, rid, pipeline_name,
|
|
wd, expid, priority, due_date, flush,
|
|
pool, **kwargs):
|
|
# called through pool
|
|
self.rid = rid
|
|
self.pipeline_name = pipeline_name
|
|
self.wd = wd
|
|
self.expid = expid
|
|
self.priority = priority
|
|
self.due_date = due_date
|
|
self.flush = flush
|
|
|
|
self.worker = Worker(pool.worker_handlers)
|
|
self.termination_requested = False
|
|
|
|
self._status = RunStatus.pending
|
|
|
|
notification = {
|
|
"pipeline": self.pipeline_name,
|
|
"expid": self.expid,
|
|
"priority": self.priority,
|
|
"due_date": self.due_date,
|
|
"flush": self.flush,
|
|
"status": self._status.name
|
|
}
|
|
notification.update(kwargs)
|
|
self._notifier = pool.notifier
|
|
self._notifier[self.rid] = notification
|
|
self._state_changed = pool.state_changed
|
|
|
|
@property
|
|
def status(self):
|
|
return self._status
|
|
|
|
@status.setter
|
|
def status(self, value):
|
|
self._status = value
|
|
if not self.worker.closed.is_set():
|
|
self._notifier[self.rid]["status"] = self._status.name
|
|
self._state_changed.notify()
|
|
|
|
# The run with the largest priority_key is to be scheduled first
|
|
def priority_key(self, now=None):
|
|
if self.due_date is None:
|
|
due_date_k = 0
|
|
else:
|
|
due_date_k = -self.due_date
|
|
if now is not None and self.due_date is not None:
|
|
runnable = int(now > self.due_date)
|
|
else:
|
|
runnable = 1
|
|
return (runnable, self.priority, due_date_k, -self.rid)
|
|
|
|
async def close(self):
|
|
# called through pool
|
|
await self.worker.close()
|
|
del self._notifier[self.rid]
|
|
|
|
_build = _mk_worker_method("build")
|
|
|
|
async def build(self):
|
|
await self._build(self.rid, self.pipeline_name,
|
|
self.wd, self.expid,
|
|
self.priority)
|
|
|
|
prepare = _mk_worker_method("prepare")
|
|
run = _mk_worker_method("run")
|
|
resume = _mk_worker_method("resume")
|
|
analyze = _mk_worker_method("analyze")
|
|
write_results = _mk_worker_method("write_results")
|
|
|
|
|
|
class RunPool:
|
|
def __init__(self, ridc, worker_handlers, notifier, experiment_db):
|
|
self.runs = dict()
|
|
self.state_changed = Condition()
|
|
|
|
self.ridc = ridc
|
|
self.worker_handlers = worker_handlers
|
|
self.notifier = notifier
|
|
self.experiment_db = experiment_db
|
|
|
|
def submit(self, expid, priority, due_date, flush, pipeline_name):
|
|
# mutates expid to insert head repository revision if None.
|
|
# called through scheduler.
|
|
rid = self.ridc.get()
|
|
if "repo_rev" in expid:
|
|
if expid["repo_rev"] is None:
|
|
expid["repo_rev"] = self.experiment_db.cur_rev
|
|
wd, repo_msg = self.experiment_db.repo_backend.request_rev(
|
|
expid["repo_rev"])
|
|
else:
|
|
wd, repo_msg = None, None
|
|
run = Run(rid, pipeline_name, wd, expid, priority, due_date, flush,
|
|
self, repo_msg=repo_msg)
|
|
self.runs[rid] = run
|
|
self.state_changed.notify()
|
|
return rid
|
|
|
|
async def delete(self, rid):
|
|
# called through deleter
|
|
if rid not in self.runs:
|
|
return
|
|
run = self.runs[rid]
|
|
await run.close()
|
|
if "repo_rev" in run.expid:
|
|
self.experiment_db.repo_backend.release_rev(run.expid["repo_rev"])
|
|
del self.runs[rid]
|
|
|
|
|
|
class PrepareStage(TaskObject):
|
|
def __init__(self, pool, delete_cb):
|
|
self.pool = pool
|
|
self.delete_cb = delete_cb
|
|
|
|
def _get_run(self):
|
|
"""If a run should get prepared now, return it.
|
|
Otherwise, return a float representing the time before the next timed
|
|
run becomes due, or None if there is no such run."""
|
|
now = time()
|
|
pending_runs = filter(lambda r: r.status == RunStatus.pending,
|
|
self.pool.runs.values())
|
|
try:
|
|
candidate = max(pending_runs, key=lambda r: r.priority_key(now))
|
|
except ValueError:
|
|
# pending_runs is an empty sequence
|
|
return None
|
|
|
|
prepared_runs = filter(lambda r: r.status == RunStatus.prepare_done,
|
|
self.pool.runs.values())
|
|
try:
|
|
top_prepared_run = max(prepared_runs,
|
|
key=lambda r: r.priority_key())
|
|
except ValueError:
|
|
# there are no existing prepared runs - go ahead with <candidate>
|
|
pass
|
|
else:
|
|
# prepare <candidate> (as well) only if it has higher priority than
|
|
# the highest priority prepared run
|
|
if top_prepared_run.priority_key() >= candidate.priority_key():
|
|
return None
|
|
|
|
if candidate.due_date is None or candidate.due_date < now:
|
|
return candidate
|
|
else:
|
|
return candidate.due_date - now
|
|
|
|
async def _do(self):
|
|
while True:
|
|
run = self._get_run()
|
|
if run is None:
|
|
await self.pool.state_changed.wait()
|
|
elif isinstance(run, float):
|
|
await asyncio_wait_or_cancel([self.pool.state_changed.wait()],
|
|
timeout=run)
|
|
else:
|
|
if run.flush:
|
|
run.status = RunStatus.flushing
|
|
while not all(r.status in (RunStatus.pending,
|
|
RunStatus.deleting)
|
|
or r is run
|
|
for r in self.pool.runs.values()):
|
|
ev = [self.pool.state_changed.wait(),
|
|
run.worker.closed.wait()]
|
|
await asyncio_wait_or_cancel(
|
|
ev, return_when=asyncio.FIRST_COMPLETED)
|
|
if run.worker.closed.is_set():
|
|
break
|
|
if run.worker.closed.is_set():
|
|
continue
|
|
run.status = RunStatus.preparing
|
|
try:
|
|
await run.build()
|
|
await run.prepare()
|
|
except:
|
|
logger.error("got worker exception in prepare stage, "
|
|
"deleting RID %d", run.rid)
|
|
log_worker_exception()
|
|
self.delete_cb(run.rid)
|
|
else:
|
|
run.status = RunStatus.prepare_done
|
|
|
|
|
|
class RunStage(TaskObject):
|
|
def __init__(self, pool, delete_cb):
|
|
self.pool = pool
|
|
self.delete_cb = delete_cb
|
|
|
|
def _get_run(self):
|
|
prepared_runs = filter(lambda r: r.status == RunStatus.prepare_done,
|
|
self.pool.runs.values())
|
|
try:
|
|
r = max(prepared_runs, key=lambda r: r.priority_key())
|
|
except ValueError:
|
|
# prepared_runs is an empty sequence
|
|
r = None
|
|
return r
|
|
|
|
async def _do(self):
|
|
stack = []
|
|
|
|
while True:
|
|
next_irun = self._get_run()
|
|
if not stack or (
|
|
next_irun is not None and
|
|
next_irun.priority_key() > stack[-1].priority_key()):
|
|
while next_irun is None:
|
|
await self.pool.state_changed.wait()
|
|
next_irun = self._get_run()
|
|
stack.append(next_irun)
|
|
|
|
run = stack.pop()
|
|
try:
|
|
if run.status == RunStatus.paused:
|
|
run.status = RunStatus.running
|
|
# clear "termination requested" flag now
|
|
# so that if it is set again during the resume, this
|
|
# results in another exception.
|
|
request_termination = run.termination_requested
|
|
run.termination_requested = False
|
|
completed = await run.resume(request_termination)
|
|
else:
|
|
run.status = RunStatus.running
|
|
completed = await run.run()
|
|
except:
|
|
logger.error("got worker exception in run stage, "
|
|
"deleting RID %d", run.rid)
|
|
log_worker_exception()
|
|
self.delete_cb(run.rid)
|
|
else:
|
|
if completed:
|
|
run.status = RunStatus.run_done
|
|
else:
|
|
run.status = RunStatus.paused
|
|
stack.append(run)
|
|
|
|
|
|
class AnalyzeStage(TaskObject):
|
|
def __init__(self, pool, delete_cb):
|
|
self.pool = pool
|
|
self.delete_cb = delete_cb
|
|
|
|
def _get_run(self):
|
|
run_runs = filter(lambda r: r.status == RunStatus.run_done,
|
|
self.pool.runs.values())
|
|
try:
|
|
r = max(run_runs, key=lambda r: r.priority_key())
|
|
except ValueError:
|
|
# run_runs is an empty sequence
|
|
r = None
|
|
return r
|
|
|
|
async def _do(self):
|
|
while True:
|
|
run = self._get_run()
|
|
while run is None:
|
|
await self.pool.state_changed.wait()
|
|
run = self._get_run()
|
|
run.status = RunStatus.analyzing
|
|
try:
|
|
await run.analyze()
|
|
await run.write_results()
|
|
except:
|
|
logger.error("got worker exception in analyze stage, "
|
|
"deleting RID %d", run.rid)
|
|
log_worker_exception()
|
|
self.delete_cb(run.rid)
|
|
else:
|
|
self.delete_cb(run.rid)
|
|
|
|
|
|
class Pipeline:
|
|
def __init__(self, ridc, deleter, worker_handlers, notifier, experiment_db):
|
|
self.pool = RunPool(ridc, worker_handlers, notifier, experiment_db)
|
|
self._prepare = PrepareStage(self.pool, deleter.delete)
|
|
self._run = RunStage(self.pool, deleter.delete)
|
|
self._analyze = AnalyzeStage(self.pool, deleter.delete)
|
|
|
|
def start(self):
|
|
self._prepare.start()
|
|
self._run.start()
|
|
self._analyze.start()
|
|
|
|
async def stop(self):
|
|
# NB: restart of a stopped pipeline is not supported
|
|
await self._analyze.stop()
|
|
await self._run.stop()
|
|
await self._prepare.stop()
|
|
|
|
|
|
class Deleter(TaskObject):
|
|
def __init__(self, pipelines):
|
|
self._pipelines = pipelines
|
|
self._queue = asyncio.Queue()
|
|
|
|
def delete(self, rid):
|
|
logger.debug("delete request for RID %d", rid)
|
|
for pipeline in self._pipelines.values():
|
|
if rid in pipeline.pool.runs:
|
|
pipeline.pool.runs[rid].status = RunStatus.deleting
|
|
break
|
|
self._queue.put_nowait(rid)
|
|
|
|
async def join(self):
|
|
await self._queue.join()
|
|
|
|
async def _delete(self, rid):
|
|
for pipeline in self._pipelines.values():
|
|
if rid in pipeline.pool.runs:
|
|
logger.debug("deleting RID %d...", rid)
|
|
await pipeline.pool.delete(rid)
|
|
logger.debug("deletion of RID %d completed", rid)
|
|
break
|
|
|
|
async def _gc_pipelines(self):
|
|
pipeline_names = list(self._pipelines.keys())
|
|
for name in pipeline_names:
|
|
if not self._pipelines[name].pool.runs:
|
|
logger.debug("garbage-collecting pipeline '%s'...", name)
|
|
await self._pipelines[name].stop()
|
|
del self._pipelines[name]
|
|
logger.debug("garbage-collection of pipeline '%s' completed",
|
|
name)
|
|
|
|
async def _do(self):
|
|
while True:
|
|
rid = await self._queue.get()
|
|
await self._delete(rid)
|
|
await self._gc_pipelines()
|
|
self._queue.task_done()
|
|
|
|
|
|
class Scheduler:
|
|
def __init__(self, ridc, worker_handlers, experiment_db):
|
|
self.notifier = Notifier(dict())
|
|
|
|
self._pipelines = dict()
|
|
self._worker_handlers = worker_handlers
|
|
self._experiment_db = experiment_db
|
|
self._terminated = False
|
|
|
|
self._ridc = ridc
|
|
self._deleter = Deleter(self._pipelines)
|
|
|
|
def start(self):
|
|
self._deleter.start()
|
|
|
|
async def stop(self):
|
|
# NB: restart of a stopped scheduler is not supported
|
|
self._terminated = True # prevent further runs from being created
|
|
for pipeline in self._pipelines.values():
|
|
for rid in pipeline.pool.runs.keys():
|
|
self._deleter.delete(rid)
|
|
await self._deleter.join()
|
|
await self._deleter.stop()
|
|
if self._pipelines:
|
|
logger.warning("some pipelines were not garbage-collected")
|
|
|
|
def submit(self, pipeline_name, expid, priority, due_date, flush):
|
|
"""Submits a new run."""
|
|
# mutates expid to insert head repository revision if None
|
|
if self._terminated:
|
|
return
|
|
try:
|
|
pipeline = self._pipelines[pipeline_name]
|
|
except KeyError:
|
|
logger.debug("creating pipeline '%s'", pipeline_name)
|
|
pipeline = Pipeline(self._ridc, self._deleter,
|
|
self._worker_handlers, self.notifier,
|
|
self._experiment_db)
|
|
self._pipelines[pipeline_name] = pipeline
|
|
pipeline.start()
|
|
return pipeline.pool.submit(expid, priority, due_date, flush, pipeline_name)
|
|
|
|
def delete(self, rid):
|
|
"""Kills the run with the specified RID."""
|
|
self._deleter.delete(rid)
|
|
|
|
def request_termination(self, rid):
|
|
"""Requests graceful termination of the run with the specified RID."""
|
|
for pipeline in self._pipelines.values():
|
|
if rid in pipeline.pool.runs:
|
|
run = pipeline.pool.runs[rid]
|
|
if run.status == RunStatus.running or run.status == RunStatus.paused:
|
|
run.termination_requested = True
|
|
else:
|
|
self.delete(rid)
|
|
break
|
|
|
|
def get_status(self):
|
|
"""Returns a dictionary containing information about the runs currently
|
|
tracked by the scheduler."""
|
|
return self.notifier.read
|
|
|
|
def check_pause(self, rid):
|
|
"""Returns ``True`` if there is a condition that could make ``pause``
|
|
not return immediately (termination requested or higher priority run).
|
|
|
|
The typical purpose of this function is to check from a kernel
|
|
whether returning control to the host and pausing would have an effect,
|
|
in order to avoid the cost of switching kernels in the common case
|
|
where ``pause`` does nothing.
|
|
"""
|
|
for pipeline in self._pipelines.values():
|
|
if rid in pipeline.pool.runs:
|
|
run = pipeline.pool.runs[rid]
|
|
if run.status != RunStatus.running:
|
|
return False
|
|
if run.termination_requested:
|
|
return True
|
|
|
|
prepared_runs = filter(lambda r: r.status == RunStatus.prepare_done,
|
|
pipeline.pool.runs.values())
|
|
try:
|
|
r = max(prepared_runs, key=lambda r: r.priority_key())
|
|
except ValueError:
|
|
# prepared_runs is an empty sequence
|
|
return False
|
|
return r.priority_key() > run.priority_key()
|
|
raise KeyError("RID not found")
|