forked from M-Labs/artiq
sync_struct: Notifier.{read -> raw_view}, factor out common dict update code [nfc]
This commit is contained in:
parent
bd71852427
commit
c213ab13ba
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import tokenize
|
||||
|
||||
from artiq.protocols.sync_struct import Notifier, process_mod
|
||||
from artiq.protocols.sync_struct import Notifier, process_mod, update_from_dict
|
||||
from artiq.protocols import pyon
|
||||
from artiq.tools import TaskObject
|
||||
|
||||
@ -19,20 +19,14 @@ class DeviceDB:
|
||||
self.data = Notifier(device_db_from_file(self.backing_file))
|
||||
|
||||
def scan(self):
|
||||
new_data = device_db_from_file(self.backing_file)
|
||||
|
||||
for k in list(self.data.read.keys()):
|
||||
if k not in new_data:
|
||||
del self.data[k]
|
||||
for k in new_data.keys():
|
||||
if k not in self.data.read or self.data.read[k] != new_data[k]:
|
||||
self.data[k] = new_data[k]
|
||||
update_from_dict(self.data,
|
||||
device_db_from_file(self.backing_file))
|
||||
|
||||
def get_device_db(self):
|
||||
return self.data.read
|
||||
return self.data.raw_view
|
||||
|
||||
def get(self, key):
|
||||
return self.data.read[key]
|
||||
return self.data.raw_view[key]
|
||||
|
||||
|
||||
class DatasetDB(TaskObject):
|
||||
@ -47,7 +41,7 @@ class DatasetDB(TaskObject):
|
||||
self.data = Notifier({k: (True, v) for k, v in file_data.items()})
|
||||
|
||||
def save(self):
|
||||
data = {k: v[1] for k, v in self.data.read.items() if v[0]}
|
||||
data = {k: v[1] for k, v in self.data.raw_view.items() if v[0]}
|
||||
pyon.store_file(self.persist_file, data)
|
||||
|
||||
async def _do(self):
|
||||
@ -59,7 +53,7 @@ class DatasetDB(TaskObject):
|
||||
self.save()
|
||||
|
||||
def get(self, key):
|
||||
return self.data.read[key][1]
|
||||
return self.data.raw_view[key][1]
|
||||
|
||||
def update(self, mod):
|
||||
process_mod(self.data, mod)
|
||||
@ -67,8 +61,8 @@ class DatasetDB(TaskObject):
|
||||
# convenience functions (update() can be used instead)
|
||||
def set(self, key, value, persist=None):
|
||||
if persist is None:
|
||||
if key in self.data.read:
|
||||
persist = self.data.read[key][0]
|
||||
if key in self.data.raw_view:
|
||||
persist = self.data.raw_view[key][0]
|
||||
else:
|
||||
persist = False
|
||||
self.data[key] = (persist, value)
|
||||
|
@ -5,7 +5,7 @@ import shutil
|
||||
import time
|
||||
import logging
|
||||
|
||||
from artiq.protocols.sync_struct import Notifier
|
||||
from artiq.protocols.sync_struct import Notifier, update_from_dict
|
||||
from artiq.master.worker import (Worker, WorkerInternalException,
|
||||
log_worker_exception)
|
||||
from artiq.tools import get_windows_drives, exc_to_warning
|
||||
@ -81,15 +81,6 @@ class _RepoScanner:
|
||||
return r
|
||||
|
||||
|
||||
def _sync_explist(target, source):
|
||||
for k in list(target.read.keys()):
|
||||
if k not in source:
|
||||
del target[k]
|
||||
for k in source.keys():
|
||||
if k not in target.read or target.read[k] != source[k]:
|
||||
target[k] = source[k]
|
||||
|
||||
|
||||
class ExperimentDB:
|
||||
def __init__(self, repo_backend, worker_handlers):
|
||||
self.repo_backend = repo_backend
|
||||
@ -125,7 +116,7 @@ class ExperimentDB:
|
||||
new_explist = await _RepoScanner(self.worker_handlers).scan(wd)
|
||||
logger.info("repository scan took %d seconds", time.monotonic()-t1)
|
||||
|
||||
_sync_explist(self.explist, new_explist)
|
||||
update_from_dict(self.explist, new_explist)
|
||||
finally:
|
||||
self._scanning = False
|
||||
self.status["scanning"] = False
|
||||
|
@ -442,8 +442,10 @@ class Scheduler:
|
||||
|
||||
def get_status(self):
|
||||
"""Returns a dictionary containing information about the runs currently
|
||||
tracked by the scheduler."""
|
||||
return self.notifier.read
|
||||
tracked by the scheduler.
|
||||
|
||||
Must not be modified."""
|
||||
return self.notifier.raw_view
|
||||
|
||||
def check_pause(self, rid):
|
||||
"""Returns ``True`` if there is a condition that could make ``pause``
|
||||
|
@ -110,12 +110,12 @@ class DeviceManager:
|
||||
|
||||
class DatasetManager:
|
||||
def __init__(self, ddb):
|
||||
self.broadcast = Notifier(dict())
|
||||
self._broadcaster = Notifier(dict())
|
||||
self.local = dict()
|
||||
self.archive = dict()
|
||||
|
||||
self.ddb = ddb
|
||||
self.broadcast.publish = ddb.update
|
||||
self._broadcaster.publish = ddb.update
|
||||
|
||||
def set(self, key, value, broadcast=False, persist=False, archive=True):
|
||||
if key in self.archive:
|
||||
@ -125,10 +125,12 @@ class DatasetManager:
|
||||
|
||||
if persist:
|
||||
broadcast = True
|
||||
|
||||
if broadcast:
|
||||
self.broadcast[key] = persist, value
|
||||
elif key in self.broadcast.read:
|
||||
del self.broadcast[key]
|
||||
self._broadcaster[key] = persist, value
|
||||
elif key in self._broadcaster.raw_view:
|
||||
del self._broadcaster[key]
|
||||
|
||||
if archive:
|
||||
self.local[key] = value
|
||||
elif key in self.local:
|
||||
@ -138,10 +140,10 @@ class DatasetManager:
|
||||
target = None
|
||||
if key in self.local:
|
||||
target = self.local[key]
|
||||
if key in self.broadcast.read:
|
||||
if key in self._broadcaster.raw_view:
|
||||
if target is not None:
|
||||
assert target is self.broadcast.read[key][1]
|
||||
target = self.broadcast[key][1]
|
||||
assert target is self._broadcaster.raw_view[key][1]
|
||||
target = self._broadcaster[key][1]
|
||||
if target is None:
|
||||
raise KeyError("Cannot mutate non-existing dataset")
|
||||
|
||||
@ -155,19 +157,20 @@ class DatasetManager:
|
||||
def get(self, key, archive=False):
|
||||
if key in self.local:
|
||||
return self.local[key]
|
||||
else:
|
||||
data = self.ddb.get(key)
|
||||
if archive:
|
||||
if key in self.archive:
|
||||
logger.warning("Dataset '%s' is already in archive, "
|
||||
"overwriting", key, stack_info=True)
|
||||
self.archive[key] = data
|
||||
return data
|
||||
|
||||
data = self.ddb.get(key)
|
||||
if archive:
|
||||
if key in self.archive:
|
||||
logger.warning("Dataset '%s' is already in archive, "
|
||||
"overwriting", key, stack_info=True)
|
||||
self.archive[key] = data
|
||||
return data
|
||||
|
||||
def write_hdf5(self, f):
|
||||
datasets_group = f.create_group("datasets")
|
||||
for k, v in self.local.items():
|
||||
datasets_group[k] = v
|
||||
|
||||
archive_group = f.create_group("archive")
|
||||
for k, v in self.archive.items():
|
||||
archive_group[k] = v
|
||||
|
@ -127,18 +127,19 @@ class Notifier:
|
||||
>>> n = Notifier([])
|
||||
>>> n.append([])
|
||||
>>> n[0].append(42)
|
||||
>>> n.read
|
||||
>>> n.raw_view
|
||||
[[42]]
|
||||
|
||||
This class does not perform any network I/O and is meant to be used with
|
||||
e.g. the :class:`.Publisher` for this purpose. Only one publisher at most can be
|
||||
associated with a :class:`.Notifier`.
|
||||
|
||||
:param backing_struct: Structure to encapsulate. For convenience, it
|
||||
also becomes available as the ``read`` property of the :class:`.Notifier`.
|
||||
:param backing_struct: Structure to encapsulate.
|
||||
"""
|
||||
def __init__(self, backing_struct, root=None, path=[]):
|
||||
self.read = backing_struct
|
||||
#: The raw data encapsulated (read-only!).
|
||||
self.raw_view = backing_struct
|
||||
|
||||
if root is None:
|
||||
self.root = self
|
||||
self.publish = None
|
||||
@ -197,6 +198,27 @@ class Notifier:
|
||||
return Notifier(item, self.root, self._path + [key])
|
||||
|
||||
|
||||
def update_from_dict(target, source):
|
||||
"""Updates notifier contents from given source dictionary.
|
||||
|
||||
Only the necessary changes are performed; unchanged fields are not written.
|
||||
(Currently, modifications are only performed at the top level. That is,
|
||||
whenever there is a change to a child array/struct the entire member is
|
||||
updated instead of choosing a more optimal set of mods.)
|
||||
"""
|
||||
curr = target.raw_view
|
||||
|
||||
# Delete removed keys.
|
||||
for k in list(curr.keys()):
|
||||
if k not in source:
|
||||
del target[k]
|
||||
|
||||
# Insert/update changed data.
|
||||
for k in source.keys():
|
||||
if k not in curr or curr[k] != source[k]:
|
||||
target[k] = source[k]
|
||||
|
||||
|
||||
class Publisher(AsyncioServer):
|
||||
"""A network server that publish changes to structures encapsulated in
|
||||
a :class:`.Notifier`.
|
||||
@ -230,7 +252,7 @@ class Publisher(AsyncioServer):
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
obj = {"action": "init", "struct": notifier.read}
|
||||
obj = {"action": "init", "struct": notifier.raw_view}
|
||||
line = pyon.encode(obj) + "\n"
|
||||
writer.write(line.encode())
|
||||
|
||||
|
@ -63,7 +63,7 @@ class SyncStructCase(unittest.TestCase):
|
||||
await subscriber.close()
|
||||
await publisher.stop()
|
||||
|
||||
self.assertEqual(self.received_dict, test_dict.read)
|
||||
self.assertEqual(self.received_dict, test_dict.raw_view)
|
||||
|
||||
def test_recv(self):
|
||||
self.loop.run_until_complete(self._do_test_recv())
|
||||
|
Loading…
Reference in New Issue
Block a user