forked from M-Labs/artiq
transforms/lower_units: fix bugs and add unit test
This commit is contained in:
parent
8d59f843fb
commit
ab88c6d0b8
|
@ -1,4 +1,6 @@
|
|||
import ast
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
|
||||
from artiq.language import units
|
||||
|
||||
|
@ -14,6 +16,7 @@ def _add_units(f, unit_list):
|
|||
class _UnitsLowerer(ast.NodeTransformer):
|
||||
def __init__(self, rpc_map):
|
||||
self.rpc_map = rpc_map
|
||||
self.rpc_remap = defaultdict(lambda: len(self.rpc_remap))
|
||||
self.variable_units = dict()
|
||||
|
||||
def visit_Name(self, node):
|
||||
|
@ -66,9 +69,9 @@ class _UnitsLowerer(ast.NodeTransformer):
|
|||
elif node.func.id == "now":
|
||||
node.unit = "s"
|
||||
elif node.func.id == "syscall" and node.args[0].s == "rpc":
|
||||
unit_list = [getattr(arg, "unit", None) for arg in node.args]
|
||||
unit_list = tuple(getattr(arg, "unit", None) for arg in node.args[2:])
|
||||
rpc_n = node.args[1].n
|
||||
self.rpc_map[rpc_n] = _add_units(self.rpc_map[rpc_n], unit_list)
|
||||
node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))]
|
||||
return node
|
||||
|
||||
def _update_target(self, target, unit):
|
||||
|
@ -105,4 +108,8 @@ class _UnitsLowerer(ast.NodeTransformer):
|
|||
|
||||
|
||||
def lower_units(func_def, rpc_map):
|
||||
_UnitsLowerer(rpc_map).visit(func_def)
|
||||
ul = _UnitsLowerer(rpc_map)
|
||||
ul.visit(func_def)
|
||||
original_map = copy(rpc_map)
|
||||
for (original_rpcn, unit_list), new_rpcn in ul.rpc_remap.items():
|
||||
rpc_map[new_rpcn] = _add_units(original_map[original_rpcn], unit_list)
|
||||
|
|
|
@ -4,6 +4,7 @@ import os
|
|||
from fractions import Fraction
|
||||
|
||||
from artiq import *
|
||||
from artiq.language.units import Quantity
|
||||
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
|
||||
from artiq.sim import devices as sim_devices
|
||||
|
||||
|
@ -44,6 +45,7 @@ class _Primes(AutoContext):
|
|||
class _Misc(AutoContext):
|
||||
def build(self):
|
||||
self.input = 84
|
||||
self.inhomogeneous_units = []
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
|
@ -51,6 +53,8 @@ class _Misc(AutoContext):
|
|||
decimal_fraction = Fraction("1.2")
|
||||
self.decimal_fraction_n = int(decimal_fraction.numerator)
|
||||
self.decimal_fraction_d = int(decimal_fraction.denominator)
|
||||
self.inhomogeneous_units.append(Quantity(1000, "Hz"))
|
||||
self.inhomogeneous_units.append(Quantity(10, "s"))
|
||||
|
||||
|
||||
class _PulseLogger(AutoContext):
|
||||
|
@ -157,6 +161,8 @@ class ExecutionCase(unittest.TestCase):
|
|||
self.assertEqual(Fraction(uut.decimal_fraction_n,
|
||||
uut.decimal_fraction_d),
|
||||
Fraction("1.2"))
|
||||
self.assertEqual(uut.inhomogeneous_units, [
|
||||
Quantity(1000, "Hz"), Quantity(10, "s")])
|
||||
|
||||
def test_pulses(self):
|
||||
l_device, l_host = [], []
|
||||
|
|
Loading…
Reference in New Issue