diff --git a/artiq/frontend/artiq_master.py b/artiq/frontend/artiq_master.py index 1d3a8617e..a89ef34f6 100755 --- a/artiq/frontend/artiq_master.py +++ b/artiq/frontend/artiq_master.py @@ -63,30 +63,30 @@ def main(): dataset_db = DatasetDB(args.dataset_db) dataset_db.start() atexit_register_coroutine(dataset_db.stop) + worker_handlers = dict() if args.git: repo_backend = GitBackend(args.repository) else: repo_backend = FilesystemBackend(args.repository) - experiment_db = ExperimentDB(repo_backend, device_db.get_device_db) + experiment_db = ExperimentDB(repo_backend, worker_handlers) atexit.register(experiment_db.close) - experiment_db.scan_repository_async() - worker_handlers = { + scheduler = Scheduler(RIDCounter(), worker_handlers, experiment_db) + scheduler.start() + atexit_register_coroutine(scheduler.stop) + + worker_handlers.update({ "get_device_db": device_db.get_device_db, "get_device": device_db.get, "get_dataset": dataset_db.get, - "update_dataset": dataset_db.update - } - scheduler = Scheduler(RIDCounter(), worker_handlers, experiment_db) - worker_handlers.update({ + "update_dataset": dataset_db.update, "scheduler_submit": scheduler.submit, "scheduler_delete": scheduler.delete, "scheduler_request_termination": scheduler.request_termination, - "scheduler_get_status": scheduler.get_status + "scheduler_get_status": scheduler.get_status }) - scheduler.start() - atexit_register_coroutine(scheduler.stop) + experiment_db.scan_repository_async() bind = bind_address_from_args(args) diff --git a/artiq/master/experiments.py b/artiq/master/experiments.py index 866351f43..904d0c804 100644 --- a/artiq/master/experiments.py +++ b/artiq/master/experiments.py @@ -14,8 +14,8 @@ logger = logging.getLogger(__name__) async def _get_repository_entries(entry_dict, - root, filename, get_device_db): - worker = Worker({"get_device_db": get_device_db}) + root, filename, worker_handlers): + worker = Worker(worker_handlers) try: description = await worker.examine("scan", os.path.join(root, filename)) except: @@ -45,7 +45,7 @@ async def _get_repository_entries(entry_dict, entry_dict[name] = entry -async def _scan_experiments(root, get_device_db, subdir=""): +async def _scan_experiments(root, worker_handlers, subdir=""): entry_dict = dict() for de in os.scandir(os.path.join(root, subdir)): if de.name.startswith("."): @@ -54,13 +54,13 @@ async def _scan_experiments(root, get_device_db, subdir=""): filename = os.path.join(subdir, de.name) try: await _get_repository_entries( - entry_dict, root, filename, get_device_db) + entry_dict, root, filename, worker_handlers) except Exception as exc: logger.warning("Skipping file '%s'", filename, exc_info=not isinstance(exc, WorkerInternalException)) if de.is_dir(): subentries = await _scan_experiments( - root, get_device_db, + root, worker_handlers, os.path.join(subdir, de.name)) entries = {de.name + "/" + k: v for k, v in subentries.items()} entry_dict.update(entries) @@ -77,9 +77,9 @@ def _sync_explist(target, source): class ExperimentDB: - def __init__(self, repo_backend, get_device_db_fn): + def __init__(self, repo_backend, worker_handlers): self.repo_backend = repo_backend - self.get_device_db_fn = get_device_db_fn + self.worker_handlers = worker_handlers self.cur_rev = self.repo_backend.get_head_rev() self.repo_backend.request_rev(self.cur_rev) @@ -107,7 +107,7 @@ class ExperimentDB: self.repo_backend.release_rev(self.cur_rev) self.cur_rev = new_cur_rev self.status["cur_rev"] = new_cur_rev - new_explist = await _scan_experiments(wd, self.get_device_db_fn) + new_explist = await _scan_experiments(wd, self.worker_handlers) _sync_explist(self.explist, new_explist) finally: @@ -123,7 +123,7 @@ class ExperimentDB: revision = self.cur_rev wd, _ = self.repo_backend.request_rev(revision) filename = os.path.join(wd, filename) - worker = Worker({"get_device_db": self.get_device_db_fn}) + worker = Worker(self.worker_handlers) try: description = await worker.examine("examine", filename) finally: