diff --git a/artiq/language/environment.py b/artiq/language/environment.py index 766779177..eff9fc3f1 100644 --- a/artiq/language/environment.py +++ b/artiq/language/environment.py @@ -32,7 +32,7 @@ class _SimpleArgProcessor: if isinstance(default, list): raise NotImplementedError if default is not NoDefault: - self.default_value = default + self.default_value = self.process(default) def default(self): if not hasattr(self, "default_value"): @@ -54,6 +54,7 @@ class PYONValue(_SimpleArgProcessor): def __init__(self, default=NoDefault): # Override the _SimpleArgProcessor init, as list defaults are valid # PYON values + # default stays decoded if default is not NoDefault: self.default_value = default @@ -69,7 +70,13 @@ class PYONValue(_SimpleArgProcessor): class BooleanValue(_SimpleArgProcessor): """A boolean argument.""" - pass + def process(self, x): + if x is True: + return True + elif x is False: + return False + else: + raise ValueError("Invalid BooleanValue value") class EnumerationValue(_SimpleArgProcessor): @@ -80,15 +87,20 @@ class EnumerationValue(_SimpleArgProcessor): argument. """ def __init__(self, choices, default=NoDefault): - _SimpleArgProcessor.__init__(self, default) - assert default is NoDefault or default in choices self.choices = choices + super().__init__(default) + + def process(self, x): + if x not in self.choices: + raise ValueError("Invalid EnumerationValue value") + return x def describe(self): d = _SimpleArgProcessor.describe(self) d["choices"] = self.choices return d + class NumberValue(_SimpleArgProcessor): """An argument that can take a numerical value. @@ -132,8 +144,6 @@ class NumberValue(_SimpleArgProcessor): "the scale manually".format(unit)) if step is None: step = scale/10.0 - if default is not NoDefault: - self.default_value = default self.unit = unit self.scale = scale self.step = step @@ -141,19 +151,13 @@ class NumberValue(_SimpleArgProcessor): self.max = max self.ndecimals = ndecimals + super().__init__(default) + def _is_int(self): return (self.ndecimals == 0 and int(self.step) == self.step and self.scale == 1) - def default(self): - if not hasattr(self, "default_value"): - raise DefaultMissing - if self._is_int(): - return int(self.default_value) - else: - return float(self.default_value) - def process(self, x): if self._is_int(): return int(x) diff --git a/artiq/test/test_arguments.py b/artiq/test/test_arguments.py new file mode 100644 index 000000000..884c8982f --- /dev/null +++ b/artiq/test/test_arguments.py @@ -0,0 +1,79 @@ +import unittest +import numbers + + +from artiq.language.environment import BooleanValue, EnumerationValue, \ + NumberValue, DefaultMissing + + +class NumberValueCase(unittest.TestCase): + def setUp(self): + self.default_value = NumberValue() + self.int_value = NumberValue(42, step=1, ndecimals=0) + self.float_value = NumberValue(42) + + def test_invalid_default(self): + with self.assertRaises(ValueError): + _ = NumberValue("invalid") + + with self.assertRaises(TypeError): + _ = NumberValue(1.+1j) + + def test_no_default(self): + with self.assertRaises(DefaultMissing): + self.default_value.default() + + def test_integer_default(self): + self.assertIsInstance(self.int_value.default(), numbers.Integral) + + def test_default_to_float(self): + self.assertIsInstance(self.float_value.default(), numbers.Real) + self.assertNotIsInstance(self.float_value.default(), numbers.Integral) + + def test_invalid_unit(self): + with self.assertRaises(KeyError): + _ = NumberValue(unit="invalid") + + def test_default_scale(self): + self.assertEqual(self.default_value.scale, 1.) + + +class BooleanValueCase(unittest.TestCase): + def setUp(self): + self.default_value = BooleanValue() + self.true_value = BooleanValue(True) + self.false_value = BooleanValue(False) + + def test_default(self): + self.assertIs(self.true_value.default(), True) + self.assertIs(self.false_value.default(), False) + + def test_no_default(self): + with self.assertRaises(DefaultMissing): + self.default_value.default() + + def test_invalid_default(self): + with self.assertRaises(ValueError): + _ = BooleanValue(1) + + with self.assertRaises(ValueError): + _ = BooleanValue("abc") + + +class EnumerationValueCase(unittest.TestCase): + def setUp(self): + self.default_value = EnumerationValue(["abc"]) + + def test_no_default(self): + with self.assertRaises(DefaultMissing): + self.default_value.default() + + def test_invalid_default(self): + with self.assertRaises(ValueError): + _ = EnumerationValue("abc", "d") + + def test_valid_default(self): + try: + _ = EnumerationValue("abc", "a") + except ValueError: + self.fail("Unexpected ValueError")