2
0
mirror of https://github.com/m-labs/artiq.git synced 2024-12-26 11:48:27 +08:00

support units in lists

This commit is contained in:
Sebastien Bourdeauducq 2014-12-19 14:34:23 +08:00
parent 0d10ae7580
commit 5522378c1c
3 changed files with 35 additions and 13 deletions

View File

@ -1,4 +1,5 @@
from artiq.language.core import * from artiq.language.core import *
from artiq.language.context import * from artiq.language.context import *
from artiq.language.units import check_unit
from artiq.language.units import ps, ns, us, ms, s from artiq.language.units import ps, ns, us, ms, s
from artiq.language.units import Hz, kHz, MHz, GHz from artiq.language.units import Hz, kHz, MHz, GHz

View File

@ -4,7 +4,7 @@ import os
from fractions import Fraction from fractions import Fraction
from artiq import * from artiq import *
from artiq.language.units import * from artiq.language.units import DimensionError
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
from artiq.sim import devices as sim_devices from artiq.sim import devices as sim_devices
@ -51,18 +51,19 @@ class _Misc(AutoContext):
self.input = 84 self.input = 84
self.inhomogeneous_units = [] self.inhomogeneous_units = []
self.al = [1, 2, 3, 4, 5] self.al = [1, 2, 3, 4, 5]
self.list_copy_in = [2*Hz, 10*MHz]
@kernel @kernel
def run(self): def run(self):
self.half_input = self.input//2 self.half_input = self.input//2
decimal_fraction = Fraction("1.2") self.decimal_fraction = Fraction("1.2")
self.decimal_fraction_n = int(decimal_fraction.numerator) self.inhomogeneous_units.append(1000*Hz)
self.decimal_fraction_d = int(decimal_fraction.denominator) self.inhomogeneous_units.append(10*s)
self.inhomogeneous_units.append(Quantity(1000, "Hz"))
self.inhomogeneous_units.append(Quantity(10, "s"))
self.acc = 0 self.acc = 0
for i in range(len(self.al)): for i in range(len(self.al)):
self.acc += self.al[i] self.acc += self.al[i]
self.list_copy_out = self.list_copy_in
self.unit_comp = [1*MHz for _ in range(3)]
@kernel @kernel
def dimension_error1(self): def dimension_error1(self):
@ -184,12 +185,11 @@ class ExecutionCase(unittest.TestCase):
uut = _Misc(core=coredev) uut = _Misc(core=coredev)
uut.run() uut.run()
self.assertEqual(uut.half_input, 42) self.assertEqual(uut.half_input, 42)
self.assertEqual(Fraction(uut.decimal_fraction_n, self.assertEqual(uut.decimal_fraction, Fraction("1.2"))
uut.decimal_fraction_d), self.assertEqual(uut.inhomogeneous_units, [1000*Hz, 10*s])
Fraction("1.2"))
self.assertEqual(uut.inhomogeneous_units, [
Quantity(1000, "Hz"), Quantity(10, "s")])
self.assertEqual(uut.acc, sum(uut.al)) self.assertEqual(uut.acc, sum(uut.al))
self.assertEqual(uut.list_copy_in, uut.list_copy_out)
self.assertEqual(uut.unit_comp, [1*MHz for _ in range(3)])
with self.assertRaises(DimensionError): with self.assertRaises(DimensionError):
uut.dimension_error1() uut.dimension_error1()
with self.assertRaises(DimensionError): with self.assertRaises(DimensionError):

View File

@ -8,8 +8,15 @@ from artiq.transforms.tools import embeddable_func_names
def _add_units(f, unit_list): def _add_units(f, unit_list):
def wrapper(*args): def wrapper(*args):
new_args = [arg if unit is None else units.Quantity(arg, unit) new_args = []
for arg, unit in zip(args, unit_list)] for arg, unit in zip(args, unit_list):
if unit is None:
new_args.append(arg)
else:
if isinstance(arg, list):
new_args.append([units.Quantity(x, unit) for x in arg])
else:
new_args.append(units.Quantity(arg, unit))
return f(*new_args) return f(*new_args)
return wrapper return wrapper
@ -80,6 +87,20 @@ class _UnitsLowerer(ast.NodeTransformer):
else: else:
return node return node
def visit_List(self, node):
self.generic_visit(node)
us = [getattr(elt, "unit", None) for elt in node.elts]
if not all(u == us[0] for u in us[1:]):
raise units.DimensionError
node.unit = us[0]
return node
def visit_ListComp(self, node):
self.generic_visit(node)
if hasattr(node.elt, "unit"):
node.unit = node.elt.unit
return node
def visit_Call(self, node): def visit_Call(self, node):
self.generic_visit(node) self.generic_visit(node)
if node.func.id == "Quantity": if node.func.id == "Quantity":