master: refactor experiments enumeration, use tools.get_experiment

Signed-off-by: Etienne Wodey <wodey@iqo.uni-hannover.de>
pull/1617/head
Etienne Wodey 2021-02-11 23:01:54 +01:00 committed by Sébastien Bourdeauducq
parent 2ca9b64ba1
commit 3cd96a951a
4 changed files with 41 additions and 36 deletions

View File

@ -64,6 +64,8 @@ Breaking changes:
* ``quamash`` has been replaced with ``qasync``.
* Protocols are updated to use device endian.
* Analyzer dump format includes a byte for device endianness.
* Experiment classes with underscore-prefixed names are now ignored when ``artiq_client``
determines which experiment to submit (consistent with ``artiq_run``).
ARTIQ-5
-------

View File

@ -488,3 +488,10 @@ def is_experiment(o):
and issubclass(o, Experiment)
and o is not Experiment
and o is not EnvExperiment)
def is_public_experiment(o):
"""Checks if a Pyhton object is a top-level,
non underscore-prefixed, experiment class.
"""
return is_experiment(o) and not o.__name__.startswith("_")

View File

@ -9,6 +9,7 @@ process via IPC.
import sys
import time
import os
import inspect
import logging
import traceback
from collections import OrderedDict
@ -20,10 +21,11 @@ from sipyco.packed_exceptions import raise_packed_exc
from sipyco.logging_tools import multiline_log_config
import artiq
from artiq.tools import file_import
from artiq import tools
from artiq.master.worker_db import DeviceManager, DatasetManager, DummyDevice
from artiq.language.environment import (is_experiment, TraceArgumentManager,
ProcessArgumentManager)
from artiq.language.environment import (
is_public_experiment, TraceArgumentManager, ProcessArgumentManager
)
from artiq.language.core import set_watchdog_factory, TerminationRequested
from artiq.language.types import TBool
from artiq.compiler import import_cache
@ -127,17 +129,9 @@ class CCB:
issue = staticmethod(make_parent_action("ccb_issue"))
def get_exp(file, class_name):
module = file_import(file, prefix="artiq_worker_")
if class_name is None:
exps = [v for k, v in module.__dict__.items()
if is_experiment(v)]
if len(exps) != 1:
raise ValueError("Found {} experiments in module"
.format(len(exps)))
return exps[0]
else:
return getattr(module, class_name)
def get_experiment(file, class_name):
module = tools.file_import(file, prefix="artiq_worker_")
return tools.get_experiment(module, class_name)
register_experiment = make_parent_action("register_experiment")
@ -164,24 +158,24 @@ class ExamineDatasetMgr:
def examine(device_mgr, dataset_mgr, file):
previous_keys = set(sys.modules.keys())
try:
module = file_import(file)
for class_name, exp_class in module.__dict__.items():
if class_name[0] == "_":
continue
if is_experiment(exp_class):
if exp_class.__doc__ is None:
name = class_name
else:
name = exp_class.__doc__.strip().splitlines()[0].strip()
if name[-1] == ".":
name = name[:-1]
argument_mgr = TraceArgumentManager()
scheduler_defaults = {}
cls = exp_class((device_mgr, dataset_mgr, argument_mgr, scheduler_defaults))
arginfo = OrderedDict(
(k, (proc.describe(), group, tooltip))
for k, (proc, group, tooltip) in argument_mgr.requested_args.items())
register_experiment(class_name, name, arginfo, scheduler_defaults)
module = tools.file_import(file)
for class_name, exp_class in inspect.getmembers(module, is_public_experiment):
if exp_class.__doc__ is None:
name = class_name
else:
name = exp_class.__doc__.strip().splitlines()[0].strip()
if name[-1] == ".":
name = name[:-1]
argument_mgr = TraceArgumentManager()
scheduler_defaults = {}
cls = exp_class( # noqa: F841 (fill argument_mgr)
(device_mgr, dataset_mgr, argument_mgr, scheduler_defaults)
)
arginfo = OrderedDict(
(k, (proc.describe(), group, tooltip))
for k, (proc, group, tooltip) in argument_mgr.requested_args.items()
)
register_experiment(class_name, name, arginfo, scheduler_defaults)
finally:
new_keys = set(sys.modules.keys())
for key in new_keys - previous_keys:
@ -285,7 +279,7 @@ def main():
experiment_file = expid["file"]
repository_path = None
setup_diagnostics(experiment_file, repository_path)
exp = get_exp(experiment_file, expid["class_name"])
exp = get_experiment(experiment_file, expid["class_name"])
device_mgr.virtual_devices["scheduler"].set_run_info(
rid, obj["pipeline_name"], expid, obj["priority"])
start_local_time = time.localtime(start_time)

View File

@ -1,5 +1,6 @@
import asyncio
import importlib.util
import inspect
import logging
import os
import pathlib
@ -12,7 +13,7 @@ from sipyco import pyon
from artiq import __version__ as artiq_version
from artiq.appdirs import user_config_dir
from artiq.language.environment import is_experiment
from artiq.language.environment import is_public_experiment
__all__ = ["parse_arguments", "elide", "short_format", "file_import",
@ -90,12 +91,13 @@ def get_experiment(module, class_name=None):
if class_name:
return getattr(module, class_name)
exps = [(k, v) for k, v in module.__dict__.items()
if k[0] != "_" and is_experiment(v)]
exps = inspect.getmembers(module, is_public_experiment)
if not exps:
raise ValueError("No experiments in module")
if len(exps) > 1:
raise ValueError("More than one experiment found in module")
return exps[0][1]