units: error checking

This commit is contained in:
Sebastien Bourdeauducq 2014-11-22 16:56:51 -08:00
parent d59d110f78
commit a3f981726a
6 changed files with 172 additions and 32 deletions

View File

@ -1,4 +1,7 @@
"""Core ARTIQ extensions to the Python language."""
"""
Core ARTIQ extensions to the Python language.
"""
from collections import namedtuple as _namedtuple
from copy import copy as _copy

View File

@ -1,17 +1,36 @@
"""
Definition and management of physical units.
"""
from fractions import Fraction as _Fraction
_prefixes_str = "pnum_kMG"
_smallest_prefix = _Fraction(1, 10**12)
class DimensionError(Exception):
"""Raised when attempting an operation with incompatible units.
When targeting the core device, all units are statically managed at
compilation time. Thus, when raised by functions in this module, this
exception cannot be caught in the kernel as it is raised by the compiler
instead.
"""
pass
def mul_dimension(l, r):
"""Returns the unit obtained by multiplying unit ``l`` with unit ``r``.
Raises ``DimensionError`` if the resulting unit is not implemented.
"""
if l is None:
return r
if r is None:
return l
if {l, r} == {"Hz", "s"}:
return None
raise DimensionError
def _rmul_dimension(l, r):
@ -19,6 +38,11 @@ def _rmul_dimension(l, r):
def div_dimension(l, r):
"""Returns the unit obtained by dividing unit ``l`` with unit ``r``.
Raises ``DimensionError`` if the resulting unit is not implemented.
"""
if l == r:
return None
if r is None:
@ -28,6 +52,7 @@ def div_dimension(l, r):
return "Hz"
if r == "Hz":
return "s"
raise DimensionError
def _rdiv_dimension(l, r):
@ -35,10 +60,20 @@ def _rdiv_dimension(l, r):
def addsub_dimension(x, y):
"""Returns the unit obtained by adding or subtracting unit ``l`` with
unit ``r``.
Raises ``DimensionError`` if ``l`` and ``r`` are different.
"""
if x == y:
return x
else:
return None
raise DimensionError
_prefixes_str = "pnum_kMG"
_smallest_prefix = _Fraction(1, 10**12)
def _format(amount, unit):
@ -139,9 +174,9 @@ class Quantity:
# comparisons
def _cmp(self, other, opf_name):
if isinstance(other, Quantity):
other = other.amount
return getattr(self.amount, opf_name)(other)
if not isinstance(other, Quantity) or other.unit != self.unit:
raise DimensionError
return getattr(self.amount, opf_name)(other.amount)
def __lt__(self, other):
return self._cmp(other, "__lt__")
@ -173,3 +208,26 @@ def _register_unit(unit, prefixes):
_register_unit("s", "pnum_")
_register_unit("Hz", "_kMG")
def check_unit(value, unit):
"""Checks that the value has the specified unit. Unit specification is
a string representing the unit without any prefix (e.g. ``s``, ``Hz``).
Checking for a dimensionless value (not a ``Quantity`` instance) is done
by setting ``unit`` to ``None``.
If the units do not match, ``DimensionError`` is raised.
This function can be used in kernels and is executed at compilation time.
There is already unit checking built into the arithmetic, so you typically
need to use this function only when using the ``amount`` property of
``Quantity``.
"""
if unit is None:
if isinstance(value, Quantity):
raise DimensionError
else:
if not isinstance(value, Quantity) or value.unit != unit:
raise DimensionError

View File

