environment: refactor

This commit is contained in:
Sebastien Bourdeauducq 2016-04-16 19:31:07 +08:00
parent 12a8c76df9
commit caf774579a
11 changed files with 97 additions and 115 deletions

View File

@ -36,6 +36,8 @@ unreleased [2.x]
(i.e. grouping by day and then by hour, instead of by day and then by minute) (i.e. grouping by day and then by hour, instead of by day and then by minute)
* GUI tools save their state file in the user's home directory instead of the * GUI tools save their state file in the user's home directory instead of the
current directory. current directory.
* The ``parent`` keyword argument of ``HasEnvironment`` (and ``EnvExperiment``)
has been replaced. Pass the parent as first argument instead.
unreleased [1.0rc3] unreleased [1.0rc3]

View File

@ -54,8 +54,8 @@ class ArgumentsDemo(EnvExperiment):
self.setattr_argument("enum", EnumerationValue( self.setattr_argument("enum", EnumerationValue(
["foo", "bar", "quux"], "foo"), "Group") ["foo", "bar", "quux"], "foo"), "Group")
self.sc1 = SubComponent1(parent=self) self.sc1 = SubComponent1(self)
self.sc2 = SubComponent2(parent=self) self.sc2 = SubComponent2(self)
def run(self): def run(self):
logging.error("logging test: error") logging.error("logging test: error")

View File

@ -17,8 +17,8 @@ class Transport(EnvExperiment):
self.setattr_argument("wait_at_stop", NumberValue(100*us)) self.setattr_argument("wait_at_stop", NumberValue(100*us))
self.setattr_argument("speed", NumberValue(1/(10*us))) self.setattr_argument("speed", NumberValue(1/(10*us)))
self.repeats = int(self.get_argument("repeats", NumberValue(100))) self.setattr_argument("repeats", NumberValue(100, step=1, ndecimals=0))
self.bins = int(self.get_argument("bins", NumberValue(100))) self.setattr_argument("bins", NumberValue(100, step=1, ndecimals=0))
t = np.linspace(0, 10, 101) # waveform time t = np.linspace(0, 10, 101) # waveform time
u = 1 - np.cos(np.pi*t/t[-1]) u = 1 - np.cos(np.pi*t/t[-1])

View File

@ -7,9 +7,6 @@ from artiq.experiment import *
class Histograms(EnvExperiment): class Histograms(EnvExperiment):
"""Histograms demo""" """Histograms demo"""
def build(self):
pass
def run(self): def run(self):
nbins = 50 nbins = 50
npoints = 20 npoints = 20

View File

@ -102,7 +102,7 @@ class SpeedBenchmark(EnvExperiment):
self.scheduler.priority, None, False) self.scheduler.priority, None, False)
def run_without_scheduler(self, pause): def run_without_scheduler(self, pause):
payload = globals()["_Payload" + self.payload](*self.managers()) payload = globals()["_Payload" + self.payload](self)
start_time = time.monotonic() start_time = time.monotonic()
for i in range(int(self.nruns)): for i in range(int(self.nruns)):

View File

@ -12,7 +12,7 @@ import h5py
from llvmlite_artiq import binding as llvm from llvmlite_artiq import binding as llvm
from artiq.language.environment import EnvExperiment from artiq.language.environment import EnvExperiment, ProcessArgumentManager
from artiq.master.databases import DeviceDB, DatasetDB from artiq.master.databases import DeviceDB, DatasetDB
from artiq.master.worker_db import DeviceManager, DatasetManager from artiq.master.worker_db import DeviceManager, DatasetManager
from artiq.coredevice.core import CompileError, host_only from artiq.coredevice.core import CompileError, host_only
@ -167,7 +167,8 @@ def _build_experiment(device_mgr, dataset_mgr, args):
"arguments": arguments "arguments": arguments
} }
device_mgr.virtual_devices["scheduler"].expid = expid device_mgr.virtual_devices["scheduler"].expid = expid
return exp(device_mgr, dataset_mgr, **arguments) argument_mgr = ProcessArgumentManager(arguments)
return exp((device_mgr, dataset_mgr, argument_mgr))
def run(with_file=False): def run(with_file=False):

