language/environment: cast argument processor default values early

Fixes #1434. Also add unit tests for some argument processors.

Signed-off-by: Etienne Wodey <wodey@iqo.uni-hannover.de>
pull/1445/head
Etienne Wodey 2020-04-03 16:15:47 +02:00 committed by Sébastien Bourdeauducq
parent 371d923385
commit 9b03a365ed
2 changed files with 97 additions and 14 deletions

View File

@ -32,7 +32,7 @@ class _SimpleArgProcessor:
if isinstance(default, list):
raise NotImplementedError
if default is not NoDefault:
self.default_value = default
self.default_value = self.process(default)
def default(self):
if not hasattr(self, "default_value"):
@ -54,6 +54,7 @@ class PYONValue(_SimpleArgProcessor):
def __init__(self, default=NoDefault):
# Override the _SimpleArgProcessor init, as list defaults are valid
# PYON values
# default stays decoded
if default is not NoDefault:
self.default_value = default
@ -69,7 +70,13 @@ class PYONValue(_SimpleArgProcessor):
class BooleanValue(_SimpleArgProcessor):
"""A boolean argument."""
pass
def process(self, x):
if x is True:
return True
elif x is False:
return False
else:
raise ValueError("Invalid BooleanValue value")
class EnumerationValue(_SimpleArgProcessor):
@ -80,15 +87,20 @@ class EnumerationValue(_SimpleArgProcessor):
argument.
"""
def __init__(self, choices, default=NoDefault):
_SimpleArgProcessor.__init__(self, default)
assert default is NoDefault or default in choices
self.choices = choices
super().__init__(default)
def process(self, x):
if x not in self.choices:
raise ValueError("Invalid EnumerationValue value")
return x
def describe(self):
d = _SimpleArgProcessor.describe(self)
d["choices"] = self.choices
return d
class NumberValue(_SimpleArgProcessor):
"""An argument that can take a numerical value.
@ -132,8 +144,6 @@ class NumberValue(_SimpleArgProcessor):
"the scale manually".format(unit))
if step is None:
step = scale/10.0
if default is not NoDefault:
self.default_value = default
self.unit = unit
self.scale = scale
self.step = step
@ -141,19 +151,13 @@ class NumberValue(_SimpleArgProcessor):
self.max = max
self.ndecimals = ndecimals
super().__init__(default)
def _is_int(self):
return (self.ndecimals == 0
and int(self.step) == self.step
and self.scale == 1)
def default(self):
if not hasattr(self, "default_value"):
raise DefaultMissing
if self._is_int():
return int(self.default_value)
else:
return float(self.default_value)
def process(self, x):
if self._is_int():
return int(x)

View File

@ -0,0 +1,79 @@
import unittest
import numbers
from artiq.language.environment import BooleanValue, EnumerationValue, \
NumberValue, DefaultMissing
class NumberValueCase(unittest.TestCase):
def setUp(self):
self.default_value = NumberValue()
self.int_value = NumberValue(42, step=1, ndecimals=0)
self.float_value = NumberValue(42)
def test_invalid_default(self):
with self.assertRaises(ValueError):
_ = NumberValue("invalid")
with self.assertRaises(TypeError):
_ = NumberValue(1.+1j)
def test_no_default(self):
with self.assertRaises(DefaultMissing):
self.default_value.default()
def test_integer_default(self):
self.assertIsInstance(self.int_value.default(), numbers.Integral)
def test_default_to_float(self):
self.assertIsInstance(self.float_value.default(), numbers.Real)
self.assertNotIsInstance(self.float_value.default(), numbers.Integral)
def test_invalid_unit(self):
with self.assertRaises(KeyError):
_ = NumberValue(unit="invalid")
def test_default_scale(self):
self.assertEqual(self.default_value.scale, 1.)
class BooleanValueCase(unittest.TestCase):
def setUp(self):
self.default_value = BooleanValue()
self.true_value = BooleanValue(True)
self.false_value = BooleanValue(False)
def test_default(self):
self.assertIs(self.true_value.default(), True)
self.assertIs(self.false_value.default(), False)
def test_no_default(self):
with self.assertRaises(DefaultMissing):
self.default_value.default()
def test_invalid_default(self):
with self.assertRaises(ValueError):
_ = BooleanValue(1)
with self.assertRaises(ValueError):
_ = BooleanValue("abc")
class EnumerationValueCase(unittest.TestCase):
def setUp(self):
self.default_value = EnumerationValue(["abc"])
def test_no_default(self):
with self.assertRaises(DefaultMissing):
self.default_value.default()
def test_invalid_default(self):
with self.assertRaises(ValueError):
_ = EnumerationValue("abc", "d")
def test_valid_default(self):
try:
_ = EnumerationValue("abc", "a")
except ValueError:
self.fail("Unexpected ValueError")