forked from M-Labs/artiq
units: error checking
This commit is contained in:
parent
d59d110f78
commit
a3f981726a
|
@ -1,4 +1,7 @@
|
|||
"""Core ARTIQ extensions to the Python language."""
|
||||
"""
|
||||
Core ARTIQ extensions to the Python language.
|
||||
|
||||
"""
|
||||
|
||||
from collections import namedtuple as _namedtuple
|
||||
from copy import copy as _copy
|
||||
|
|
|
@ -1,17 +1,36 @@
|
|||
"""
|
||||
Definition and management of physical units.
|
||||
|
||||
"""
|
||||
|
||||
from fractions import Fraction as _Fraction
|
||||
|
||||
|
||||
_prefixes_str = "pnum_kMG"
|
||||
_smallest_prefix = _Fraction(1, 10**12)
|
||||
class DimensionError(Exception):
|
||||
"""Raised when attempting an operation with incompatible units.
|
||||
|
||||
When targeting the core device, all units are statically managed at
|
||||
compilation time. Thus, when raised by functions in this module, this
|
||||
exception cannot be caught in the kernel as it is raised by the compiler
|
||||
instead.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def mul_dimension(l, r):
|
||||
"""Returns the unit obtained by multiplying unit ``l`` with unit ``r``.
|
||||
|
||||
Raises ``DimensionError`` if the resulting unit is not implemented.
|
||||
|
||||
"""
|
||||
if l is None:
|
||||
return r
|
||||
if r is None:
|
||||
return l
|
||||
if {l, r} == {"Hz", "s"}:
|
||||
return None
|
||||
raise DimensionError
|
||||
|
||||
|
||||
def _rmul_dimension(l, r):
|
||||
|
@ -19,6 +38,11 @@ def _rmul_dimension(l, r):
|
|||
|
||||
|
||||
def div_dimension(l, r):
|
||||
"""Returns the unit obtained by dividing unit ``l`` with unit ``r``.
|
||||
|
||||
Raises ``DimensionError`` if the resulting unit is not implemented.
|
||||
|
||||
"""
|
||||
if l == r:
|
||||
return None
|
||||
if r is None:
|
||||
|
@ -28,6 +52,7 @@ def div_dimension(l, r):
|
|||
return "Hz"
|
||||
if r == "Hz":
|
||||
return "s"
|
||||
raise DimensionError
|
||||
|
||||
|
||||
def _rdiv_dimension(l, r):
|
||||
|
@ -35,10 +60,20 @@ def _rdiv_dimension(l, r):
|
|||
|
||||
|
||||
def addsub_dimension(x, y):
|
||||
"""Returns the unit obtained by adding or subtracting unit ``l`` with
|
||||
unit ``r``.
|
||||
|
||||
Raises ``DimensionError`` if ``l`` and ``r`` are different.
|
||||
|
||||
"""
|
||||
if x == y:
|
||||
return x
|
||||
else:
|
||||
return None
|
||||
raise DimensionError
|
||||
|
||||
|
||||
_prefixes_str = "pnum_kMG"
|
||||
_smallest_prefix = _Fraction(1, 10**12)
|
||||
|
||||
|
||||
def _format(amount, unit):
|
||||
|
@ -139,9 +174,9 @@ class Quantity:
|
|||
|
||||
# comparisons
|
||||
def _cmp(self, other, opf_name):
|
||||
if isinstance(other, Quantity):
|
||||
other = other.amount
|
||||
return getattr(self.amount, opf_name)(other)
|
||||
if not isinstance(other, Quantity) or other.unit != self.unit:
|
||||
raise DimensionError
|
||||
return getattr(self.amount, opf_name)(other.amount)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self._cmp(other, "__lt__")
|
||||
|
@ -173,3 +208,26 @@ def _register_unit(unit, prefixes):
|
|||
|
||||
_register_unit("s", "pnum_")
|
||||
_register_unit("Hz", "_kMG")
|
||||
|
||||
|
||||
def check_unit(value, unit):
|
||||
"""Checks that the value has the specified unit. Unit specification is
|
||||
a string representing the unit without any prefix (e.g. ``s``, ``Hz``).
|
||||
Checking for a dimensionless value (not a ``Quantity`` instance) is done
|
||||
by setting ``unit`` to ``None``.
|
||||
|
||||
If the units do not match, ``DimensionError`` is raised.
|
||||
|
||||
This function can be used in kernels and is executed at compilation time.
|
||||
|
||||
There is already unit checking built into the arithmetic, so you typically
|
||||
need to use this function only when using the ``amount`` property of
|
||||
``Quantity``.
|
||||
|
||||
"""
|
||||
if unit is None:
|
||||
if isinstance(value, Quantity):
|
||||
raise DimensionError
|
||||
else:
|
||||
if not isinstance(value, Quantity) or value.unit != unit:
|
||||
raise DimensionError
|
||||
|
|
|
@ -3,6 +3,7 @@ import textwrap
|
|||
import ast
|
||||
import types
|
||||
import builtins
|
||||
from copy import copy
|
||||
from fractions import Fraction
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
@ -10,7 +11,7 @@ from itertools import zip_longest, chain
|
|||
|
||||
from artiq.language import core as core_language
|
||||
from artiq.language import units
|
||||
from artiq.transforms.tools import value_to_ast, NotASTRepresentable
|
||||
from artiq.transforms.tools import *
|
||||
|
||||
|
||||
def new_mangled_name(in_use_names, name):
|
||||
|
@ -35,23 +36,6 @@ class AttributeInfo:
|
|||
self.read_write = read_write
|
||||
|
||||
|
||||
embeddable_funcs = (
|
||||
core_language.delay, core_language.at, core_language.now,
|
||||
core_language.time_to_cycles, core_language.cycles_to_time,
|
||||
core_language.syscall,
|
||||
range, bool, int, float, round,
|
||||
core_language.int64, core_language.round64, core_language.array,
|
||||
Fraction, units.Quantity, core_language.EncodedException
|
||||
)
|
||||
|
||||
|
||||
def is_embeddable(func):
|
||||
for ef in embeddable_funcs:
|
||||
if func is ef:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_inlinable(core, func):
|
||||
if hasattr(func, "k_function_info"):
|
||||
if func.k_function_info.core_name == "":
|
||||
|
@ -493,7 +477,7 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
|
|||
def inline(core, k_function, k_args, k_kwargs):
|
||||
# OrderedDict prevents non-determinism in attribute init
|
||||
attribute_namespace = OrderedDict()
|
||||
in_use_names = {func.__name__ for func in embeddable_funcs}
|
||||
in_use_names = copy(embeddable_func_names)
|
||||
mappers = types.SimpleNamespace(
|
||||
rpc=HostObjectMapper(),
|
||||
exception=HostObjectMapper(core_language.first_user_eid)
|
||||
|
|
|
@ -3,6 +3,7 @@ from collections import defaultdict
|
|||
from copy import copy
|
||||
|
||||
from artiq.language import units
|
||||
from artiq.transforms.tools import embeddable_func_names
|
||||
|
||||
|
||||
def _add_units(f, unit_list):
|
||||
|
@ -16,6 +17,7 @@ def _add_units(f, unit_list):
|
|||
class _UnitsLowerer(ast.NodeTransformer):
|
||||
def __init__(self, rpc_map):
|
||||
self.rpc_map = rpc_map
|
||||
# (original rpc number, (unit list)) -> new rpc number
|
||||
self.rpc_remap = defaultdict(lambda: len(self.rpc_remap))
|
||||
self.variable_units = dict()
|
||||
|
||||
|
@ -29,6 +31,22 @@ class _UnitsLowerer(ast.NodeTransformer):
|
|||
node.unit = unit
|
||||
return node
|
||||
|
||||
def visit_BoolOp(self, node):
|
||||
self.generic_visit(node)
|
||||
us = [getattr(value, "unit", None) for value in node.values]
|
||||
if not all(u == us[0] for u in us[1:]):
|
||||
raise units.DimensionError
|
||||
return node
|
||||
|
||||
def visit_Compare(self, node):
|
||||
self.generic_visit(node)
|
||||
u0 = getattr(node.left, "unit", None)
|
||||
us = [getattr(comparator, "unit", None)
|
||||
for comparator in node.comparators]
|
||||
if not all(u == u0 for u in us):
|
||||
raise units.DimensionError
|
||||
return node
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
self.generic_visit(node)
|
||||
if hasattr(node.operand, "unit"):
|
||||
|
@ -47,6 +65,8 @@ class _UnitsLowerer(ast.NodeTransformer):
|
|||
elif op in (ast.Div, ast.FloorDiv):
|
||||
unit = units.div_dimension(left_unit, right_unit)
|
||||
else:
|
||||
if left_unit is not None or right_unit is not None:
|
||||
raise units.DimensionError
|
||||
unit = None
|
||||
if unit is not None:
|
||||
node.unit = unit
|
||||
|
@ -66,14 +86,47 @@ class _UnitsLowerer(ast.NodeTransformer):
|
|||
amount, unit = node.args
|
||||
amount.unit = unit.s
|
||||
return amount
|
||||
elif node.func.id == "now":
|
||||
elif node.func.id in ("now", "cycles_to_time"):
|
||||
node.unit = "s"
|
||||
elif node.func.id == "syscall" and node.args[0].s == "rpc":
|
||||
unit_list = tuple(getattr(arg, "unit", None) for arg in node.args[2:])
|
||||
rpc_n = node.args[1].n
|
||||
node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))]
|
||||
elif node.func.id == "syscall":
|
||||
# only RPCs can have units
|
||||
if node.args[0].s == "rpc":
|
||||
unit_list = tuple(getattr(arg, "unit", None)
|
||||
for arg in node.args[2:])
|
||||
rpc_n = node.args[1].n
|
||||
node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))]
|
||||
else:
|
||||
if any(hasattr(arg, "unit") for arg in node.args):
|
||||
raise units.DimensionError
|
||||
elif node.func.id in ("delay", "at", "time_to_cycles"):
|
||||
if getattr(node.args[0], "unit", None) != "s":
|
||||
raise units.DimensionError
|
||||
elif node.func.id == "check_unit":
|
||||
self.generic_visit(node)
|
||||
elif node.func.id in embeddable_func_names:
|
||||
# must be last (some embeddable funcs may have units)
|
||||
if any(hasattr(arg, "unit") for arg in node.args):
|
||||
raise units.DimensionError
|
||||
return node
|
||||
|
||||
def visit_Expr(self, node):
|
||||
self.generic_visit(node)
|
||||
if (isinstance(node.value, ast.Call)
|
||||
and node.value.func.id == "check_unit"):
|
||||
call = node.value
|
||||
if (isinstance(call.args[1], ast.NameConstant)
|
||||
and call.args[1].value is None):
|
||||
if hasattr(call.value.args[0], "unit"):
|
||||
raise units.DimensionError
|
||||
elif isinstance(call.args[1], ast.Str):
|
||||
if getattr(call.args[0], "unit", None) != call.args[1].s:
|
||||
raise units.DimensionError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return None
|
||||
else:
|
||||
return node
|
||||
|
||||
def _update_target(self, target, unit):
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in self.variable_units:
|
||||
|
|
|
@ -5,6 +5,24 @@ from artiq.language import core as core_language
|
|||
from artiq.language import units
|
||||
|
||||
|
||||
embeddable_funcs = (
|
||||
core_language.delay, core_language.at, core_language.now,
|
||||
core_language.time_to_cycles, core_language.cycles_to_time,
|
||||
core_language.syscall,
|
||||
range, bool, int, float, round,
|
||||
core_language.int64, core_language.round64, core_language.array,
|
||||
Fraction, units.Quantity, units.check_unit, core_language.EncodedException
|
||||
)
|
||||
embeddable_func_names = {func.__name__ for func in embeddable_funcs}
|
||||
|
||||
|
||||
def is_embeddable(func):
|
||||
for ef in embeddable_funcs:
|
||||
if func is ef:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def eval_ast(expr, symdict=dict()):
|
||||
if not isinstance(expr, ast.Expression):
|
||||
expr = ast.copy_location(ast.Expression(expr), expr)
|
||||
|
|
|
@ -4,7 +4,7 @@ import os
|
|||
from fractions import Fraction
|
||||
|
||||
from artiq import *
|
||||
from artiq.language.units import Quantity
|
||||
from artiq.language.units import *
|
||||
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
|
||||
from artiq.sim import devices as sim_devices
|
||||
|
||||
|
@ -56,6 +56,22 @@ class _Misc(AutoContext):
|
|||
self.inhomogeneous_units.append(Quantity(1000, "Hz"))
|
||||
self.inhomogeneous_units.append(Quantity(10, "s"))
|
||||
|
||||
@kernel
|
||||
def dimension_error1(self):
|
||||
print(1*Hz + 1*s)
|
||||
|
||||
@kernel
|
||||
def dimension_error2(self):
|
||||
print(1*Hz < 1*s)
|
||||
|
||||
@kernel
|
||||
def dimension_error3(self):
|
||||
check_unit(1*Hz, "s")
|
||||
|
||||
@kernel
|
||||
def dimension_error4(self):
|
||||
delay(10*Hz)
|
||||
|
||||
|
||||
class _PulseLogger(AutoContext):
|
||||
parameters = "output_list name"
|
||||
|
@ -163,6 +179,14 @@ class ExecutionCase(unittest.TestCase):
|
|||
Fraction("1.2"))
|
||||
self.assertEqual(uut.inhomogeneous_units, [
|
||||
Quantity(1000, "Hz"), Quantity(10, "s")])
|
||||
with self.assertRaises(DimensionError):
|
||||
uut.dimension_error1()
|
||||
with self.assertRaises(DimensionError):
|
||||
uut.dimension_error2()
|
||||
with self.assertRaises(DimensionError):
|
||||
uut.dimension_error3()
|
||||
with self.assertRaises(DimensionError):
|
||||
uut.dimension_error4()
|
||||
|
||||
def test_pulses(self):
|
||||
l_device, l_host = [], []
|
||||
|
|
Loading…
Reference in New Issue