diff --git a/artiq/transforms/lower_units.py b/artiq/transforms/lower_units.py index 917a7a0cf..13c357cfa 100644 --- a/artiq/transforms/lower_units.py +++ b/artiq/transforms/lower_units.py @@ -1,4 +1,6 @@ import ast +from collections import defaultdict +from copy import copy from artiq.language import units @@ -14,6 +16,7 @@ def _add_units(f, unit_list): class _UnitsLowerer(ast.NodeTransformer): def __init__(self, rpc_map): self.rpc_map = rpc_map + self.rpc_remap = defaultdict(lambda: len(self.rpc_remap)) self.variable_units = dict() def visit_Name(self, node): @@ -66,9 +69,9 @@ class _UnitsLowerer(ast.NodeTransformer): elif node.func.id == "now": node.unit = "s" elif node.func.id == "syscall" and node.args[0].s == "rpc": - unit_list = [getattr(arg, "unit", None) for arg in node.args] + unit_list = tuple(getattr(arg, "unit", None) for arg in node.args[2:]) rpc_n = node.args[1].n - self.rpc_map[rpc_n] = _add_units(self.rpc_map[rpc_n], unit_list) + node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))] return node def _update_target(self, target, unit): @@ -105,4 +108,8 @@ class _UnitsLowerer(ast.NodeTransformer): def lower_units(func_def, rpc_map): - _UnitsLowerer(rpc_map).visit(func_def) + ul = _UnitsLowerer(rpc_map) + ul.visit(func_def) + original_map = copy(rpc_map) + for (original_rpcn, unit_list), new_rpcn in ul.rpc_remap.items(): + rpc_map[new_rpcn] = _add_units(original_map[original_rpcn], unit_list) diff --git a/test/full_stack.py b/test/full_stack.py index d8697b0b9..a68aa499c 100644 --- a/test/full_stack.py +++ b/test/full_stack.py @@ -4,6 +4,7 @@ import os from fractions import Fraction from artiq import * +from artiq.language.units import Quantity from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio from artiq.sim import devices as sim_devices @@ -44,6 +45,7 @@ class _Primes(AutoContext): class _Misc(AutoContext): def build(self): self.input = 84 + self.inhomogeneous_units = [] @kernel def run(self): @@ -51,6 +53,8 @@ class _Misc(AutoContext): decimal_fraction = Fraction("1.2") self.decimal_fraction_n = int(decimal_fraction.numerator) self.decimal_fraction_d = int(decimal_fraction.denominator) + self.inhomogeneous_units.append(Quantity(1000, "Hz")) + self.inhomogeneous_units.append(Quantity(10, "s")) class _PulseLogger(AutoContext): @@ -157,6 +161,8 @@ class ExecutionCase(unittest.TestCase): self.assertEqual(Fraction(uut.decimal_fraction_n, uut.decimal_fraction_d), Fraction("1.2")) + self.assertEqual(uut.inhomogeneous_units, [ + Quantity(1000, "Hz"), Quantity(10, "s")]) def test_pulses(self): l_device, l_host = [], []