forked from M-Labs/artiq
1
0
Fork 0

master: refactor experiments enumeration, use tools.get_experiment

Signed-off-by: Etienne Wodey <wodey@iqo.uni-hannover.de>
This commit is contained in:
Etienne Wodey 2021-02-11 23:01:54 +01:00 committed by Sébastien Bourdeauducq
parent 2ca9b64ba1
commit 3cd96a951a
4 changed files with 41 additions and 36 deletions

View File

@ -64,6 +64,8 @@ Breaking changes:
* ``quamash`` has been replaced with ``qasync``. * ``quamash`` has been replaced with ``qasync``.
* Protocols are updated to use device endian. * Protocols are updated to use device endian.
* Analyzer dump format includes a byte for device endianness. * Analyzer dump format includes a byte for device endianness.
* Experiment classes with underscore-prefixed names are now ignored when ``artiq_client``
determines which experiment to submit (consistent with ``artiq_run``).
ARTIQ-5 ARTIQ-5
------- -------

View File

@ -488,3 +488,10 @@ def is_experiment(o):
and issubclass(o, Experiment) and issubclass(o, Experiment)
and o is not Experiment and o is not Experiment
and o is not EnvExperiment) and o is not EnvExperiment)
def is_public_experiment(o):
"""Checks if a Pyhton object is a top-level,
non underscore-prefixed, experiment class.
"""
return is_experiment(o) and not o.__name__.startswith("_")

View File

@ -9,6 +9,7 @@ process via IPC.
import sys import sys
import time import time
import os import os
import inspect
import logging import logging
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
@ -20,10 +21,11 @@ from sipyco.packed_exceptions import raise_packed_exc
from sipyco.logging_tools import multiline_log_config from sipyco.logging_tools import multiline_log_config
import artiq import artiq
from artiq.tools import file_import from artiq import tools
from artiq.master.worker_db import DeviceManager, DatasetManager, DummyDevice from artiq.master.worker_db import DeviceManager, DatasetManager, DummyDevice
from artiq.language.environment import (is_experiment, TraceArgumentManager, from artiq.language.environment import (
ProcessArgumentManager) is_public_experiment, TraceArgumentManager, ProcessArgumentManager
)
from artiq.language.core import set_watchdog_factory, TerminationRequested from artiq.language.core import set_watchdog_factory, TerminationRequested
from artiq.language.types import TBool from artiq.language.types import TBool
from artiq.compiler import import_cache from artiq.compiler import import_cache
@ -127,17 +129,9 @@ class CCB:
issue = staticmethod(make_parent_action("ccb_issue")) issue = staticmethod(make_parent_action("ccb_issue"))
def get_exp(file, class_name): def get_experiment(file, class_name):
module = file_import(file, prefix="artiq_worker_") module = tools.file_import(file, prefix="artiq_worker_")
if class_name is None: return tools.get_experiment(module, class_name)
exps = [v for k, v in module.__dict__.items()
if is_experiment(v)]
if len(exps) != 1:
raise ValueError("Found {} experiments in module"
.format(len(exps)))
return exps[0]
else:
return getattr(module, class_name)
register_experiment = make_parent_action("register_experiment") register_experiment = make_parent_action("register_experiment")
@ -164,24 +158,24 @@ class ExamineDatasetMgr:
def examine(device_mgr, dataset_mgr, file): def examine(device_mgr, dataset_mgr, file):
previous_keys = set(sys.modules.keys()) previous_keys = set(sys.modules.keys())
try: try:
module = file_import(file) module = tools.file_import(file)
for class_name, exp_class in module.__dict__.items(): for class_name, exp_class in inspect.getmembers(module, is_public_experiment):
if class_name[0] == "_": if exp_class.__doc__ is None:
continue name = class_name
if is_experiment(exp_class): else:
if exp_class.__doc__ is None: name = exp_class.__doc__.strip().splitlines()[0].strip()
name = class_name if name[-1] == ".":
else: name = name[:-1]
name = exp_class.__doc__.strip().splitlines()[0].strip() argument_mgr = TraceArgumentManager()
if name[-1] == ".": scheduler_defaults = {}
name = name[:-1] cls = exp_class( # noqa: F841 (fill argument_mgr)
argument_mgr = TraceArgumentManager() (device_mgr, dataset_mgr, argument_mgr, scheduler_defaults)
scheduler_defaults = {} )
cls = exp_class((device_mgr, dataset_mgr, argument_mgr, scheduler_defaults)) arginfo = OrderedDict(
arginfo = OrderedDict( (k, (proc.describe(), group, tooltip))
(k, (proc.describe(), group, tooltip)) for k, (proc, group, tooltip) in argument_mgr.requested_args.items()
for k, (proc, group, tooltip) in argument_mgr.requested_args.items()) )
register_experiment(class_name, name, arginfo, scheduler_defaults) register_experiment(class_name, name, arginfo, scheduler_defaults)
finally: finally:
new_keys = set(sys.modules.keys()) new_keys = set(sys.modules.keys())
for key in new_keys - previous_keys: for key in new_keys - previous_keys:
@ -285,7 +279,7 @@ def main():
experiment_file = expid["file"] experiment_file = expid["file"]
repository_path = None repository_path = None
setup_diagnostics(experiment_file, repository_path) setup_diagnostics(experiment_file, repository_path)
exp = get_exp(experiment_file, expid["class_name"]) exp = get_experiment(experiment_file, expid["class_name"])
device_mgr.virtual_devices["scheduler"].set_run_info( device_mgr.virtual_devices["scheduler"].set_run_info(
rid, obj["pipeline_name"], expid, obj["priority"]) rid, obj["pipeline_name"], expid, obj["priority"])
start_local_time = time.localtime(start_time) start_local_time = time.localtime(start_time)

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import importlib.util import importlib.util
import inspect
import logging import logging
import os import os
import pathlib import pathlib
@ -12,7 +13,7 @@ from sipyco import pyon
from artiq import __version__ as artiq_version from artiq import __version__ as artiq_version
from artiq.appdirs import user_config_dir from artiq.appdirs import user_config_dir
from artiq.language.environment import is_experiment from artiq.language.environment import is_public_experiment
__all__ = ["parse_arguments", "elide", "short_format", "file_import", __all__ = ["parse_arguments", "elide", "short_format", "file_import",
@ -90,12 +91,13 @@ def get_experiment(module, class_name=None):
if class_name: if class_name:
return getattr(module, class_name) return getattr(module, class_name)
exps = [(k, v) for k, v in module.__dict__.items() exps = inspect.getmembers(module, is_public_experiment)
if k[0] != "_" and is_experiment(v)]
if not exps: if not exps:
raise ValueError("No experiments in module") raise ValueError("No experiments in module")
if len(exps) > 1: if len(exps) > 1:
raise ValueError("More than one experiment found in module") raise ValueError("More than one experiment found in module")
return exps[0][1] return exps[0][1]