forked from M-Labs/artiq
language/environment: cast argument processor default values early
Fixes #1434. Also add unit tests for some argument processors. Signed-off-by: Etienne Wodey <wodey@iqo.uni-hannover.de>
This commit is contained in:
parent
371d923385
commit
9b03a365ed
|
@ -32,7 +32,7 @@ class _SimpleArgProcessor:
|
|||
if isinstance(default, list):
|
||||
raise NotImplementedError
|
||||
if default is not NoDefault:
|
||||
self.default_value = default
|
||||
self.default_value = self.process(default)
|
||||
|
||||
def default(self):
|
||||
if not hasattr(self, "default_value"):
|
||||
|
@ -54,6 +54,7 @@ class PYONValue(_SimpleArgProcessor):
|
|||
def __init__(self, default=NoDefault):
|
||||
# Override the _SimpleArgProcessor init, as list defaults are valid
|
||||
# PYON values
|
||||
# default stays decoded
|
||||
if default is not NoDefault:
|
||||
self.default_value = default
|
||||
|
||||
|
@ -69,7 +70,13 @@ class PYONValue(_SimpleArgProcessor):
|
|||
|
||||
class BooleanValue(_SimpleArgProcessor):
|
||||
"""A boolean argument."""
|
||||
pass
|
||||
def process(self, x):
|
||||
if x is True:
|
||||
return True
|
||||
elif x is False:
|
||||
return False
|
||||
else:
|
||||
raise ValueError("Invalid BooleanValue value")
|
||||
|
||||
|
||||
class EnumerationValue(_SimpleArgProcessor):
|
||||
|
@ -80,15 +87,20 @@ class EnumerationValue(_SimpleArgProcessor):
|
|||
argument.
|
||||
"""
|
||||
def __init__(self, choices, default=NoDefault):
|
||||
_SimpleArgProcessor.__init__(self, default)
|
||||
assert default is NoDefault or default in choices
|
||||
self.choices = choices
|
||||
super().__init__(default)
|
||||
|
||||
def process(self, x):
|
||||
if x not in self.choices:
|
||||
raise ValueError("Invalid EnumerationValue value")
|
||||
return x
|
||||
|
||||
def describe(self):
|
||||
d = _SimpleArgProcessor.describe(self)
|
||||
d["choices"] = self.choices
|
||||
return d
|
||||
|
||||
|
||||
class NumberValue(_SimpleArgProcessor):
|
||||
"""An argument that can take a numerical value.
|
||||
|
||||
|
@ -132,8 +144,6 @@ class NumberValue(_SimpleArgProcessor):
|
|||
"the scale manually".format(unit))
|
||||
if step is None:
|
||||
step = scale/10.0
|
||||
if default is not NoDefault:
|
||||
self.default_value = default
|
||||
self.unit = unit
|
||||
self.scale = scale
|
||||
self.step = step
|
||||
|
@ -141,19 +151,13 @@ class NumberValue(_SimpleArgProcessor):
|
|||
self.max = max
|
||||
self.ndecimals = ndecimals
|
||||
|
||||
super().__init__(default)
|
||||
|
||||
def _is_int(self):
|
||||
return (self.ndecimals == 0
|
||||
and int(self.step) == self.step
|
||||
and self.scale == 1)
|
||||
|
||||
def default(self):
|
||||
if not hasattr(self, "default_value"):
|
||||
raise DefaultMissing
|
||||
if self._is_int():
|
||||
return int(self.default_value)
|
||||
else:
|
||||
return float(self.default_value)
|
||||
|
||||
def process(self, x):
|
||||
if self._is_int():
|
||||
return int(x)
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
import unittest
|
||||
import numbers
|
||||
|
||||
|
||||
from artiq.language.environment import BooleanValue, EnumerationValue, \
|
||||
NumberValue, DefaultMissing
|
||||
|
||||
|
||||
class NumberValueCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.default_value = NumberValue()
|
||||
self.int_value = NumberValue(42, step=1, ndecimals=0)
|
||||
self.float_value = NumberValue(42)
|
||||
|
||||
def test_invalid_default(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = NumberValue("invalid")
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
_ = NumberValue(1.+1j)
|
||||
|
||||
def test_no_default(self):
|
||||
with self.assertRaises(DefaultMissing):
|
||||
self.default_value.default()
|
||||
|
||||
def test_integer_default(self):
|
||||
self.assertIsInstance(self.int_value.default(), numbers.Integral)
|
||||
|
||||
def test_default_to_float(self):
|
||||
self.assertIsInstance(self.float_value.default(), numbers.Real)
|
||||
self.assertNotIsInstance(self.float_value.default(), numbers.Integral)
|
||||
|
||||
def test_invalid_unit(self):
|
||||
with self.assertRaises(KeyError):
|
||||
_ = NumberValue(unit="invalid")
|
||||
|
||||
def test_default_scale(self):
|
||||
self.assertEqual(self.default_value.scale, 1.)
|
||||
|
||||
|
||||
class BooleanValueCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.default_value = BooleanValue()
|
||||
self.true_value = BooleanValue(True)
|
||||
self.false_value = BooleanValue(False)
|
||||
|
||||
def test_default(self):
|
||||
self.assertIs(self.true_value.default(), True)
|
||||
self.assertIs(self.false_value.default(), False)
|
||||
|
||||
def test_no_default(self):
|
||||
with self.assertRaises(DefaultMissing):
|
||||
self.default_value.default()
|
||||
|
||||
def test_invalid_default(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = BooleanValue(1)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = BooleanValue("abc")
|
||||
|
||||
|
||||
class EnumerationValueCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.default_value = EnumerationValue(["abc"])
|
||||
|
||||
def test_no_default(self):
|
||||
with self.assertRaises(DefaultMissing):
|
||||
self.default_value.default()
|
||||
|
||||
def test_invalid_default(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = EnumerationValue("abc", "d")
|
||||
|
||||
def test_valid_default(self):
|
||||
try:
|
||||
_ = EnumerationValue("abc", "a")
|
||||
except ValueError:
|
||||
self.fail("Unexpected ValueError")
|
Loading…
Reference in New Issue