Make arguments attributes, integrate with AutoContext

This makes them accessible to future "data analysis" methods.
This commit is contained in:
Sebastien Bourdeauducq 2015-01-10 15:41:35 +08:00
parent 2ad063c377
commit 06914bbaa3
5 changed files with 59 additions and 29 deletions

View File

@ -26,7 +26,8 @@ class NoDefault:
class Parameter(_AttributeKind): class Parameter(_AttributeKind):
"""Represents a parameter for ``AutoContext`` to process. """Represents a parameter (from the database) for ``AutoContext``
to process.
:param default: Default value of the parameter to be used if not found :param default: Default value of the parameter to be used if not found
in the database. in the database.
@ -39,6 +40,18 @@ class Parameter(_AttributeKind):
self.write_db = write_db self.write_db = write_db
class Argument(_AttributeKind):
"""Represents an argument (specifiable at instance creation) for
``AutoContext`` to process.
:param default: Default value of the argument to be used if not specified
at instance creation.
"""
def __init__(self, default=NoDefault, write_db=False):
self.default = default
class AutoContext: class AutoContext:
"""Base class to automate device and parameter discovery. """Base class to automate device and parameter discovery.
@ -112,11 +125,22 @@ class AutoContext:
p = getattr(self, k) p = getattr(self, k)
if isinstance(p, Parameter) and p.write_db: if isinstance(p, Parameter) and p.write_db:
self.mvs.register_parameter_wb(self, k) self.mvs.register_parameter_wb(self, k)
if (not hasattr(self, k)
or not isinstance(getattr(self, k), _AttributeKind)):
raise ValueError(
"Got unexpected keyword argument: '{}'".format(k))
setattr(self, k, v) setattr(self, k, v)
for k in dir(self): for k in dir(self):
v = getattr(self, k) v = getattr(self, k)
if isinstance(v, _AttributeKind): if isinstance(v, _AttributeKind):
if isinstance(v, Argument):
# never goes through MVS
if v.default is NoDefault:
raise AttributeError(
"No value specified for argument '{}'".format(k))
value = v.default
else:
if self.mvs is None: if self.mvs is None:
if (isinstance(v, Parameter) if (isinstance(v, Parameter)
and v.default is not NoDefault): and v.default is not NoDefault):

View File

@ -21,8 +21,8 @@ def run(dps, file, unit, arguments):
unit = units[0] unit = units[0]
else: else:
unit = getattr(module, unit) unit = getattr(module, unit)
unit_inst = unit(dps) unit_inst = unit(dps, **arguments)
unit_inst.run(**arguments) unit_inst.run()
def get_object(): def get_object():

View File

@ -6,6 +6,9 @@ class PhotonHistogram(AutoContext):
bdd = Device("dds") bdd = Device("dds")
pmt = Device("ttl_in") pmt = Device("ttl_in")
nbins = Argument(100)
repeats = Argument(100)
@kernel @kernel
def cool_detect(self): def cool_detect(self):
with parallel: with parallel:
@ -20,13 +23,13 @@ class PhotonHistogram(AutoContext):
return self.pmt.count() return self.pmt.count()
@kernel @kernel
def run(self, nbins=100, repeats=100): def run(self):
hist = [0 for _ in range (nbins)] hist = [0 for _ in range(self.nbins)]
for i in range(repeats): for i in range(self.repeats):
n = self.cool_detect() n = self.cool_detect()
if n >= nbins: if n >= self.nbins:
n = nbins - 1 n = self.nbins - 1
hist[n] += 1 hist[n] += 1
print(hist) print(hist)

View File

@ -19,6 +19,9 @@ class Transport(AutoContext):
wait_at_stop = Parameter(100*us) wait_at_stop = Parameter(100*us)
speed = Parameter(1.5) speed = Parameter(1.5)
repeats = Argument(100)
nbins = Argument(100)
def prepare(self, stop): def prepare(self, stop):
t = transport_data["t"][:stop]*self.speed t = transport_data["t"][:stop]*self.speed
u = transport_data["u"][:stop] u = transport_data["u"][:stop]
@ -91,27 +94,27 @@ class Transport(AutoContext):
return self.detect() return self.detect()
@kernel @kernel
def repeat(self, repeats, nbins): def repeat(self):
self.histogram = [0 for _ in range(nbins)] self.histogram = [0 for _ in range(self.nbins)]
for i in range(repeats): for i in range(self.repeats):
n = self.one() n = self.one()
if n >= nbins: if n >= self.nbins:
n = nbins - 1 n = self.nbins - 1
self.histogram[n] += 1 self.histogram[n] += 1
def scan(self, repeats, nbins, stops): def scan(self, stops):
for s in stops: for s in stops:
self.histogram = [] self.histogram = []
# non-kernel, calculate waveforms, build frames # non-kernel, calculate waveforms, build frames
# could also be rpc'ed from repeat() # could also be rpc'ed from repeat()
self.prepare(s) self.prepare(s)
# kernel part # kernel part
self.repeat(repeats, nbins) self.repeat()
# live update 2d plot with current self.histogram # live update 2d plot with current self.histogram
# broadcast(s, self.histogram) # broadcast(s, self.histogram)
def run(self, repeats=100, nbins=100): def run(self):
# scan transport endpoint # scan transport endpoint
stops = range(10, len(transport_data["t"]), 10) stops = range(10, len(transport_data["t"]), 10)
self.scan(repeats, nbins, stops) self.scan(stops)

View File

@ -92,8 +92,8 @@ def main():
print("Failed to parse run arguments") print("Failed to parse run arguments")
sys.exit(1) sys.exit(1)
unit_inst = unit(dps) unit_inst = unit(dps, **arguments)
unit_inst.run(**arguments) unit_inst.run()
if dps.parameter_wb: if dps.parameter_wb:
print("Modified parameters:") print("Modified parameters:")