forked from M-Labs/artiq
support units in lists
This commit is contained in:
parent
0d10ae7580
commit
5522378c1c
|
@ -1,4 +1,5 @@
|
|||
from artiq.language.core import *
|
||||
from artiq.language.context import *
|
||||
from artiq.language.units import check_unit
|
||||
from artiq.language.units import ps, ns, us, ms, s
|
||||
from artiq.language.units import Hz, kHz, MHz, GHz
|
||||
|
|
|
@ -4,7 +4,7 @@ import os
|
|||
from fractions import Fraction
|
||||
|
||||
from artiq import *
|
||||
from artiq.language.units import *
|
||||
from artiq.language.units import DimensionError
|
||||
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
|
||||
from artiq.sim import devices as sim_devices
|
||||
|
||||
|
@ -51,18 +51,19 @@ class _Misc(AutoContext):
|
|||
self.input = 84
|
||||
self.inhomogeneous_units = []
|
||||
self.al = [1, 2, 3, 4, 5]
|
||||
self.list_copy_in = [2*Hz, 10*MHz]
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
self.half_input = self.input//2
|
||||
decimal_fraction = Fraction("1.2")
|
||||
self.decimal_fraction_n = int(decimal_fraction.numerator)
|
||||
self.decimal_fraction_d = int(decimal_fraction.denominator)
|
||||
self.inhomogeneous_units.append(Quantity(1000, "Hz"))
|
||||
self.inhomogeneous_units.append(Quantity(10, "s"))
|
||||
self.decimal_fraction = Fraction("1.2")
|
||||
self.inhomogeneous_units.append(1000*Hz)
|
||||
self.inhomogeneous_units.append(10*s)
|
||||
self.acc = 0
|
||||
for i in range(len(self.al)):
|
||||
self.acc += self.al[i]
|
||||
self.list_copy_out = self.list_copy_in
|
||||
self.unit_comp = [1*MHz for _ in range(3)]
|
||||
|
||||
@kernel
|
||||
def dimension_error1(self):
|
||||
|
@ -184,12 +185,11 @@ class ExecutionCase(unittest.TestCase):
|
|||
uut = _Misc(core=coredev)
|
||||
uut.run()
|
||||
self.assertEqual(uut.half_input, 42)
|
||||
self.assertEqual(Fraction(uut.decimal_fraction_n,
|
||||
uut.decimal_fraction_d),
|
||||
Fraction("1.2"))
|
||||
self.assertEqual(uut.inhomogeneous_units, [
|
||||
Quantity(1000, "Hz"), Quantity(10, "s")])
|
||||
self.assertEqual(uut.decimal_fraction, Fraction("1.2"))
|
||||
self.assertEqual(uut.inhomogeneous_units, [1000*Hz, 10*s])
|
||||
self.assertEqual(uut.acc, sum(uut.al))
|
||||
self.assertEqual(uut.list_copy_in, uut.list_copy_out)
|
||||
self.assertEqual(uut.unit_comp, [1*MHz for _ in range(3)])
|
||||
with self.assertRaises(DimensionError):
|
||||
uut.dimension_error1()
|
||||
with self.assertRaises(DimensionError):
|
||||
|
|
|
@ -8,8 +8,15 @@ from artiq.transforms.tools import embeddable_func_names
|
|||
|
||||
def _add_units(f, unit_list):
|
||||
def wrapper(*args):
|
||||
new_args = [arg if unit is None else units.Quantity(arg, unit)
|
||||
for arg, unit in zip(args, unit_list)]
|
||||
new_args = []
|
||||
for arg, unit in zip(args, unit_list):
|
||||
if unit is None:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
if isinstance(arg, list):
|
||||
new_args.append([units.Quantity(x, unit) for x in arg])
|
||||
else:
|
||||
new_args.append(units.Quantity(arg, unit))
|
||||
return f(*new_args)
|
||||
return wrapper
|
||||
|
||||
|
@ -80,6 +87,20 @@ class _UnitsLowerer(ast.NodeTransformer):
|
|||
else:
|
||||
return node
|
||||
|
||||
def visit_List(self, node):
|
||||
self.generic_visit(node)
|
||||
us = [getattr(elt, "unit", None) for elt in node.elts]
|
||||
if not all(u == us[0] for u in us[1:]):
|
||||
raise units.DimensionError
|
||||
node.unit = us[0]
|
||||
return node
|
||||
|
||||
def visit_ListComp(self, node):
|
||||
self.generic_visit(node)
|
||||
if hasattr(node.elt, "unit"):
|
||||
node.unit = node.elt.unit
|
||||
return node
|
||||
|
||||
def visit_Call(self, node):
|
||||
self.generic_visit(node)
|
||||
if node.func.id == "Quantity":
|
||||
|
|
Loading…
Reference in New Issue