@ -3,6 +3,7 @@ import textwrap
import ast
import types
import builtins
from copy import copy
from fractions import Fraction
from collections import OrderedDict
from functools import partial
@ -10,7 +11,7 @@ from itertools import zip_longest, chain
from artiq.language import core as core_language
from artiq.language import units
from artiq.transforms.tools import value_to_ast, NotASTRepresentable
from artiq.transforms.tools import *
def new_mangled_name(in_use_names, name):
@ -35,23 +36,6 @@ class AttributeInfo:
self.read_write = read_write
embeddable_funcs = (
core_language.delay, core_language.at, core_language.now,
core_language.time_to_cycles, core_language.cycles_to_time,
core_language.syscall,
range, bool, int, float, round,
core_language.int64, core_language.round64, core_language.array,
Fraction, units.Quantity, core_language.EncodedException
)
def is_embeddable(func):
for ef in embeddable_funcs:
if func is ef:
return True
return False
def is_inlinable(core, func):
if hasattr(func, "k_function_info"):
if func.k_function_info.core_name == "":
@ -493,7 +477,7 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
def inline(core, k_function, k_args, k_kwargs):
# OrderedDict prevents non-determinism in attribute init
attribute_namespace = OrderedDict()
in_use_names = {func.__name__ for func in embeddable_funcs}
in_use_names = copy(embeddable_func_names)
mappers = types.SimpleNamespace(
rpc=HostObjectMapper(),
exception=HostObjectMapper(core_language.first_user_eid)

View File

@ -3,6 +3,7 @@ from collections import defaultdict
from copy import copy
from artiq.language import units
from artiq.transforms.tools import embeddable_func_names
def _add_units(f, unit_list):
@ -16,6 +17,7 @@ def _add_units(f, unit_list):
class _UnitsLowerer(ast.NodeTransformer):
def __init__(self, rpc_map):
self.rpc_map = rpc_map
# (original rpc number, (unit list)) -> new rpc number
self.rpc_remap = defaultdict(lambda: len(self.rpc_remap))
self.variable_units = dict()
@ -29,6 +31,22 @@ class _UnitsLowerer(ast.NodeTransformer):
node.unit = unit
return node
def visit_BoolOp(self, node):
self.generic_visit(node)
us = [getattr(value, "unit", None) for value in node.values]
if not all(u == us[0] for u in us[1:]):
raise units.DimensionError
return node
def visit_Compare(self, node):
self.generic_visit(node)
u0 = getattr(node.left, "unit", None)
us = [getattr(comparator, "unit", None)
for comparator in node.comparators]
if not all(u == u0 for u in us):
raise units.DimensionError
return node
def visit_UnaryOp(self, node):
self.generic_visit(node)
if hasattr(node.operand, "unit"):
@ -47,6 +65,8 @@ class _UnitsLowerer(ast.NodeTransformer):
elif op in (ast.Div, ast.FloorDiv):
unit = units.div_dimension(left_unit, right_unit)
else:
if left_unit is not None or right_unit is not None:
raise units.DimensionError
unit = None
if unit is not None:
node.unit = unit
@ -66,14 +86,47 @@ class _UnitsLowerer(ast.NodeTransformer):
amount, unit = node.args
amount.unit = unit.s
return amount
elif node.func.id == "now":
elif node.func.id in ("now", "cycles_to_time"):
node.unit = "s"
elif node.func.id == "syscall" and node.args[0].s == "rpc":
unit_list = tuple(getattr(arg, "unit", None) for arg in node.args[2:])
rpc_n = node.args[1].n
node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))]
elif node.func.id == "syscall":
# only RPCs can have units
if node.args[0].s == "rpc":
unit_list = tuple(getattr(arg, "unit", None)
for arg in node.args[2:])
rpc_n = node.args[1].n
node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))]
else:
if any(hasattr(arg, "unit") for arg in node.args):
raise units.DimensionError
elif node.func.id in ("delay", "at", "time_to_cycles"):
if getattr(node.args[0], "unit", None) != "s":
raise units.DimensionError
elif node.func.id == "check_unit":
self.generic_visit(node)
elif node.func.id in embeddable_func_names:
# must be last (some embeddable funcs may have units)
if any(hasattr(arg, "unit") for arg in node.args):
raise units.DimensionError
return node
def visit_Expr(self, node):
self.generic_visit(node)
if (isinstance(node.value, ast.Call)
and node.value.func.id == "check_unit"):
call = node.value
if (isinstance(call.args[1], ast.NameConstant)
and call.args[1].value is None):
if hasattr(call.value.args[0], "unit"):
raise units.DimensionError
elif isinstance(call.args[1], ast.Str):
if getattr(call.args[0], "unit", None) != call.args[1].s:
raise units.DimensionError
else:
raise NotImplementedError
return None
else:
return node
def _update_target(self, target, unit):
if isinstance(target, ast.Name):
if target.id in self.variable_units:

View File

@ -5,6 +5,24 @@ from artiq.language import core as core_language
from artiq.language import units
embeddable_funcs = (
core_language.delay, core_language.at, core_language.now,
core_language.time_to_cycles, core_language.cycles_to_time,
core_language.syscall,
range, bool, int, float, round,
core_language.int64, core_language.round64, core_language.array,
Fraction, units.Quantity, units.check_unit, core_language.EncodedException
)
embeddable_func_names = {func.__name__ for func in embeddable_funcs}
def is_embeddable(func):
for ef in embeddable_funcs:
if func is ef:
return True
return False
def eval_ast(expr, symdict=dict()):
if not isinstance(expr, ast.Expression):
expr = ast.copy_location(ast.Expression(expr), expr)

View File

@ -4,7 +4,7 @@ import os
from fractions import Fraction
from artiq import *
from artiq.language.units import Quantity
from artiq.language.units import *
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
from artiq.sim import devices as sim_devices
@ -56,6 +56,22 @@ class _Misc(AutoContext):
self.inhomogeneous_units.append(Quantity(1000, "Hz"))
self.inhomogeneous_units.append(Quantity(10, "s"))
@kernel
def dimension_error1(self):
print(1*Hz + 1*s)
@kernel
def dimension_error2(self):
print(1*Hz < 1*s)
@kernel
def dimension_error3(self):
check_unit(1*Hz, "s")
@kernel
def dimension_error4(self):
delay(10*Hz)
class _PulseLogger(AutoContext):
parameters = "output_list name"
@ -163,6 +179,14 @@ class ExecutionCase(unittest.TestCase):
Fraction("1.2"))
self.assertEqual(uut.inhomogeneous_units, [
Quantity(1000, "Hz"), Quantity(10, "s")])
with self.assertRaises(DimensionError):
uut.dimension_error1()
with self.assertRaises(DimensionError):
uut.dimension_error2()
with self.assertRaises(DimensionError):
uut.dimension_error3()
with self.assertRaises(DimensionError):
uut.dimension_error4()
def test_pulses(self):
l_device, l_host = [], []