scheduler: refactor, fix pipeline hazards

This commit is contained in:
Sebastien Bourdeauducq 2015-08-10 21:58:11 +08:00
parent 47e3d0337d
commit 06badd1dc1
3 changed files with 137 additions and 115 deletions

View File

@ -4,8 +4,7 @@ from enum import Enum
from time import time from time import time
from artiq.master.worker import Worker from artiq.master.worker import Worker
from artiq.tools import (asyncio_wait_or_cancel, asyncio_queue_peek, from artiq.tools import asyncio_wait_or_cancel, TaskObject, Condition
TaskObject, WaitSet)
from artiq.protocols.sync_struct import Notifier from artiq.protocols.sync_struct import Notifier
@ -20,7 +19,7 @@ class RunStatus(Enum):
running = 4 running = 4
run_done = 5 run_done = 5
analyzing = 6 analyzing = 6
analyze_done = 7 deleting = 7
paused = 8 paused = 8
@ -48,7 +47,7 @@ def _mk_worker_method(name):
class Run: class Run:
def __init__(self, rid, pipeline_name, def __init__(self, rid, pipeline_name,
wd, expid, priority, due_date, flush, wd, expid, priority, due_date, flush,
worker_handlers, notifier, **kwargs): pool, **kwargs):
# called through pool # called through pool
self.rid = rid self.rid = rid
self.pipeline_name = pipeline_name self.pipeline_name = pipeline_name
@ -58,7 +57,7 @@ class Run:
self.due_date = due_date self.due_date = due_date
self.flush = flush self.flush = flush
self.worker = Worker(worker_handlers) self.worker = Worker(pool.worker_handlers)
self._status = RunStatus.pending self._status = RunStatus.pending
@ -71,8 +70,9 @@ class Run:
"status": self._status.name "status": self._status.name
} }
notification.update(kwargs) notification.update(kwargs)
self._notifier = notifier self._notifier = pool.notifier
self._notifier[self.rid] = notification self._notifier[self.rid] = notification
self._state_changed = pool.state_changed
@property @property
def status(self): def status(self):
@ -83,6 +83,7 @@ class Run:
self._status = value self._status = value
if not self.worker.closed.is_set(): if not self.worker.closed.is_set():
self._notifier[self.rid]["status"] = self._status.name self._notifier[self.rid]["status"] = self._status.name
self._state_changed.notify()
# The run with the largest priority_key is to be scheduled first # The run with the largest priority_key is to be scheduled first
def priority_key(self, now=None): def priority_key(self, now=None):
@ -130,28 +131,27 @@ class RIDCounter:
class RunPool: class RunPool:
def __init__(self, ridc, worker_handlers, notifier, repo_backend): def __init__(self, ridc, worker_handlers, notifier, repo_backend):
self.runs = dict() self.runs = dict()
self.submitted_cb = None self.state_changed = Condition()
self._ridc = ridc self.ridc = ridc
self._worker_handlers = worker_handlers self.worker_handlers = worker_handlers
self._notifier = notifier self.notifier = notifier
self._repo_backend = repo_backend self.repo_backend = repo_backend
def submit(self, expid, priority, due_date, flush, pipeline_name): def submit(self, expid, priority, due_date, flush, pipeline_name):
# mutates expid to insert head repository revision if None # mutates expid to insert head repository revision if None.
# called through scheduler # called through scheduler.
rid = self._ridc.get() rid = self.ridc.get()
if "repo_rev" in expid: if "repo_rev" in expid:
if expid["repo_rev"] is None: if expid["repo_rev"] is None:
expid["repo_rev"] = self._repo_backend.get_head_rev() expid["repo_rev"] = self.repo_backend.get_head_rev()
wd, repo_msg = self._repo_backend.request_rev(expid["repo_rev"]) wd, repo_msg = self.repo_backend.request_rev(expid["repo_rev"])
else: else:
wd, repo_msg = None, None wd, repo_msg = None, None
run = Run(rid, pipeline_name, wd, expid, priority, due_date, flush, run = Run(rid, pipeline_name, wd, expid, priority, due_date, flush,
self._worker_handlers, self._notifier, repo_msg=repo_msg) self, repo_msg=repo_msg)
self.runs[rid] = run self.runs[rid] = run
if self.submitted_cb is not None: self.state_changed.notify()
self.submitted_cb()
return rid return rid
@asyncio.coroutine @asyncio.coroutine
@ -162,47 +162,72 @@ class RunPool:
run = self.runs[rid] run = self.runs[rid]
yield from run.close() yield from run.close()
if "repo_rev" in run.expid: if "repo_rev" in run.expid:
self._repo_backend.release_rev(run.expid["repo_rev"]) self.repo_backend.release_rev(run.expid["repo_rev"])
del self.runs[rid] del self.runs[rid]
class PrepareStage(TaskObject): class PrepareStage(TaskObject):
def __init__(self, flush_tracker, delete_cb, pool, outq): def __init__(self, pool, delete_cb):
self.flush_tracker = flush_tracker
self.delete_cb = delete_cb
self.pool = pool self.pool = pool
self.outq = outq self.delete_cb = delete_cb
self.pool_submitted = asyncio.Event() def _get_run(self):
self.pool.submitted_cb = lambda: self.pool_submitted.set() """If a run should get prepared now, return it.
Otherwise, return a float representing the time before the next timed
run becomes due, or None if there is no such run."""
now = time()
pending_runs = filter(lambda r: r.status == RunStatus.pending,
self.pool.runs.values())
try:
candidate = max(pending_runs, key=lambda r: r.priority_key(now))
except ValueError:
# pending_runs is an empty sequence
return None
prepared_runs = filter(lambda r: r.status == RunStatus.prepare_done,
self.pool.runs.values())
try:
top_prepared_run = max(prepared_runs,
key=lambda r: r.priority_key())
except ValueError:
# there are no existing prepared runs - go ahead with <candidate>
pass
else:
# prepare <candidate> (as well) only if it has higher priority than
# the highest priority prepared run
if top_prepared_run.priority_key() >= candidate.priority_key():
return None
if candidate.due_date is None or candidate.due_date < now:
return candidate
else:
return candidate.due_date - now
@asyncio.coroutine @asyncio.coroutine
def _push_runs(self): def _do(self):
"""Pushes all runs that have no due date of have a due date in the
past.
Returns the time before the next schedulable run, or None if the
pool is empty."""
while True: while True:
now = time() run = self._get_run()
pending_runs = filter(lambda r: r.status == RunStatus.pending, if run is None:
self.pool.runs.values()) yield from self.pool.state_changed.wait()
try: elif isinstance(run, float):
run = max(pending_runs, key=lambda r: r.priority_key(now)) yield from asyncio_wait_or_cancel([self.pool.state_changed.wait()],
except ValueError: timeout=run)
# pending_runs is an empty sequence else:
return None
if run.due_date is None or run.due_date < now:
if run.flush: if run.flush:
run.status = RunStatus.flushing run.status = RunStatus.flushing
yield from asyncio_wait_or_cancel( while not all(r.status in (RunStatus.pending,
[self.flush_tracker.wait_empty(), RunStatus.deleting)
run.worker.closed.wait()], or r is run
return_when=asyncio.FIRST_COMPLETED) for r in self.pool.runs.values()):
ev = [self.pool.state_changed.wait(),
run.worker.closed.wait()]
yield from asyncio_wait_or_cancel(
ev, return_when=asyncio.FIRST_COMPLETED)
if run.worker.closed.is_set():
break
if run.worker.closed.is_set(): if run.worker.closed.is_set():
continue continue
run.status = RunStatus.preparing run.status = RunStatus.preparing
self.flush_tracker.add(run.rid)
try: try:
yield from run.build() yield from run.build()
yield from run.prepare() yield from run.prepare()
@ -211,44 +236,38 @@ class PrepareStage(TaskObject):
"deleting RID %d", "deleting RID %d",
run.rid, exc_info=True) run.rid, exc_info=True)
self.delete_cb(run.rid) self.delete_cb(run.rid)
run.status = RunStatus.prepare_done else:
yield from self.outq.put(run) run.status = RunStatus.prepare_done
else:
return run.due_date - now
@asyncio.coroutine
def _do(self):
while True:
next_timed_in = yield from self._push_runs()
if next_timed_in is None:
# pool is empty - wait for something to be added to it
yield from self.pool_submitted.wait()
else:
# wait for next_timed_in seconds, or until the pool is modified
yield from asyncio_wait_or_cancel([self.pool_submitted.wait()],
timeout=next_timed_in)
self.pool_submitted.clear()
class RunStage(TaskObject): class RunStage(TaskObject):
def __init__(self, delete_cb, inq, outq): def __init__(self, pool, delete_cb):
self.pool = pool
self.delete_cb = delete_cb self.delete_cb = delete_cb
self.inq = inq
self.outq = outq def _get_run(self):
prepared_runs = filter(lambda r: r.status == RunStatus.prepare_done,
self.pool.runs.values())
try:
r = max(prepared_runs, key=lambda r: r.priority_key())
except ValueError:
# prepared_runs is an empty sequence
r = None
return r
@asyncio.coroutine @asyncio.coroutine
def _do(self): def _do(self):
stack = [] stack = []
while True: while True:
try: next_irun = self._get_run()
next_irun = asyncio_queue_peek(self.inq)
except asyncio.QueueEmpty:
next_irun = None
if not stack or ( if not stack or (
next_irun is not None and next_irun is not None and
next_irun.priority_key() > stack[-1].priority_key()): next_irun.priority_key() > stack[-1].priority_key()):
stack.append((yield from self.inq.get())) while next_irun is None:
yield from self.pool.state_changed.wait()
next_irun = self._get_run()
stack.append(next_irun)
run = stack.pop() run = stack.pop()
try: try:
@ -266,21 +285,33 @@ class RunStage(TaskObject):
else: else:
if completed: if completed:
run.status = RunStatus.run_done run.status = RunStatus.run_done
yield from self.outq.put(run)
else: else:
run.status = RunStatus.paused run.status = RunStatus.paused
stack.append(run) stack.append(run)
class AnalyzeStage(TaskObject): class AnalyzeStage(TaskObject):
def __init__(self, delete_cb, inq): def __init__(self, pool, delete_cb):
self.pool = pool
self.delete_cb = delete_cb self.delete_cb = delete_cb
self.inq = inq
def _get_run(self):
run_runs = filter(lambda r: r.status == RunStatus.run_done,
self.pool.runs.values())
try:
r = max(run_runs, key=lambda r: r.priority_key())
except ValueError:
# run_runs is an empty sequence
r = None
return r
@asyncio.coroutine @asyncio.coroutine
def _do(self): def _do(self):
while True: while True:
run = yield from self.inq.get() run = self._get_run()
while run is None:
yield from self.pool.state_changed.wait()
run = self._get_run()
run.status = RunStatus.analyzing run.status = RunStatus.analyzing
try: try:
yield from run.analyze() yield from run.analyze()
@ -290,22 +321,16 @@ class AnalyzeStage(TaskObject):
"deleting RID %d", "deleting RID %d",
run.rid, exc_info=True) run.rid, exc_info=True)
self.delete_cb(run.rid) self.delete_cb(run.rid)
run.status = RunStatus.analyze_done else:
self.delete_cb(run.rid) self.delete_cb(run.rid)
class Pipeline: class Pipeline:
def __init__(self, ridc, deleter, worker_handlers, notifier, repo_backend): def __init__(self, ridc, deleter, worker_handlers, notifier, repo_backend):
flush_tracker = WaitSet()
def delete_cb(rid):
deleter.delete(rid)
flush_tracker.discard(rid)
self.pool = RunPool(ridc, worker_handlers, notifier, repo_backend) self.pool = RunPool(ridc, worker_handlers, notifier, repo_backend)
self._prepare = PrepareStage(flush_tracker, delete_cb, self._prepare = PrepareStage(self.pool, deleter.delete)
self.pool, asyncio.Queue(maxsize=1)) self._run = RunStage(self.pool, deleter.delete)
self._run = RunStage(delete_cb, self._analyze = AnalyzeStage(self.pool, deleter.delete)
self._prepare.outq, asyncio.Queue(maxsize=1))
self._analyze = AnalyzeStage(delete_cb, self._run.outq)
def start(self): def start(self):
self._prepare.start() self._prepare.start()
@ -327,6 +352,10 @@ class Deleter(TaskObject):
def delete(self, rid): def delete(self, rid):
logger.debug("delete request for RID %d", rid) logger.debug("delete request for RID %d", rid)
for pipeline in self._pipelines.values():
if rid in pipeline.pool.runs:
pipeline.pool.runs[rid].status = RunStatus.deleting
break
self._queue.put_nowait(rid) self._queue.put_nowait(rid)
@asyncio.coroutine @asyncio.coroutine

