From a3f981726a49af7ff25ac29c73e8bcad2f7da467 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sat, 22 Nov 2014 16:56:51 -0800 Subject: [PATCH] units: error checking --- artiq/language/core.py | 5 ++- artiq/language/units.py | 70 ++++++++++++++++++++++++++++++--- artiq/transforms/inline.py | 22 ++--------- artiq/transforms/lower_units.py | 63 ++++++++++++++++++++++++++--- artiq/transforms/tools.py | 18 +++++++++ test/full_stack.py | 26 +++++++++++- 6 files changed, 172 insertions(+), 32 deletions(-) diff --git a/artiq/language/core.py b/artiq/language/core.py index d47432ce0..9c98450cf 100644 --- a/artiq/language/core.py +++ b/artiq/language/core.py @@ -1,4 +1,7 @@ -"""Core ARTIQ extensions to the Python language.""" +""" +Core ARTIQ extensions to the Python language. + +""" from collections import namedtuple as _namedtuple from copy import copy as _copy diff --git a/artiq/language/units.py b/artiq/language/units.py index 6a7542eb3..f269748ea 100644 --- a/artiq/language/units.py +++ b/artiq/language/units.py @@ -1,17 +1,36 @@ +""" +Definition and management of physical units. + +""" + from fractions import Fraction as _Fraction -_prefixes_str = "pnum_kMG" -_smallest_prefix = _Fraction(1, 10**12) +class DimensionError(Exception): + """Raised when attempting an operation with incompatible units. + + When targeting the core device, all units are statically managed at + compilation time. Thus, when raised by functions in this module, this + exception cannot be caught in the kernel as it is raised by the compiler + instead. + + """ + pass def mul_dimension(l, r): + """Returns the unit obtained by multiplying unit ``l`` with unit ``r``. + + Raises ``DimensionError`` if the resulting unit is not implemented. + + """ if l is None: return r if r is None: return l if {l, r} == {"Hz", "s"}: return None + raise DimensionError def _rmul_dimension(l, r): @@ -19,6 +38,11 @@ def _rmul_dimension(l, r): def div_dimension(l, r): + """Returns the unit obtained by dividing unit ``l`` with unit ``r``. + + Raises ``DimensionError`` if the resulting unit is not implemented. + + """ if l == r: return None if r is None: @@ -28,6 +52,7 @@ def div_dimension(l, r): return "Hz" if r == "Hz": return "s" + raise DimensionError def _rdiv_dimension(l, r): @@ -35,10 +60,20 @@ def _rdiv_dimension(l, r): def addsub_dimension(x, y): + """Returns the unit obtained by adding or subtracting unit ``l`` with + unit ``r``. + + Raises ``DimensionError`` if ``l`` and ``r`` are different. + + """ if x == y: return x else: - return None + raise DimensionError + + +_prefixes_str = "pnum_kMG" +_smallest_prefix = _Fraction(1, 10**12) def _format(amount, unit): @@ -139,9 +174,9 @@ class Quantity: # comparisons def _cmp(self, other, opf_name): - if isinstance(other, Quantity): - other = other.amount - return getattr(self.amount, opf_name)(other) + if not isinstance(other, Quantity) or other.unit != self.unit: + raise DimensionError + return getattr(self.amount, opf_name)(other.amount) def __lt__(self, other): return self._cmp(other, "__lt__") @@ -173,3 +208,26 @@ def _register_unit(unit, prefixes): _register_unit("s", "pnum_") _register_unit("Hz", "_kMG") + + +def check_unit(value, unit): + """Checks that the value has the specified unit. Unit specification is + a string representing the unit without any prefix (e.g. ``s``, ``Hz``). + Checking for a dimensionless value (not a ``Quantity`` instance) is done + by setting ``unit`` to ``None``. + + If the units do not match, ``DimensionError`` is raised. + + This function can be used in kernels and is executed at compilation time. + + There is already unit checking built into the arithmetic, so you typically + need to use this function only when using the ``amount`` property of + ``Quantity``. + + """ + if unit is None: + if isinstance(value, Quantity): + raise DimensionError + else: + if not isinstance(value, Quantity) or value.unit != unit: + raise DimensionError diff --git a/artiq/transforms/inline.py b/artiq/transforms/inline.py index 644d2887f..98851a360 100644 --- a/artiq/transforms/inline.py +++ b/artiq/transforms/inline.py @@ -3,6 +3,7 @@ import textwrap import ast import types import builtins +from copy import copy from fractions import Fraction from collections import OrderedDict from functools import partial @@ -10,7 +11,7 @@ from itertools import zip_longest, chain from artiq.language import core as core_language from artiq.language import units -from artiq.transforms.tools import value_to_ast, NotASTRepresentable +from artiq.transforms.tools import * def new_mangled_name(in_use_names, name): @@ -35,23 +36,6 @@ class AttributeInfo: self.read_write = read_write -embeddable_funcs = ( - core_language.delay, core_language.at, core_language.now, - core_language.time_to_cycles, core_language.cycles_to_time, - core_language.syscall, - range, bool, int, float, round, - core_language.int64, core_language.round64, core_language.array, - Fraction, units.Quantity, core_language.EncodedException -) - - -def is_embeddable(func): - for ef in embeddable_funcs: - if func is ef: - return True - return False - - def is_inlinable(core, func): if hasattr(func, "k_function_info"): if func.k_function_info.core_name == "": @@ -493,7 +477,7 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node): def inline(core, k_function, k_args, k_kwargs): # OrderedDict prevents non-determinism in attribute init attribute_namespace = OrderedDict() - in_use_names = {func.__name__ for func in embeddable_funcs} + in_use_names = copy(embeddable_func_names) mappers = types.SimpleNamespace( rpc=HostObjectMapper(), exception=HostObjectMapper(core_language.first_user_eid) diff --git a/artiq/transforms/lower_units.py b/artiq/transforms/lower_units.py index 13c357cfa..b73acfe93 100644 --- a/artiq/transforms/lower_units.py +++ b/artiq/transforms/lower_units.py @@ -3,6 +3,7 @@ from collections import defaultdict from copy import copy from artiq.language import units +from artiq.transforms.tools import embeddable_func_names def _add_units(f, unit_list): @@ -16,6 +17,7 @@ def _add_units(f, unit_list): class _UnitsLowerer(ast.NodeTransformer): def __init__(self, rpc_map): self.rpc_map = rpc_map + # (original rpc number, (unit list)) -> new rpc number self.rpc_remap = defaultdict(lambda: len(self.rpc_remap)) self.variable_units = dict() @@ -29,6 +31,22 @@ class _UnitsLowerer(ast.NodeTransformer): node.unit = unit return node + def visit_BoolOp(self, node): + self.generic_visit(node) + us = [getattr(value, "unit", None) for value in node.values] + if not all(u == us[0] for u in us[1:]): + raise units.DimensionError + return node + + def visit_Compare(self, node): + self.generic_visit(node) + u0 = getattr(node.left, "unit", None) + us = [getattr(comparator, "unit", None) + for comparator in node.comparators] + if not all(u == u0 for u in us): + raise units.DimensionError + return node + def visit_UnaryOp(self, node): self.generic_visit(node) if hasattr(node.operand, "unit"): @@ -47,6 +65,8 @@ class _UnitsLowerer(ast.NodeTransformer): elif op in (ast.Div, ast.FloorDiv): unit = units.div_dimension(left_unit, right_unit) else: + if left_unit is not None or right_unit is not None: + raise units.DimensionError unit = None if unit is not None: node.unit = unit @@ -66,14 +86,47 @@ class _UnitsLowerer(ast.NodeTransformer): amount, unit = node.args amount.unit = unit.s return amount - elif node.func.id == "now": + elif node.func.id in ("now", "cycles_to_time"): node.unit = "s" - elif node.func.id == "syscall" and node.args[0].s == "rpc": - unit_list = tuple(getattr(arg, "unit", None) for arg in node.args[2:]) - rpc_n = node.args[1].n - node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))] + elif node.func.id == "syscall": + # only RPCs can have units + if node.args[0].s == "rpc": + unit_list = tuple(getattr(arg, "unit", None) + for arg in node.args[2:]) + rpc_n = node.args[1].n + node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))] + else: + if any(hasattr(arg, "unit") for arg in node.args): + raise units.DimensionError + elif node.func.id in ("delay", "at", "time_to_cycles"): + if getattr(node.args[0], "unit", None) != "s": + raise units.DimensionError + elif node.func.id == "check_unit": + self.generic_visit(node) + elif node.func.id in embeddable_func_names: + # must be last (some embeddable funcs may have units) + if any(hasattr(arg, "unit") for arg in node.args): + raise units.DimensionError return node + def visit_Expr(self, node): + self.generic_visit(node) + if (isinstance(node.value, ast.Call) + and node.value.func.id == "check_unit"): + call = node.value + if (isinstance(call.args[1], ast.NameConstant) + and call.args[1].value is None): + if hasattr(call.value.args[0], "unit"): + raise units.DimensionError + elif isinstance(call.args[1], ast.Str): + if getattr(call.args[0], "unit", None) != call.args[1].s: + raise units.DimensionError + else: + raise NotImplementedError + return None + else: + return node + def _update_target(self, target, unit): if isinstance(target, ast.Name): if target.id in self.variable_units: diff --git a/artiq/transforms/tools.py b/artiq/transforms/tools.py index cfec5396c..2ae3a6cb5 100644 --- a/artiq/transforms/tools.py +++ b/artiq/transforms/tools.py @@ -5,6 +5,24 @@ from artiq.language import core as core_language from artiq.language import units +embeddable_funcs = ( + core_language.delay, core_language.at, core_language.now, + core_language.time_to_cycles, core_language.cycles_to_time, + core_language.syscall, + range, bool, int, float, round, + core_language.int64, core_language.round64, core_language.array, + Fraction, units.Quantity, units.check_unit, core_language.EncodedException +) +embeddable_func_names = {func.__name__ for func in embeddable_funcs} + + +def is_embeddable(func): + for ef in embeddable_funcs: + if func is ef: + return True + return False + + def eval_ast(expr, symdict=dict()): if not isinstance(expr, ast.Expression): expr = ast.copy_location(ast.Expression(expr), expr) diff --git a/test/full_stack.py b/test/full_stack.py index a68aa499c..0ccba86d2 100644 --- a/test/full_stack.py +++ b/test/full_stack.py @@ -4,7 +4,7 @@ import os from fractions import Fraction from artiq import * -from artiq.language.units import Quantity +from artiq.language.units import * from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio from artiq.sim import devices as sim_devices @@ -56,6 +56,22 @@ class _Misc(AutoContext): self.inhomogeneous_units.append(Quantity(1000, "Hz")) self.inhomogeneous_units.append(Quantity(10, "s")) + @kernel + def dimension_error1(self): + print(1*Hz + 1*s) + + @kernel + def dimension_error2(self): + print(1*Hz < 1*s) + + @kernel + def dimension_error3(self): + check_unit(1*Hz, "s") + + @kernel + def dimension_error4(self): + delay(10*Hz) + class _PulseLogger(AutoContext): parameters = "output_list name" @@ -163,6 +179,14 @@ class ExecutionCase(unittest.TestCase): Fraction("1.2")) self.assertEqual(uut.inhomogeneous_units, [ Quantity(1000, "Hz"), Quantity(10, "s")]) + with self.assertRaises(DimensionError): + uut.dimension_error1() + with self.assertRaises(DimensionError): + uut.dimension_error2() + with self.assertRaises(DimensionError): + uut.dimension_error3() + with self.assertRaises(DimensionError): + uut.dimension_error4() def test_pulses(self): l_device, l_host = [], []