mirror of https://github.com/m-labs/artiq.git
units: error checking
This commit is contained in:
parent
d59d110f78
commit
a3f981726a
|
@ -1,4 +1,7 @@
|
||||||
"""Core ARTIQ extensions to the Python language."""
|
"""
|
||||||
|
Core ARTIQ extensions to the Python language.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
from collections import namedtuple as _namedtuple
|
from collections import namedtuple as _namedtuple
|
||||||
from copy import copy as _copy
|
from copy import copy as _copy
|
||||||
|
|
|
@ -1,17 +1,36 @@
|
||||||
|
"""
|
||||||
|
Definition and management of physical units.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
from fractions import Fraction as _Fraction
|
from fractions import Fraction as _Fraction
|
||||||
|
|
||||||
|
|
||||||
_prefixes_str = "pnum_kMG"
|
class DimensionError(Exception):
|
||||||
_smallest_prefix = _Fraction(1, 10**12)
|
"""Raised when attempting an operation with incompatible units.
|
||||||
|
|
||||||
|
When targeting the core device, all units are statically managed at
|
||||||
|
compilation time. Thus, when raised by functions in this module, this
|
||||||
|
exception cannot be caught in the kernel as it is raised by the compiler
|
||||||
|
instead.
|
||||||
|
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def mul_dimension(l, r):
|
def mul_dimension(l, r):
|
||||||
|
"""Returns the unit obtained by multiplying unit ``l`` with unit ``r``.
|
||||||
|
|
||||||
|
Raises ``DimensionError`` if the resulting unit is not implemented.
|
||||||
|
|
||||||
|
"""
|
||||||
if l is None:
|
if l is None:
|
||||||
return r
|
return r
|
||||||
if r is None:
|
if r is None:
|
||||||
return l
|
return l
|
||||||
if {l, r} == {"Hz", "s"}:
|
if {l, r} == {"Hz", "s"}:
|
||||||
return None
|
return None
|
||||||
|
raise DimensionError
|
||||||
|
|
||||||
|
|
||||||
def _rmul_dimension(l, r):
|
def _rmul_dimension(l, r):
|
||||||
|
@ -19,6 +38,11 @@ def _rmul_dimension(l, r):
|
||||||
|
|
||||||
|
|
||||||
def div_dimension(l, r):
|
def div_dimension(l, r):
|
||||||
|
"""Returns the unit obtained by dividing unit ``l`` with unit ``r``.
|
||||||
|
|
||||||
|
Raises ``DimensionError`` if the resulting unit is not implemented.
|
||||||
|
|
||||||
|
"""
|
||||||
if l == r:
|
if l == r:
|
||||||
return None
|
return None
|
||||||
if r is None:
|
if r is None:
|
||||||
|
@ -28,6 +52,7 @@ def div_dimension(l, r):
|
||||||
return "Hz"
|
return "Hz"
|
||||||
if r == "Hz":
|
if r == "Hz":
|
||||||
return "s"
|
return "s"
|
||||||
|
raise DimensionError
|
||||||
|
|
||||||
|
|
||||||
def _rdiv_dimension(l, r):
|
def _rdiv_dimension(l, r):
|
||||||
|
@ -35,10 +60,20 @@ def _rdiv_dimension(l, r):
|
||||||
|
|
||||||
|
|
||||||
def addsub_dimension(x, y):
|
def addsub_dimension(x, y):
|
||||||
|
"""Returns the unit obtained by adding or subtracting unit ``l`` with
|
||||||
|
unit ``r``.
|
||||||
|
|
||||||
|
Raises ``DimensionError`` if ``l`` and ``r`` are different.
|
||||||
|
|
||||||
|
"""
|
||||||
if x == y:
|
if x == y:
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
return None
|
raise DimensionError
|
||||||
|
|
||||||
|
|
||||||
|
_prefixes_str = "pnum_kMG"
|
||||||
|
_smallest_prefix = _Fraction(1, 10**12)
|
||||||
|
|
||||||
|
|
||||||
def _format(amount, unit):
|
def _format(amount, unit):
|
||||||
|
@ -139,9 +174,9 @@ class Quantity:
|
||||||
|
|
||||||
# comparisons
|
# comparisons
|
||||||
def _cmp(self, other, opf_name):
|
def _cmp(self, other, opf_name):
|
||||||
if isinstance(other, Quantity):
|
if not isinstance(other, Quantity) or other.unit != self.unit:
|
||||||
other = other.amount
|
raise DimensionError
|
||||||
return getattr(self.amount, opf_name)(other)
|
return getattr(self.amount, opf_name)(other.amount)
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
return self._cmp(other, "__lt__")
|
return self._cmp(other, "__lt__")
|
||||||
|
@ -173,3 +208,26 @@ def _register_unit(unit, prefixes):
|
||||||
|
|
||||||
_register_unit("s", "pnum_")
|
_register_unit("s", "pnum_")
|
||||||
_register_unit("Hz", "_kMG")
|
_register_unit("Hz", "_kMG")
|
||||||
|
|
||||||
|
|
||||||
|
def check_unit(value, unit):
|
||||||
|
"""Checks that the value has the specified unit. Unit specification is
|
||||||
|
a string representing the unit without any prefix (e.g. ``s``, ``Hz``).
|
||||||
|
Checking for a dimensionless value (not a ``Quantity`` instance) is done
|
||||||
|
by setting ``unit`` to ``None``.
|
||||||
|
|
||||||
|
If the units do not match, ``DimensionError`` is raised.
|
||||||
|
|
||||||
|
This function can be used in kernels and is executed at compilation time.
|
||||||
|
|
||||||
|
There is already unit checking built into the arithmetic, so you typically
|
||||||
|
need to use this function only when using the ``amount`` property of
|
||||||
|
``Quantity``.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if unit is None:
|
||||||
|
if isinstance(value, Quantity):
|
||||||
|
raise DimensionError
|
||||||
|
else:
|
||||||
|
if not isinstance(value, Quantity) or value.unit != unit:
|
||||||
|
raise DimensionError
|
||||||
|
|
|
@ -3,6 +3,7 @@ import textwrap
|
||||||
import ast
|
import ast
|
||||||
import types
|
import types
|
||||||
import builtins
|
import builtins
|
||||||
|
from copy import copy
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -10,7 +11,7 @@ from itertools import zip_longest, chain
|
||||||
|
|
||||||
from artiq.language import core as core_language
|
from artiq.language import core as core_language
|
||||||
from artiq.language import units
|
from artiq.language import units
|
||||||
from artiq.transforms.tools import value_to_ast, NotASTRepresentable
|
from artiq.transforms.tools import *
|
||||||
|
|
||||||
|
|
||||||
def new_mangled_name(in_use_names, name):
|
def new_mangled_name(in_use_names, name):
|
||||||
|
@ -35,23 +36,6 @@ class AttributeInfo:
|
||||||
self.read_write = read_write
|
self.read_write = read_write
|
||||||
|
|
||||||
|
|
||||||
embeddable_funcs = (
|
|
||||||
core_language.delay, core_language.at, core_language.now,
|
|
||||||
core_language.time_to_cycles, core_language.cycles_to_time,
|
|
||||||
core_language.syscall,
|
|
||||||
range, bool, int, float, round,
|
|
||||||
core_language.int64, core_language.round64, core_language.array,
|
|
||||||
Fraction, units.Quantity, core_language.EncodedException
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_embeddable(func):
|
|
||||||
for ef in embeddable_funcs:
|
|
||||||
if func is ef:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_inlinable(core, func):
|
def is_inlinable(core, func):
|
||||||
if hasattr(func, "k_function_info"):
|
if hasattr(func, "k_function_info"):
|
||||||
if func.k_function_info.core_name == "":
|
if func.k_function_info.core_name == "":
|
||||||
|
@ -493,7 +477,7 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
|
||||||
def inline(core, k_function, k_args, k_kwargs):
|
def inline(core, k_function, k_args, k_kwargs):
|
||||||
# OrderedDict prevents non-determinism in attribute init
|
# OrderedDict prevents non-determinism in attribute init
|
||||||
attribute_namespace = OrderedDict()
|
attribute_namespace = OrderedDict()
|
||||||
in_use_names = {func.__name__ for func in embeddable_funcs}
|
in_use_names = copy(embeddable_func_names)
|
||||||
mappers = types.SimpleNamespace(
|
mappers = types.SimpleNamespace(
|
||||||
rpc=HostObjectMapper(),
|
rpc=HostObjectMapper(),
|
||||||
exception=HostObjectMapper(core_language.first_user_eid)
|
exception=HostObjectMapper(core_language.first_user_eid)
|
||||||
|
|
|
@ -3,6 +3,7 @@ from collections import defaultdict
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
|
||||||
from artiq.language import units
|
from artiq.language import units
|
||||||
|
from artiq.transforms.tools import embeddable_func_names
|
||||||
|
|
||||||
|
|
||||||
def _add_units(f, unit_list):
|
def _add_units(f, unit_list):
|
||||||
|
@ -16,6 +17,7 @@ def _add_units(f, unit_list):
|
||||||
class _UnitsLowerer(ast.NodeTransformer):
|
class _UnitsLowerer(ast.NodeTransformer):
|
||||||
def __init__(self, rpc_map):
|
def __init__(self, rpc_map):
|
||||||
self.rpc_map = rpc_map
|
self.rpc_map = rpc_map
|
||||||
|
# (original rpc number, (unit list)) -> new rpc number
|
||||||
self.rpc_remap = defaultdict(lambda: len(self.rpc_remap))
|
self.rpc_remap = defaultdict(lambda: len(self.rpc_remap))
|
||||||
self.variable_units = dict()
|
self.variable_units = dict()
|
||||||
|
|
||||||
|
@ -29,6 +31,22 @@ class _UnitsLowerer(ast.NodeTransformer):
|
||||||
node.unit = unit
|
node.unit = unit
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
def visit_BoolOp(self, node):
|
||||||
|
self.generic_visit(node)
|
||||||
|
us = [getattr(value, "unit", None) for value in node.values]
|
||||||
|
if not all(u == us[0] for u in us[1:]):
|
||||||
|
raise units.DimensionError
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_Compare(self, node):
|
||||||
|
self.generic_visit(node)
|
||||||
|
u0 = getattr(node.left, "unit", None)
|
||||||
|
us = [getattr(comparator, "unit", None)
|
||||||
|
for comparator in node.comparators]
|
||||||
|
if not all(u == u0 for u in us):
|
||||||
|
raise units.DimensionError
|
||||||
|
return node
|
||||||
|
|
||||||
def visit_UnaryOp(self, node):
|
def visit_UnaryOp(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
if hasattr(node.operand, "unit"):
|
if hasattr(node.operand, "unit"):
|
||||||
|
@ -47,6 +65,8 @@ class _UnitsLowerer(ast.NodeTransformer):
|
||||||
elif op in (ast.Div, ast.FloorDiv):
|
elif op in (ast.Div, ast.FloorDiv):
|
||||||
unit = units.div_dimension(left_unit, right_unit)
|
unit = units.div_dimension(left_unit, right_unit)
|
||||||
else:
|
else:
|
||||||
|
if left_unit is not None or right_unit is not None:
|
||||||
|
raise units.DimensionError
|
||||||
unit = None
|
unit = None
|
||||||
if unit is not None:
|
if unit is not None:
|
||||||
node.unit = unit
|
node.unit = unit
|
||||||
|
@ -66,14 +86,47 @@ class _UnitsLowerer(ast.NodeTransformer):
|
||||||
amount, unit = node.args
|
amount, unit = node.args
|
||||||
amount.unit = unit.s
|
amount.unit = unit.s
|
||||||
return amount
|
return amount
|
||||||
elif node.func.id == "now":
|
elif node.func.id in ("now", "cycles_to_time"):
|
||||||
node.unit = "s"
|
node.unit = "s"
|
||||||
elif node.func.id == "syscall" and node.args[0].s == "rpc":
|
elif node.func.id == "syscall":
|
||||||
unit_list = tuple(getattr(arg, "unit", None) for arg in node.args[2:])
|
# only RPCs can have units
|
||||||
rpc_n = node.args[1].n
|
if node.args[0].s == "rpc":
|
||||||
node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))]
|
unit_list = tuple(getattr(arg, "unit", None)
|
||||||
|
for arg in node.args[2:])
|
||||||
|
rpc_n = node.args[1].n
|
||||||
|
node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))]
|
||||||
|
else:
|
||||||
|
if any(hasattr(arg, "unit") for arg in node.args):
|
||||||
|
raise units.DimensionError
|
||||||
|
elif node.func.id in ("delay", "at", "time_to_cycles"):
|
||||||
|
if getattr(node.args[0], "unit", None) != "s":
|
||||||
|
raise units.DimensionError
|
||||||
|
elif node.func.id == "check_unit":
|
||||||
|
self.generic_visit(node)
|
||||||
|
elif node.func.id in embeddable_func_names:
|
||||||
|
# must be last (some embeddable funcs may have units)
|
||||||
|
if any(hasattr(arg, "unit") for arg in node.args):
|
||||||
|
raise units.DimensionError
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
def visit_Expr(self, node):
|
||||||
|
self.generic_visit(node)
|
||||||
|
if (isinstance(node.value, ast.Call)
|
||||||
|
and node.value.func.id == "check_unit"):
|
||||||
|
call = node.value
|
||||||
|
if (isinstance(call.args[1], ast.NameConstant)
|
||||||
|
and call.args[1].value is None):
|
||||||
|
if hasattr(call.value.args[0], "unit"):
|
||||||
|
raise units.DimensionError
|
||||||
|
elif isinstance(call.args[1], ast.Str):
|
||||||
|
if getattr(call.args[0], "unit", None) != call.args[1].s:
|
||||||
|
raise units.DimensionError
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return node
|
||||||
|
|
||||||
def _update_target(self, target, unit):
|
def _update_target(self, target, unit):
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
if target.id in self.variable_units:
|
if target.id in self.variable_units:
|
||||||
|
|
|
@ -5,6 +5,24 @@ from artiq.language import core as core_language
|
||||||
from artiq.language import units
|
from artiq.language import units
|
||||||
|
|
||||||
|
|
||||||
|
embeddable_funcs = (
|
||||||
|
core_language.delay, core_language.at, core_language.now,
|
||||||
|
core_language.time_to_cycles, core_language.cycles_to_time,
|
||||||
|
core_language.syscall,
|
||||||
|
range, bool, int, float, round,
|
||||||
|
core_language.int64, core_language.round64, core_language.array,
|
||||||
|
Fraction, units.Quantity, units.check_unit, core_language.EncodedException
|
||||||
|
)
|
||||||
|
embeddable_func_names = {func.__name__ for func in embeddable_funcs}
|
||||||
|
|
||||||
|
|
||||||
|
def is_embeddable(func):
|
||||||
|
for ef in embeddable_funcs:
|
||||||
|
if func is ef:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def eval_ast(expr, symdict=dict()):
|
def eval_ast(expr, symdict=dict()):
|
||||||
if not isinstance(expr, ast.Expression):
|
if not isinstance(expr, ast.Expression):
|
||||||
expr = ast.copy_location(ast.Expression(expr), expr)
|
expr = ast.copy_location(ast.Expression(expr), expr)
|
||||||
|
|
|
@ -4,7 +4,7 @@ import os
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
|
|
||||||
from artiq import *
|
from artiq import *
|
||||||
from artiq.language.units import Quantity
|
from artiq.language.units import *
|
||||||
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
|
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
|
||||||
from artiq.sim import devices as sim_devices
|
from artiq.sim import devices as sim_devices
|
||||||
|
|
||||||
|
@ -56,6 +56,22 @@ class _Misc(AutoContext):
|
||||||
self.inhomogeneous_units.append(Quantity(1000, "Hz"))
|
self.inhomogeneous_units.append(Quantity(1000, "Hz"))
|
||||||
self.inhomogeneous_units.append(Quantity(10, "s"))
|
self.inhomogeneous_units.append(Quantity(10, "s"))
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def dimension_error1(self):
|
||||||
|
print(1*Hz + 1*s)
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def dimension_error2(self):
|
||||||
|
print(1*Hz < 1*s)
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def dimension_error3(self):
|
||||||
|
check_unit(1*Hz, "s")
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def dimension_error4(self):
|
||||||
|
delay(10*Hz)
|
||||||
|
|
||||||
|
|
||||||
class _PulseLogger(AutoContext):
|
class _PulseLogger(AutoContext):
|
||||||
parameters = "output_list name"
|
parameters = "output_list name"
|
||||||
|
@ -163,6 +179,14 @@ class ExecutionCase(unittest.TestCase):
|
||||||
Fraction("1.2"))
|
Fraction("1.2"))
|
||||||
self.assertEqual(uut.inhomogeneous_units, [
|
self.assertEqual(uut.inhomogeneous_units, [
|
||||||
Quantity(1000, "Hz"), Quantity(10, "s")])
|
Quantity(1000, "Hz"), Quantity(10, "s")])
|
||||||
|
with self.assertRaises(DimensionError):
|
||||||
|
uut.dimension_error1()
|
||||||
|
with self.assertRaises(DimensionError):
|
||||||
|
uut.dimension_error2()
|
||||||
|
with self.assertRaises(DimensionError):
|
||||||
|
uut.dimension_error3()
|
||||||
|
with self.assertRaises(DimensionError):
|
||||||
|
uut.dimension_error4()
|
||||||
|
|
||||||
def test_pulses(self):
|
def test_pulses(self):
|
||||||
l_device, l_host = [], []
|
l_device, l_host = [], []
|
||||||
|
|
Loading…
Reference in New Issue