units: error checking

This commit is contained in:
Sebastien Bourdeauducq 2014-11-22 16:56:51 -08:00
parent d59d110f78
commit a3f981726a
6 changed files with 172 additions and 32 deletions

View File

@ -1,4 +1,7 @@
"""Core ARTIQ extensions to the Python language.""" """
Core ARTIQ extensions to the Python language.
"""
from collections import namedtuple as _namedtuple from collections import namedtuple as _namedtuple
from copy import copy as _copy from copy import copy as _copy

View File

@ -1,17 +1,36 @@
"""
Definition and management of physical units.
"""
from fractions import Fraction as _Fraction from fractions import Fraction as _Fraction
_prefixes_str = "pnum_kMG" class DimensionError(Exception):
_smallest_prefix = _Fraction(1, 10**12) """Raised when attempting an operation with incompatible units.
When targeting the core device, all units are statically managed at
compilation time. Thus, when raised by functions in this module, this
exception cannot be caught in the kernel as it is raised by the compiler
instead.
"""
pass
def mul_dimension(l, r): def mul_dimension(l, r):
"""Returns the unit obtained by multiplying unit ``l`` with unit ``r``.
Raises ``DimensionError`` if the resulting unit is not implemented.
"""
if l is None: if l is None:
return r return r
if r is None: if r is None:
return l return l
if {l, r} == {"Hz", "s"}: if {l, r} == {"Hz", "s"}:
return None return None
raise DimensionError
def _rmul_dimension(l, r): def _rmul_dimension(l, r):
@ -19,6 +38,11 @@ def _rmul_dimension(l, r):
def div_dimension(l, r): def div_dimension(l, r):
"""Returns the unit obtained by dividing unit ``l`` with unit ``r``.
Raises ``DimensionError`` if the resulting unit is not implemented.
"""
if l == r: if l == r:
return None return None
if r is None: if r is None:
@ -28,6 +52,7 @@ def div_dimension(l, r):
return "Hz" return "Hz"
if r == "Hz": if r == "Hz":
return "s" return "s"
raise DimensionError
def _rdiv_dimension(l, r): def _rdiv_dimension(l, r):
@ -35,10 +60,20 @@ def _rdiv_dimension(l, r):
def addsub_dimension(x, y): def addsub_dimension(x, y):
"""Returns the unit obtained by adding or subtracting unit ``l`` with
unit ``r``.
Raises ``DimensionError`` if ``l`` and ``r`` are different.
"""
if x == y: if x == y:
return x return x
else: else:
return None raise DimensionError
_prefixes_str = "pnum_kMG"
_smallest_prefix = _Fraction(1, 10**12)
def _format(amount, unit): def _format(amount, unit):
@ -139,9 +174,9 @@ class Quantity:
# comparisons # comparisons
def _cmp(self, other, opf_name): def _cmp(self, other, opf_name):
if isinstance(other, Quantity): if not isinstance(other, Quantity) or other.unit != self.unit:
other = other.amount raise DimensionError
return getattr(self.amount, opf_name)(other) return getattr(self.amount, opf_name)(other.amount)
def __lt__(self, other): def __lt__(self, other):
return self._cmp(other, "__lt__") return self._cmp(other, "__lt__")
@ -173,3 +208,26 @@ def _register_unit(unit, prefixes):
_register_unit("s", "pnum_") _register_unit("s", "pnum_")
_register_unit("Hz", "_kMG") _register_unit("Hz", "_kMG")
def check_unit(value, unit):
"""Checks that the value has the specified unit. Unit specification is
a string representing the unit without any prefix (e.g. ``s``, ``Hz``).
Checking for a dimensionless value (not a ``Quantity`` instance) is done
by setting ``unit`` to ``None``.
If the units do not match, ``DimensionError`` is raised.
This function can be used in kernels and is executed at compilation time.
There is already unit checking built into the arithmetic, so you typically
need to use this function only when using the ``amount`` property of
``Quantity``.
"""
if unit is None:
if isinstance(value, Quantity):
raise DimensionError
else:
if not isinstance(value, Quantity) or value.unit != unit:
raise DimensionError

View File