View File

@ -7,8 +7,7 @@ from artiq.protocols import pyon
__all__ = ["NoDefault", __all__ = ["NoDefault",
"PYONValue", "BooleanValue", "EnumerationValue", "PYONValue", "BooleanValue", "EnumerationValue",
"NumberValue", "StringValue", "NumberValue", "StringValue",
"HasEnvironment", "HasEnvironment", "Experiment", "EnvExperiment"]
"Experiment", "EnvExperiment", "is_experiment"]
class NoDefault: class NoDefault:
@ -145,51 +144,63 @@ class StringValue(_SimpleArgProcessor):
pass pass
class HasEnvironment: class TraceArgumentManager:
"""Provides methods to manage the environment of an experiment (devices, def __init__(self):
parameters, results, arguments)."""
def __init__(self, device_mgr=None, dataset_mgr=None, *, parent=None,
default_arg_none=False, **kwargs):
self.requested_args = OrderedDict() self.requested_args = OrderedDict()
self.__device_mgr = device_mgr def get(self, key, processor, group):
self.__dataset_mgr = dataset_mgr self.requested_args[key] = processor, group
self.__parent = parent return None
self.__default_arg_none = default_arg_none
class ProcessArgumentManager:
def __init__(self, unprocessed_arguments):
self.unprocessed_arguments = unprocessed_arguments
def get(self, key, processor, group):
if key in self.unprocessed_arguments:
r = processor.process(self.unprocessed_arguments[key])
else:
r = processor.default()
return r
class HasEnvironment:
"""Provides methods to manage the environment of an experiment (arguments,
devices, datasets)."""
def __init__(self, managers_or_parent, *args, **kwargs):
if isinstance(managers_or_parent, tuple):
self.__device_mgr = managers_or_parent[0]
self.__dataset_mgr = managers_or_parent[1]
self.__argument_mgr = managers_or_parent[2]
else:
self.__device_mgr = managers_or_parent.__device_mgr
self.__dataset_mgr = managers_or_parent.__dataset_mgr
self.__argument_mgr = managers_or_parent.__argument_mgr
self.__kwargs = kwargs
self.__in_build = True self.__in_build = True
self.build() self.build(*args, **kwargs)
self.__in_build = False self.__in_build = False
for key in self.__kwargs.keys():
if key not in self.requested_args:
raise TypeError("Got unexpected argument: " + key)
del self.__kwargs
def build(self): def build(self):
"""Must be implemented by the user to request arguments. """Should be implemented by the user to request arguments.
Other initialization steps such as requesting devices and parameters Other initialization steps such as requesting devices may also be
or initializing real-time results may also be performed here. performed here.
When the repository is scanned, any requested devices and parameters When the repository is scanned, any requested devices and arguments
are set to ``None``.""" are set to ``None``.
raise NotImplementedError
def managers(self): Leftover positional and keyword arguments from the constructor are
"""Returns the device manager and the dataset manager, in this order. forwarded to this method. This is intended for experiments that are
only meant to be executed programmatically (not from the GUI)."""
pass
This is the same order that the constructor takes them, allowing def get_argument(self, key, processor, group=None):
sub-objects to be created with this idiom to pass the environment
around: ::
sub_object = SomeLibrary(*self.managers())
"""
return self.__device_mgr, self.__dataset_mgr
def get_argument(self, key, processor=None, group=None):
"""Retrieves and returns the value of an argument. """Retrieves and returns the value of an argument.
This function should only be called from ``build``.
:param key: Name of the argument. :param key: Name of the argument.
:param processor: A description of how to process the argument, such :param processor: A description of how to process the argument, such
as instances of ``BooleanValue`` and ``NumberValue``. as instances of ``BooleanValue`` and ``NumberValue``.
@ -199,22 +210,7 @@ class HasEnvironment:
if not self.__in_build: if not self.__in_build:
raise TypeError("get_argument() should only " raise TypeError("get_argument() should only "
"be called from build()") "be called from build()")
if self.__parent is not None and key not in self.__kwargs: return self.__argument_mgr.get(key, processor, group)
return self.__parent.get_argument(key, processor, group)
if processor is None:
processor = PYONValue()
self.requested_args[key] = processor, group
try:
argval = self.__kwargs[key]
except KeyError:
try:
return processor.default()
except DefaultMissing:
if self.__default_arg_none:
return None
else:
raise
return processor.process(argval)
def setattr_argument(self, key, processor=None, group=None): def setattr_argument(self, key, processor=None, group=None):
"""Sets an argument as attribute. The names of the argument and of the """Sets an argument as attribute. The names of the argument and of the
@ -223,16 +219,10 @@ class HasEnvironment:
def get_device_db(self): def get_device_db(self):
"""Returns the full contents of the device database.""" """Returns the full contents of the device database."""
if self.__parent is not None:
return self.__parent.get_device_db()
return self.__device_mgr.get_device_db() return self.__device_mgr.get_device_db()
def get_device(self, key): def get_device(self, key):
"""Creates and returns a device driver.""" """Creates and returns a device driver."""
if self.__parent is not None:
return self.__parent.get_device(key)
if self.__device_mgr is None:
raise ValueError("Device manager not present")
return self.__device_mgr.get(key) return self.__device_mgr.get(key)
def setattr_device(self, key): def setattr_device(self, key):
@ -254,11 +244,6 @@ class HasEnvironment:
:param save: the data is saved into the local storage of the current :param save: the data is saved into the local storage of the current
run (archived as a HDF5 file). run (archived as a HDF5 file).
""" """
if self.__parent is not None:
self.__parent.set_dataset(key, value, broadcast, persist, save)
return
if self.__dataset_mgr is None:
raise ValueError("Dataset manager not present")
self.__dataset_mgr.set(key, value, broadcast, persist, save) self.__dataset_mgr.set(key, value, broadcast, persist, save)
def mutate_dataset(self, key, index, value): def mutate_dataset(self, key, index, value):
@ -267,10 +252,6 @@ class HasEnvironment:
If the dataset was created in broadcast mode, the modification is If the dataset was created in broadcast mode, the modification is
immediately transmitted.""" immediately transmitted."""
if self.__parent is not None:
self.__parent.mutate_dataset(key, index, value)
if self.__dataset_mgr is None:
raise ValueError("Dataset manager not present")
self.__dataset_mgr.mutate(key, index, value) self.__dataset_mgr.mutate(key, index, value)
def get_dataset(self, key, default=NoDefault): def get_dataset(self, key, default=NoDefault):
@ -283,10 +264,6 @@ class HasEnvironment:
If the dataset does not exist, returns the default value. If no default If the dataset does not exist, returns the default value. If no default
is provided, raises ``KeyError``. is provided, raises ``KeyError``.
""" """
if self.__parent is not None:
return self.__parent.get_dataset(key, default)
if self.__dataset_mgr is None:
raise ValueError("Dataset manager not present")
try: try:
return self.__dataset_mgr.get(key) return self.__dataset_mgr.get(key)
except KeyError: except KeyError:
@ -302,7 +279,7 @@ class HasEnvironment:
class Experiment: class Experiment:
"""Base class for experiments. """Base class for top-level experiments.
Deriving from this class enables automatic experiment discovery in Deriving from this class enables automatic experiment discovery in
Python modules. Python modules.
@ -348,15 +325,15 @@ class Experiment:
class EnvExperiment(Experiment, HasEnvironment): class EnvExperiment(Experiment, HasEnvironment):
"""Base class for experiments that use the ``HasEnvironment`` environment """Base class for top-level experiments that use the ``HasEnvironment``
manager. environment manager.
Most experiment should derive from this class.""" Most experiment should derive from this class."""
pass pass
def is_experiment(o): def is_experiment(o):
"""Checks if a Python object is an instantiable user experiment.""" """Checks if a Python object is a top-level experiment class."""
return (isclass(o) return (isclass(o)
and issubclass(o, Experiment) and issubclass(o, Experiment)
and o is not Experiment and o is not Experiment

View File

@ -12,7 +12,8 @@ from artiq.protocols import pipe_ipc, pyon
from artiq.protocols.packed_exceptions import raise_packed_exc from artiq.protocols.packed_exceptions import raise_packed_exc
from artiq.tools import multiline_log_config, file_import from artiq.tools import multiline_log_config, file_import
from artiq.master.worker_db import DeviceManager, DatasetManager from artiq.master.worker_db import DeviceManager, DatasetManager
from artiq.language.environment import is_experiment from artiq.language.environment import (is_experiment, TraceArgumentManager,
ProcessArgumentManager)
from artiq.language.core import set_watchdog_factory, TerminationRequested from artiq.language.core import set_watchdog_factory, TerminationRequested
from artiq.coredevice.core import CompileError, host_only, _render_diagnostic from artiq.coredevice.core import CompileError, host_only, _render_diagnostic
from artiq import __version__ as artiq_version from artiq import __version__ as artiq_version
@ -138,11 +139,11 @@ def examine(device_mgr, dataset_mgr, file):
name = exp_class.__doc__.splitlines()[0].strip() name = exp_class.__doc__.splitlines()[0].strip()
if name[-1] == ".": if name[-1] == ".":
name = name[:-1] name = name[:-1]
exp_inst = exp_class(device_mgr, dataset_mgr, argument_mgr = TraceArgumentManager()
default_arg_none=True) exp_inst = exp_class((device_mgr, dataset_mgr, argument_mgr))
arginfo = OrderedDict( arginfo = OrderedDict(
(k, (proc.describe(), group)) (k, (proc.describe(), group))
for k, (proc, group) in exp_inst.requested_args.items()) for k, (proc, group) in argument_mgr.requested_args.items())
register_experiment(class_name, name, arginfo) register_experiment(class_name, name, arginfo)
@ -213,8 +214,8 @@ def main():
time.strftime("%H", start_time)) time.strftime("%H", start_time))
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
os.chdir(dirname) os.chdir(dirname)
exp_inst = exp( argument_mgr = ProcessArgumentManager(expid["arguments"])
device_mgr, dataset_mgr, **expid["arguments"]) exp_inst = exp((device_mgr, dataset_mgr, argument_mgr))
put_object({"action": "completed"}) put_object({"action": "completed"})
elif action == "prepare": elif action == "prepare":
exp_inst.prepare() exp_inst.prepare()

