forked from M-Labs/artiq
sync_struct: Notifier.{read -> raw_view}, factor out common dict update code [nfc]
This commit is contained in:
parent
bd71852427
commit
c213ab13ba
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import tokenize
|
import tokenize
|
||||||
|
|
||||||
from artiq.protocols.sync_struct import Notifier, process_mod
|
from artiq.protocols.sync_struct import Notifier, process_mod, update_from_dict
|
||||||
from artiq.protocols import pyon
|
from artiq.protocols import pyon
|
||||||
from artiq.tools import TaskObject
|
from artiq.tools import TaskObject
|
||||||
|
|
||||||
|
@ -19,20 +19,14 @@ class DeviceDB:
|
||||||
self.data = Notifier(device_db_from_file(self.backing_file))
|
self.data = Notifier(device_db_from_file(self.backing_file))
|
||||||
|
|
||||||
def scan(self):
|
def scan(self):
|
||||||
new_data = device_db_from_file(self.backing_file)
|
update_from_dict(self.data,
|
||||||
|
device_db_from_file(self.backing_file))
|
||||||
for k in list(self.data.read.keys()):
|
|
||||||
if k not in new_data:
|
|
||||||
del self.data[k]
|
|
||||||
for k in new_data.keys():
|
|
||||||
if k not in self.data.read or self.data.read[k] != new_data[k]:
|
|
||||||
self.data[k] = new_data[k]
|
|
||||||
|
|
||||||
def get_device_db(self):
|
def get_device_db(self):
|
||||||
return self.data.read
|
return self.data.raw_view
|
||||||
|
|
||||||
def get(self, key):
|
def get(self, key):
|
||||||
return self.data.read[key]
|
return self.data.raw_view[key]
|
||||||
|
|
||||||
|
|
||||||
class DatasetDB(TaskObject):
|
class DatasetDB(TaskObject):
|
||||||
|
@ -47,7 +41,7 @@ class DatasetDB(TaskObject):
|
||||||
self.data = Notifier({k: (True, v) for k, v in file_data.items()})
|
self.data = Notifier({k: (True, v) for k, v in file_data.items()})
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
data = {k: v[1] for k, v in self.data.read.items() if v[0]}
|
data = {k: v[1] for k, v in self.data.raw_view.items() if v[0]}
|
||||||
pyon.store_file(self.persist_file, data)
|
pyon.store_file(self.persist_file, data)
|
||||||
|
|
||||||
async def _do(self):
|
async def _do(self):
|
||||||
|
@ -59,7 +53,7 @@ class DatasetDB(TaskObject):
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def get(self, key):
|
def get(self, key):
|
||||||
return self.data.read[key][1]
|
return self.data.raw_view[key][1]
|
||||||
|
|
||||||
def update(self, mod):
|
def update(self, mod):
|
||||||
process_mod(self.data, mod)
|
process_mod(self.data, mod)
|
||||||
|
@ -67,8 +61,8 @@ class DatasetDB(TaskObject):
|
||||||
# convenience functions (update() can be used instead)
|
# convenience functions (update() can be used instead)
|
||||||
def set(self, key, value, persist=None):
|
def set(self, key, value, persist=None):
|
||||||
if persist is None:
|
if persist is None:
|
||||||
if key in self.data.read:
|
if key in self.data.raw_view:
|
||||||
persist = self.data.read[key][0]
|
persist = self.data.raw_view[key][0]
|
||||||
else:
|
else:
|
||||||
persist = False
|
persist = False
|
||||||
self.data[key] = (persist, value)
|
self.data[key] = (persist, value)
|
||||||
|
|
|
@ -5,7 +5,7 @@ import shutil
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from artiq.protocols.sync_struct import Notifier
|
from artiq.protocols.sync_struct import Notifier, update_from_dict
|
||||||
from artiq.master.worker import (Worker, WorkerInternalException,
|
from artiq.master.worker import (Worker, WorkerInternalException,
|
||||||
log_worker_exception)
|
log_worker_exception)
|
||||||
from artiq.tools import get_windows_drives, exc_to_warning
|
from artiq.tools import get_windows_drives, exc_to_warning
|
||||||
|
@ -81,15 +81,6 @@ class _RepoScanner:
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
def _sync_explist(target, source):
|
|
||||||
for k in list(target.read.keys()):
|
|
||||||
if k not in source:
|
|
||||||
del target[k]
|
|
||||||
for k in source.keys():
|
|
||||||
if k not in target.read or target.read[k] != source[k]:
|
|
||||||
target[k] = source[k]
|
|
||||||
|
|
||||||
|
|
||||||
class ExperimentDB:
|
class ExperimentDB:
|
||||||
def __init__(self, repo_backend, worker_handlers):
|
def __init__(self, repo_backend, worker_handlers):
|
||||||
self.repo_backend = repo_backend
|
self.repo_backend = repo_backend
|
||||||
|
@ -125,7 +116,7 @@ class ExperimentDB:
|
||||||
new_explist = await _RepoScanner(self.worker_handlers).scan(wd)
|
new_explist = await _RepoScanner(self.worker_handlers).scan(wd)
|
||||||
logger.info("repository scan took %d seconds", time.monotonic()-t1)
|
logger.info("repository scan took %d seconds", time.monotonic()-t1)
|
||||||
|
|
||||||
_sync_explist(self.explist, new_explist)
|
update_from_dict(self.explist, new_explist)
|
||||||
finally:
|
finally:
|
||||||
self._scanning = False
|
self._scanning = False
|
||||||
self.status["scanning"] = False
|
self.status["scanning"] = False
|
||||||
|
|
|
@ -442,8 +442,10 @@ class Scheduler:
|
||||||
|
|
||||||
def get_status(self):
|
def get_status(self):
|
||||||
"""Returns a dictionary containing information about the runs currently
|
"""Returns a dictionary containing information about the runs currently
|
||||||
tracked by the scheduler."""
|
tracked by the scheduler.
|
||||||
return self.notifier.read
|
|
||||||
|
Must not be modified."""
|
||||||
|
return self.notifier.raw_view
|
||||||
|
|
||||||
def check_pause(self, rid):
|
def check_pause(self, rid):
|
||||||
"""Returns ``True`` if there is a condition that could make ``pause``
|
"""Returns ``True`` if there is a condition that could make ``pause``
|
||||||
|
|
|
@ -110,12 +110,12 @@ class DeviceManager:
|
||||||
|
|
||||||
class DatasetManager:
|
class DatasetManager:
|
||||||
def __init__(self, ddb):
|
def __init__(self, ddb):
|
||||||
self.broadcast = Notifier(dict())
|
self._broadcaster = Notifier(dict())
|
||||||
self.local = dict()
|
self.local = dict()
|
||||||
self.archive = dict()
|
self.archive = dict()
|
||||||
|
|
||||||
self.ddb = ddb
|
self.ddb = ddb
|
||||||
self.broadcast.publish = ddb.update
|
self._broadcaster.publish = ddb.update
|
||||||
|
|
||||||
def set(self, key, value, broadcast=False, persist=False, archive=True):
|
def set(self, key, value, broadcast=False, persist=False, archive=True):
|
||||||
if key in self.archive:
|
if key in self.archive:
|
||||||
|
@ -125,10 +125,12 @@ class DatasetManager:
|
||||||
|
|
||||||
if persist:
|
if persist:
|
||||||
broadcast = True
|
broadcast = True
|
||||||
|
|
||||||
if broadcast:
|
if broadcast:
|
||||||
self.broadcast[key] = persist, value
|
self._broadcaster[key] = persist, value
|
||||||
elif key in self.broadcast.read:
|
elif key in self._broadcaster.raw_view:
|
||||||
del self.broadcast[key]
|
del self._broadcaster[key]
|
||||||
|
|
||||||
if archive:
|
if archive:
|
||||||
self.local[key] = value
|
self.local[key] = value
|
||||||
elif key in self.local:
|
elif key in self.local:
|
||||||
|
@ -138,10 +140,10 @@ class DatasetManager:
|
||||||
target = None
|
target = None
|
||||||
if key in self.local:
|
if key in self.local:
|
||||||
target = self.local[key]
|
target = self.local[key]
|
||||||
if key in self.broadcast.read:
|
if key in self._broadcaster.raw_view:
|
||||||
if target is not None:
|
if target is not None:
|
||||||
assert target is self.broadcast.read[key][1]
|
assert target is self._broadcaster.raw_view[key][1]
|
||||||
target = self.broadcast[key][1]
|
target = self._broadcaster[key][1]
|
||||||
if target is None:
|
if target is None:
|
||||||
raise KeyError("Cannot mutate non-existing dataset")
|
raise KeyError("Cannot mutate non-existing dataset")
|
||||||
|
|
||||||
|
@ -155,19 +157,20 @@ class DatasetManager:
|
||||||
def get(self, key, archive=False):
|
def get(self, key, archive=False):
|
||||||
if key in self.local:
|
if key in self.local:
|
||||||
return self.local[key]
|
return self.local[key]
|
||||||
else:
|
|
||||||
data = self.ddb.get(key)
|
data = self.ddb.get(key)
|
||||||
if archive:
|
if archive:
|
||||||
if key in self.archive:
|
if key in self.archive:
|
||||||
logger.warning("Dataset '%s' is already in archive, "
|
logger.warning("Dataset '%s' is already in archive, "
|
||||||
"overwriting", key, stack_info=True)
|
"overwriting", key, stack_info=True)
|
||||||
self.archive[key] = data
|
self.archive[key] = data
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def write_hdf5(self, f):
|
def write_hdf5(self, f):
|
||||||
datasets_group = f.create_group("datasets")
|
datasets_group = f.create_group("datasets")
|
||||||
for k, v in self.local.items():
|
for k, v in self.local.items():
|
||||||
datasets_group[k] = v
|
datasets_group[k] = v
|
||||||
|
|
||||||
archive_group = f.create_group("archive")
|
archive_group = f.create_group("archive")
|
||||||
for k, v in self.archive.items():
|
for k, v in self.archive.items():
|
||||||
archive_group[k] = v
|
archive_group[k] = v
|
||||||
|
|
|
@ -127,18 +127,19 @@ class Notifier:
|
||||||
>>> n = Notifier([])
|
>>> n = Notifier([])
|
||||||
>>> n.append([])
|
>>> n.append([])
|
||||||
>>> n[0].append(42)
|
>>> n[0].append(42)
|
||||||
>>> n.read
|
>>> n.raw_view
|
||||||
[[42]]
|
[[42]]
|
||||||
|
|
||||||
This class does not perform any network I/O and is meant to be used with
|
This class does not perform any network I/O and is meant to be used with
|
||||||
e.g. the :class:`.Publisher` for this purpose. Only one publisher at most can be
|
e.g. the :class:`.Publisher` for this purpose. Only one publisher at most can be
|
||||||
associated with a :class:`.Notifier`.
|
associated with a :class:`.Notifier`.
|
||||||
|
|
||||||
:param backing_struct: Structure to encapsulate. For convenience, it
|
:param backing_struct: Structure to encapsulate.
|
||||||
also becomes available as the ``read`` property of the :class:`.Notifier`.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, backing_struct, root=None, path=[]):
|
def __init__(self, backing_struct, root=None, path=[]):
|
||||||
self.read = backing_struct
|
#: The raw data encapsulated (read-only!).
|
||||||
|
self.raw_view = backing_struct
|
||||||
|
|
||||||
if root is None:
|
if root is None:
|
||||||
self.root = self
|
self.root = self
|
||||||
self.publish = None
|
self.publish = None
|
||||||
|
@ -197,6 +198,27 @@ class Notifier:
|
||||||
return Notifier(item, self.root, self._path + [key])
|
return Notifier(item, self.root, self._path + [key])
|
||||||
|
|
||||||
|
|
||||||
|
def update_from_dict(target, source):
|
||||||
|
"""Updates notifier contents from given source dictionary.
|
||||||
|
|
||||||
|
Only the necessary changes are performed; unchanged fields are not written.
|
||||||
|
(Currently, modifications are only performed at the top level. That is,
|
||||||
|
whenever there is a change to a child array/struct the entire member is
|
||||||
|
updated instead of choosing a more optimal set of mods.)
|
||||||
|
"""
|
||||||
|
curr = target.raw_view
|
||||||
|
|
||||||
|
# Delete removed keys.
|
||||||
|
for k in list(curr.keys()):
|
||||||
|
if k not in source:
|
||||||
|
del target[k]
|
||||||
|
|
||||||
|
# Insert/update changed data.
|
||||||
|
for k in source.keys():
|
||||||
|
if k not in curr or curr[k] != source[k]:
|
||||||
|
target[k] = source[k]
|
||||||
|
|
||||||
|
|
||||||
class Publisher(AsyncioServer):
|
class Publisher(AsyncioServer):
|
||||||
"""A network server that publish changes to structures encapsulated in
|
"""A network server that publish changes to structures encapsulated in
|
||||||
a :class:`.Notifier`.
|
a :class:`.Notifier`.
|
||||||
|
@ -230,7 +252,7 @@ class Publisher(AsyncioServer):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return
|
return
|
||||||
|
|
||||||
obj = {"action": "init", "struct": notifier.read}
|
obj = {"action": "init", "struct": notifier.raw_view}
|
||||||
line = pyon.encode(obj) + "\n"
|
line = pyon.encode(obj) + "\n"
|
||||||
writer.write(line.encode())
|
writer.write(line.encode())
|
||||||
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ class SyncStructCase(unittest.TestCase):
|
||||||
await subscriber.close()
|
await subscriber.close()
|
||||||
await publisher.stop()
|
await publisher.stop()
|
||||||
|
|
||||||
self.assertEqual(self.received_dict, test_dict.read)
|
self.assertEqual(self.received_dict, test_dict.raw_view)
|
||||||
|
|
||||||
def test_recv(self):
|
def test_recv(self):
|
||||||
self.loop.run_until_complete(self._do_test_recv())
|
self.loop.run_until_complete(self._do_test_recv())
|
||||||
|
|
Loading…
Reference in New Issue