forked from M-Labs/artiq
support units in lists
This commit is contained in:
parent
0d10ae7580
commit
5522378c1c
|
@ -1,4 +1,5 @@
|
||||||
from artiq.language.core import *
|
from artiq.language.core import *
|
||||||
from artiq.language.context import *
|
from artiq.language.context import *
|
||||||
|
from artiq.language.units import check_unit
|
||||||
from artiq.language.units import ps, ns, us, ms, s
|
from artiq.language.units import ps, ns, us, ms, s
|
||||||
from artiq.language.units import Hz, kHz, MHz, GHz
|
from artiq.language.units import Hz, kHz, MHz, GHz
|
||||||
|
|
|
@ -4,7 +4,7 @@ import os
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
|
|
||||||
from artiq import *
|
from artiq import *
|
||||||
from artiq.language.units import *
|
from artiq.language.units import DimensionError
|
||||||
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
|
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
|
||||||
from artiq.sim import devices as sim_devices
|
from artiq.sim import devices as sim_devices
|
||||||
|
|
||||||
|
@ -51,18 +51,19 @@ class _Misc(AutoContext):
|
||||||
self.input = 84
|
self.input = 84
|
||||||
self.inhomogeneous_units = []
|
self.inhomogeneous_units = []
|
||||||
self.al = [1, 2, 3, 4, 5]
|
self.al = [1, 2, 3, 4, 5]
|
||||||
|
self.list_copy_in = [2*Hz, 10*MHz]
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def run(self):
|
def run(self):
|
||||||
self.half_input = self.input//2
|
self.half_input = self.input//2
|
||||||
decimal_fraction = Fraction("1.2")
|
self.decimal_fraction = Fraction("1.2")
|
||||||
self.decimal_fraction_n = int(decimal_fraction.numerator)
|
self.inhomogeneous_units.append(1000*Hz)
|
||||||
self.decimal_fraction_d = int(decimal_fraction.denominator)
|
self.inhomogeneous_units.append(10*s)
|
||||||
self.inhomogeneous_units.append(Quantity(1000, "Hz"))
|
|
||||||
self.inhomogeneous_units.append(Quantity(10, "s"))
|
|
||||||
self.acc = 0
|
self.acc = 0
|
||||||
for i in range(len(self.al)):
|
for i in range(len(self.al)):
|
||||||
self.acc += self.al[i]
|
self.acc += self.al[i]
|
||||||
|
self.list_copy_out = self.list_copy_in
|
||||||
|
self.unit_comp = [1*MHz for _ in range(3)]
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def dimension_error1(self):
|
def dimension_error1(self):
|
||||||
|
@ -184,12 +185,11 @@ class ExecutionCase(unittest.TestCase):
|
||||||
uut = _Misc(core=coredev)
|
uut = _Misc(core=coredev)
|
||||||
uut.run()
|
uut.run()
|
||||||
self.assertEqual(uut.half_input, 42)
|
self.assertEqual(uut.half_input, 42)
|
||||||
self.assertEqual(Fraction(uut.decimal_fraction_n,
|
self.assertEqual(uut.decimal_fraction, Fraction("1.2"))
|
||||||
uut.decimal_fraction_d),
|
self.assertEqual(uut.inhomogeneous_units, [1000*Hz, 10*s])
|
||||||
Fraction("1.2"))
|
|
||||||
self.assertEqual(uut.inhomogeneous_units, [
|
|
||||||
Quantity(1000, "Hz"), Quantity(10, "s")])
|
|
||||||
self.assertEqual(uut.acc, sum(uut.al))
|
self.assertEqual(uut.acc, sum(uut.al))
|
||||||
|
self.assertEqual(uut.list_copy_in, uut.list_copy_out)
|
||||||
|
self.assertEqual(uut.unit_comp, [1*MHz for _ in range(3)])
|
||||||
with self.assertRaises(DimensionError):
|
with self.assertRaises(DimensionError):
|
||||||
uut.dimension_error1()
|
uut.dimension_error1()
|
||||||
with self.assertRaises(DimensionError):
|
with self.assertRaises(DimensionError):
|
||||||
|
|
|
@ -8,8 +8,15 @@ from artiq.transforms.tools import embeddable_func_names
|
||||||
|
|
||||||
def _add_units(f, unit_list):
|
def _add_units(f, unit_list):
|
||||||
def wrapper(*args):
|
def wrapper(*args):
|
||||||
new_args = [arg if unit is None else units.Quantity(arg, unit)
|
new_args = []
|
||||||
for arg, unit in zip(args, unit_list)]
|
for arg, unit in zip(args, unit_list):
|
||||||
|
if unit is None:
|
||||||
|
new_args.append(arg)
|
||||||
|
else:
|
||||||
|
if isinstance(arg, list):
|
||||||
|
new_args.append([units.Quantity(x, unit) for x in arg])
|
||||||
|
else:
|
||||||
|
new_args.append(units.Quantity(arg, unit))
|
||||||
return f(*new_args)
|
return f(*new_args)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
@ -80,6 +87,20 @@ class _UnitsLowerer(ast.NodeTransformer):
|
||||||
else:
|
else:
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
def visit_List(self, node):
|
||||||
|
self.generic_visit(node)
|
||||||
|
us = [getattr(elt, "unit", None) for elt in node.elts]
|
||||||
|
if not all(u == us[0] for u in us[1:]):
|
||||||
|
raise units.DimensionError
|
||||||
|
node.unit = us[0]
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_ListComp(self, node):
|
||||||
|
self.generic_visit(node)
|
||||||
|
if hasattr(node.elt, "unit"):
|
||||||
|
node.unit = node.elt.unit
|
||||||
|
return node
|
||||||
|
|
||||||
def visit_Call(self, node):
|
def visit_Call(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
if node.func.id == "Quantity":
|
if node.func.id == "Quantity":
|
||||||
|
|
Loading…
Reference in New Issue