View File

@ -6,19 +6,21 @@ from artiq.sim import devices as sim_devices
from artiq.test.hardware_testbench import ExperimentCase from artiq.test.hardware_testbench import ExperimentCase
def _run_on_host(k_class, **arguments): def _run_on_host(k_class, *args, **kwargs):
dmgr = dict() device_mgr = dict()
dmgr["core"] = sim_devices.Core(dmgr) device_mgr["core"] = sim_devices.Core(device_mgr)
k_inst = k_class(dmgr, **arguments)
k_inst = k_class((device_mgr, None, None),
*args, **kwargs)
k_inst.run() k_inst.run()
return k_inst return k_inst
class _Primes(EnvExperiment): class _Primes(EnvExperiment):
def build(self): def build(self, output_list, maximum):
self.setattr_device("core") self.setattr_device("core")
self.setattr_argument("output_list") self.output_list = output_list
self.setattr_argument("maximum") self.maximum = maximum
def _add_output(self, x): def _add_output(self, x):
self.output_list.append(x) self.output_list.append(x)
@ -72,10 +74,10 @@ class _Misc(EnvExperiment):
class _PulseLogger(EnvExperiment): class _PulseLogger(EnvExperiment):
def build(self): def build(self, parent_test, name):
self.setattr_device("core") self.setattr_device("core")
self.setattr_argument("parent_test") self.parent_test = parent_test
self.setattr_argument("name") self.name = name
def _append(self, t, l, f): def _append(self, t, l, f):
if not hasattr(self.parent_test, "first_timestamp"): if not hasattr(self.parent_test, "first_timestamp"):
@ -98,12 +100,12 @@ class _PulseLogger(EnvExperiment):
class _Pulses(EnvExperiment): class _Pulses(EnvExperiment):
def build(self): def build(self, output_list):
self.setattr_device("core") self.setattr_device("core")
self.setattr_argument("output_list") self.output_list = output_list
for name in "a", "b", "c", "d": for name in "a", "b", "c", "d":
pl = _PulseLogger(*self.managers(), pl = _PulseLogger(self,
parent_test=self, parent_test=self,
name=name) name=name)
setattr(self, name, pl) setattr(self, name, pl)
@ -125,9 +127,9 @@ class _MyException(Exception):
class _Exceptions(EnvExperiment): class _Exceptions(EnvExperiment):
def build(self): def build(self, trace):
self.setattr_device("core") self.setattr_device("core")
self.setattr_argument("trace") self.trace = trace
def _trace(self, i): def _trace(self, i):
self.trace.append(i) self.trace.append(i)
@ -172,9 +174,9 @@ class _Exceptions(EnvExperiment):
class _RPCExceptions(EnvExperiment): class _RPCExceptions(EnvExperiment):
def build(self): def build(self, catch):
self.setattr_device("core") self.setattr_device("core")
self.setattr_argument("catch", PYONValue(False)) self.catch = catch
self.success = False self.success = False