View File

@ -50,7 +50,7 @@ def _get_basic_steps(rid, expid, priority=0, flush=False):
"path": [rid]}, "path": [rid]},
{"action": "setitem", "key": "status", "value": "analyzing", {"action": "setitem", "key": "status", "value": "analyzing",
"path": [rid]}, "path": [rid]},
{"action": "setitem", "key": "status", "value": "analyze_done", {"action": "setitem", "key": "status", "value": "deleting",
"path": [rid]}, "path": [rid]},
{"action": "delitem", "key": rid, "path": []} {"action": "delitem", "key": rid, "path": []}
] ]

View File

@ -5,6 +5,7 @@ import logging
import sys import sys
import asyncio import asyncio
import time import time
import collections
import os.path import os.path
from artiq.language.environment import is_experiment from artiq.language.environment import is_experiment
@ -125,14 +126,6 @@ def asyncio_wait_or_cancel(fs, **kwargs):
return fs return fs
def asyncio_queue_peek(q):
"""Like q.get_nowait(), but does not remove the item from the queue."""
if q._queue:
return q._queue[0]
else:
raise asyncio.QueueEmpty
class TaskObject: class TaskObject:
def start(self): def start(self):
self.task = asyncio.async(self._do()) self.task = asyncio.async(self._do())
@ -151,25 +144,25 @@ class TaskObject:
raise NotImplementedError raise NotImplementedError
class WaitSet: class Condition:
def __init__(self): def __init__(self, *, loop=None):
self._s = set() if loop is not None:
self._ev = asyncio.Event() self._loop = loop
def _update_ev(self):
if self._s:
self._ev.clear()
else: else:
self._ev.set() self._loop = asyncio.get_event_loop()
self._waiters = collections.deque()
def add(self, e):
self._s.add(e)
self._update_ev()
def discard(self, e):
self._s.discard(e)
self._update_ev()
@asyncio.coroutine @asyncio.coroutine
def wait_empty(self): def wait(self):
yield from self._ev.wait() """Wait until notified."""
fut = asyncio.Future(loop=self._loop)
self._waiters.append(fut)
try:
yield from fut
finally:
self._waiters.remove(fut)
def notify(self):
for fut in self._waiters:
if not fut.done():
fut.set_result(False)