diff --git a/artiq/__init__.py b/artiq/__init__.py index 24fdfcad1..bba7663cb 100644 --- a/artiq/__init__.py +++ b/artiq/__init__.py @@ -1,4 +1,5 @@ from artiq.language.core import * from artiq.language.context import * +from artiq.language.units import check_unit from artiq.language.units import ps, ns, us, ms, s from artiq.language.units import Hz, kHz, MHz, GHz diff --git a/artiq/test/full_stack.py b/artiq/test/full_stack.py index fd19ac8f6..3e89063ec 100644 --- a/artiq/test/full_stack.py +++ b/artiq/test/full_stack.py @@ -4,7 +4,7 @@ import os from fractions import Fraction from artiq import * -from artiq.language.units import * +from artiq.language.units import DimensionError from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio from artiq.sim import devices as sim_devices @@ -51,18 +51,19 @@ class _Misc(AutoContext): self.input = 84 self.inhomogeneous_units = [] self.al = [1, 2, 3, 4, 5] + self.list_copy_in = [2*Hz, 10*MHz] @kernel def run(self): self.half_input = self.input//2 - decimal_fraction = Fraction("1.2") - self.decimal_fraction_n = int(decimal_fraction.numerator) - self.decimal_fraction_d = int(decimal_fraction.denominator) - self.inhomogeneous_units.append(Quantity(1000, "Hz")) - self.inhomogeneous_units.append(Quantity(10, "s")) + self.decimal_fraction = Fraction("1.2") + self.inhomogeneous_units.append(1000*Hz) + self.inhomogeneous_units.append(10*s) self.acc = 0 for i in range(len(self.al)): self.acc += self.al[i] + self.list_copy_out = self.list_copy_in + self.unit_comp = [1*MHz for _ in range(3)] @kernel def dimension_error1(self): @@ -184,12 +185,11 @@ class ExecutionCase(unittest.TestCase): uut = _Misc(core=coredev) uut.run() self.assertEqual(uut.half_input, 42) - self.assertEqual(Fraction(uut.decimal_fraction_n, - uut.decimal_fraction_d), - Fraction("1.2")) - self.assertEqual(uut.inhomogeneous_units, [ - Quantity(1000, "Hz"), Quantity(10, "s")]) + self.assertEqual(uut.decimal_fraction, Fraction("1.2")) + self.assertEqual(uut.inhomogeneous_units, [1000*Hz, 10*s]) self.assertEqual(uut.acc, sum(uut.al)) + self.assertEqual(uut.list_copy_in, uut.list_copy_out) + self.assertEqual(uut.unit_comp, [1*MHz for _ in range(3)]) with self.assertRaises(DimensionError): uut.dimension_error1() with self.assertRaises(DimensionError): diff --git a/artiq/transforms/lower_units.py b/artiq/transforms/lower_units.py index b73acfe93..7573c2d3d 100644 --- a/artiq/transforms/lower_units.py +++ b/artiq/transforms/lower_units.py @@ -8,8 +8,15 @@ from artiq.transforms.tools import embeddable_func_names def _add_units(f, unit_list): def wrapper(*args): - new_args = [arg if unit is None else units.Quantity(arg, unit) - for arg, unit in zip(args, unit_list)] + new_args = [] + for arg, unit in zip(args, unit_list): + if unit is None: + new_args.append(arg) + else: + if isinstance(arg, list): + new_args.append([units.Quantity(x, unit) for x in arg]) + else: + new_args.append(units.Quantity(arg, unit)) return f(*new_args) return wrapper @@ -80,6 +87,20 @@ class _UnitsLowerer(ast.NodeTransformer): else: return node + def visit_List(self, node): + self.generic_visit(node) + us = [getattr(elt, "unit", None) for elt in node.elts] + if not all(u == us[0] for u in us[1:]): + raise units.DimensionError + node.unit = us[0] + return node + + def visit_ListComp(self, node): + self.generic_visit(node) + if hasattr(node.elt, "unit"): + node.unit = node.elt.unit + return node + def visit_Call(self, node): self.generic_visit(node) if node.func.id == "Quantity":