View File

@ -147,11 +147,11 @@ class Watchdog(EnvExperiment):
class LoopbackCount(EnvExperiment): class LoopbackCount(EnvExperiment):
def build(self): def build(self, npulses):
self.setattr_device("core") self.setattr_device("core")
self.setattr_device("loop_in") self.setattr_device("loop_in")
self.setattr_device("loop_out") self.setattr_device("loop_out")
self.setattr_argument("npulses") self.npulses = npulses
def set_count(self, count): def set_count(self, count):
self.set_dataset("count", count) self.set_dataset("count", count)
@ -320,9 +320,9 @@ class CoredeviceTest(ExperimentCase):
class RPCTiming(EnvExperiment): class RPCTiming(EnvExperiment):
def build(self): def build(self, repeats=100):
self.setattr_device("core") self.setattr_device("core")
self.setattr_argument("repeats", PYONValue(100)) self.repeats = repeats
def nop(self): def nop(self):
pass pass

View File

@ -108,9 +108,11 @@ class ExperimentCase(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.device_mgr.close_devices() self.device_mgr.close_devices()
def create(self, cls, **kwargs): def create(self, cls, *args, **kwargs):
try: try:
exp = cls(self.device_mgr, self.dataset_mgr, **kwargs) exp = cls(
(self.device_mgr, self.dataset_mgr, None),
*args, **kwargs)
exp.prepare() exp.prepare()
return exp return exp
except KeyError as e: except KeyError as e:
@ -118,15 +120,15 @@ class ExperimentCase(unittest.TestCase):
raise unittest.SkipTest( raise unittest.SkipTest(
"device_db entry `{}` not found".format(*e.args)) "device_db entry `{}` not found".format(*e.args))
def execute(self, cls, **kwargs): def execute(self, cls, *args, **kwargs):
expid = { expid = {
"file": sys.modules[cls.__module__].__file__, "file": sys.modules[cls.__module__].__file__,
"class_name": cls.__name__, "class_name": cls.__name__,
"arguments": kwargs "arguments": dict()
} }
self.device_mgr.virtual_devices["scheduler"].expid = expid self.device_mgr.virtual_devices["scheduler"].expid = expid
try: try:
exp = self.create(cls, **kwargs) exp = self.create(cls, *args, **kwargs)
exp.run() exp.run()
exp.analyze() exp.analyze()
return exp return exp