@ -3,6 +3,7 @@ import textwrap
import ast import ast
import types import types
import builtins import builtins
from copy import copy
from fractions import Fraction from fractions import Fraction
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
@ -10,7 +11,7 @@ from itertools import zip_longest, chain
from artiq.language import core as core_language from artiq.language import core as core_language
from artiq.language import units from artiq.language import units
from artiq.transforms.tools import value_to_ast, NotASTRepresentable from artiq.transforms.tools import *
def new_mangled_name(in_use_names, name): def new_mangled_name(in_use_names, name):
@ -35,23 +36,6 @@ class AttributeInfo:
self.read_write = read_write self.read_write = read_write
embeddable_funcs = (
core_language.delay, core_language.at, core_language.now,
core_language.time_to_cycles, core_language.cycles_to_time,
core_language.syscall,
range, bool, int, float, round,
core_language.int64, core_language.round64, core_language.array,
Fraction, units.Quantity, core_language.EncodedException
)
def is_embeddable(func):
for ef in embeddable_funcs:
if func is ef:
return True
return False
def is_inlinable(core, func): def is_inlinable(core, func):
if hasattr(func, "k_function_info"): if hasattr(func, "k_function_info"):
if func.k_function_info.core_name == "": if func.k_function_info.core_name == "":
@ -493,7 +477,7 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
def inline(core, k_function, k_args, k_kwargs): def inline(core, k_function, k_args, k_kwargs):
# OrderedDict prevents non-determinism in attribute init # OrderedDict prevents non-determinism in attribute init
attribute_namespace = OrderedDict() attribute_namespace = OrderedDict()
in_use_names = {func.__name__ for func in embeddable_funcs} in_use_names = copy(embeddable_func_names)
mappers = types.SimpleNamespace( mappers = types.SimpleNamespace(
rpc=HostObjectMapper(), rpc=HostObjectMapper(),
exception=HostObjectMapper(core_language.first_user_eid) exception=HostObjectMapper(core_language.first_user_eid)

View File

@ -3,6 +3,7 @@ from collections import defaultdict
from copy import copy from copy import copy
from artiq.language import units from artiq.language import units
from artiq.transforms.tools import embeddable_func_names
def _add_units(f, unit_list): def _add_units(f, unit_list):
@ -16,6 +17,7 @@ def _add_units(f, unit_list):
class _UnitsLowerer(ast.NodeTransformer): class _UnitsLowerer(ast.NodeTransformer):
def __init__(self, rpc_map): def __init__(self, rpc_map):
self.rpc_map = rpc_map self.rpc_map = rpc_map
# (original rpc number, (unit list)) -> new rpc number
self.rpc_remap = defaultdict(lambda: len(self.rpc_remap)) self.rpc_remap = defaultdict(lambda: len(self.rpc_remap))
self.variable_units = dict() self.variable_units = dict()
@ -29,6 +31,22 @@ class _UnitsLowerer(ast.NodeTransformer):
node.unit = unit node.unit = unit
return node return node
def visit_BoolOp(self, node):
self.generic_visit(node)
us = [getattr(value, "unit", None) for value in node.values]
if not all(u == us[0] for u in us[1:]):
raise units.DimensionError
return node
def visit_Compare(self, node):
self.generic_visit(node)
u0 = getattr(node.left, "unit", None)
us = [getattr(comparator, "unit", None)
for comparator in node.comparators]
if not all(u == u0 for u in us):
raise units.DimensionError
return node
def visit_UnaryOp(self, node): def visit_UnaryOp(self, node):
self.generic_visit(node) self.generic_visit(node)
if hasattr(node.operand, "unit"): if hasattr(node.operand, "unit"):
@ -47,6 +65,8 @@ class _UnitsLowerer(ast.NodeTransformer):
elif op in (ast.Div, ast.FloorDiv): elif op in (ast.Div, ast.FloorDiv):
unit = units.div_dimension(left_unit, right_unit) unit = units.div_dimension(left_unit, right_unit)
else: else:
if left_unit is not None or right_unit is not None:
raise units.DimensionError
unit = None unit = None
if unit is not None: if unit is not None:
node.unit = unit node.unit = unit
@ -66,14 +86,47 @@ class _UnitsLowerer(ast.NodeTransformer):
amount, unit = node.args amount, unit = node.args
amount.unit = unit.s amount.unit = unit.s
return amount return amount
elif node.func.id == "now": elif node.func.id in ("now", "cycles_to_time"):
node.unit = "s" node.unit = "s"
elif node.func.id == "syscall" and node.args[0].s == "rpc": elif node.func.id == "syscall":
unit_list = tuple(getattr(arg, "unit", None) for arg in node.args[2:]) # only RPCs can have units
rpc_n = node.args[1].n if node.args[0].s == "rpc":
node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))] unit_list = tuple(getattr(arg, "unit", None)
for arg in node.args[2:])
rpc_n = node.args[1].n
node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))]
else:
if any(hasattr(arg, "unit") for arg in node.args):
raise units.DimensionError
elif node.func.id in ("delay", "at", "time_to_cycles"):
if getattr(node.args[0], "unit", None) != "s":
raise units.DimensionError
elif node.func.id == "check_unit":
self.generic_visit(node)
elif node.func.id in embeddable_func_names:
# must be last (some embeddable funcs may have units)
if any(hasattr(arg, "unit") for arg in node.args):
raise units.DimensionError
return node return node
def visit_Expr(self, node):
self.generic_visit(node)
if (isinstance(node.value, ast.Call)
and node.value.func.id == "check_unit"):
call = node.value
if (isinstance(call.args[1], ast.NameConstant)
and call.args[1].value is None):
if hasattr(call.value.args[0], "unit"):
raise units.DimensionError
elif isinstance(call.args[1], ast.Str):
if getattr(call.args[0], "unit", None) != call.args[1].s:
raise units.DimensionError
else:
raise NotImplementedError
return None
else:
return node
def _update_target(self, target, unit): def _update_target(self, target, unit):
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
if target.id in self.variable_units: if target.id in self.variable_units:

View File

@ -5,6 +5,24 @@ from artiq.language import core as core_language
from artiq.language import units from artiq.language import units
embeddable_funcs = (
core_language.delay, core_language.at, core_language.now,
core_language.time_to_cycles, core_language.cycles_to_time,
core_language.syscall,
range, bool, int, float, round,
core_language.int64, core_language.round64, core_language.array,
Fraction, units.Quantity, units.check_unit, core_language.EncodedException
)
embeddable_func_names = {func.__name__ for func in embeddable_funcs}
def is_embeddable(func):
for ef in embeddable_funcs:
if func is ef:
return True
return False
def eval_ast(expr, symdict=dict()): def eval_ast(expr, symdict=dict()):
if not isinstance(expr, ast.Expression): if not isinstance(expr, ast.Expression):
expr = ast.copy_location(ast.Expression(expr), expr) expr = ast.copy_location(ast.Expression(expr), expr)

View File

@ -4,7 +4,7 @@ import os
from fractions import Fraction from fractions import Fraction
from artiq import * from artiq import *
from artiq.language.units import Quantity from artiq.language.units import *
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
from artiq.sim import devices as sim_devices from artiq.sim import devices as sim_devices
@ -56,6 +56,22 @@ class _Misc(AutoContext):
self.inhomogeneous_units.append(Quantity(1000, "Hz")) self.inhomogeneous_units.append(Quantity(1000, "Hz"))
self.inhomogeneous_units.append(Quantity(10, "s")) self.inhomogeneous_units.append(Quantity(10, "s"))
@kernel
def dimension_error1(self):
print(1*Hz + 1*s)
@kernel
def dimension_error2(self):
print(1*Hz < 1*s)
@kernel
def dimension_error3(self):
check_unit(1*Hz, "s")
@kernel
def dimension_error4(self):
delay(10*Hz)
class _PulseLogger(AutoContext): class _PulseLogger(AutoContext):
parameters = "output_list name" parameters = "output_list name"
@ -163,6 +179,14 @@ class ExecutionCase(unittest.TestCase):
Fraction("1.2")) Fraction("1.2"))
self.assertEqual(uut.inhomogeneous_units, [ self.assertEqual(uut.inhomogeneous_units, [
Quantity(1000, "Hz"), Quantity(10, "s")]) Quantity(1000, "Hz"), Quantity(10, "s")])
with self.assertRaises(DimensionError):
uut.dimension_error1()
with self.assertRaises(DimensionError):
uut.dimension_error2()
with self.assertRaises(DimensionError):
uut.dimension_error3()
with self.assertRaises(DimensionError):
uut.dimension_error4()
def test_pulses(self): def test_pulses(self):
l_device, l_host = [], [] l_device, l_host = [], []