transforms/lower_units: fix bugs and add unit test

This commit is contained in:
Sebastien Bourdeauducq 2014-11-21 18:08:14 -08:00
parent 8d59f843fb
commit ab88c6d0b8
2 changed files with 16 additions and 3 deletions

View File

@ -1,4 +1,6 @@
import ast import ast
from collections import defaultdict
from copy import copy
from artiq.language import units from artiq.language import units
@ -14,6 +16,7 @@ def _add_units(f, unit_list):
class _UnitsLowerer(ast.NodeTransformer): class _UnitsLowerer(ast.NodeTransformer):
def __init__(self, rpc_map): def __init__(self, rpc_map):
self.rpc_map = rpc_map self.rpc_map = rpc_map
self.rpc_remap = defaultdict(lambda: len(self.rpc_remap))
self.variable_units = dict() self.variable_units = dict()
def visit_Name(self, node): def visit_Name(self, node):
@ -66,9 +69,9 @@ class _UnitsLowerer(ast.NodeTransformer):
elif node.func.id == "now": elif node.func.id == "now":
node.unit = "s" node.unit = "s"
elif node.func.id == "syscall" and node.args[0].s == "rpc": elif node.func.id == "syscall" and node.args[0].s == "rpc":
unit_list = [getattr(arg, "unit", None) for arg in node.args] unit_list = tuple(getattr(arg, "unit", None) for arg in node.args[2:])
rpc_n = node.args[1].n rpc_n = node.args[1].n
self.rpc_map[rpc_n] = _add_units(self.rpc_map[rpc_n], unit_list) node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))]
return node return node
def _update_target(self, target, unit): def _update_target(self, target, unit):
@ -105,4 +108,8 @@ class _UnitsLowerer(ast.NodeTransformer):
def lower_units(func_def, rpc_map): def lower_units(func_def, rpc_map):
_UnitsLowerer(rpc_map).visit(func_def) ul = _UnitsLowerer(rpc_map)
ul.visit(func_def)
original_map = copy(rpc_map)
for (original_rpcn, unit_list), new_rpcn in ul.rpc_remap.items():
rpc_map[new_rpcn] = _add_units(original_map[original_rpcn], unit_list)

View File

@ -4,6 +4,7 @@ import os
from fractions import Fraction from fractions import Fraction
from artiq import * from artiq import *
from artiq.language.units import Quantity
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
from artiq.sim import devices as sim_devices from artiq.sim import devices as sim_devices
@ -44,6 +45,7 @@ class _Primes(AutoContext):
class _Misc(AutoContext): class _Misc(AutoContext):
def build(self): def build(self):
self.input = 84 self.input = 84
self.inhomogeneous_units = []
@kernel @kernel
def run(self): def run(self):
@ -51,6 +53,8 @@ class _Misc(AutoContext):
decimal_fraction = Fraction("1.2") decimal_fraction = Fraction("1.2")
self.decimal_fraction_n = int(decimal_fraction.numerator) self.decimal_fraction_n = int(decimal_fraction.numerator)
self.decimal_fraction_d = int(decimal_fraction.denominator) self.decimal_fraction_d = int(decimal_fraction.denominator)
self.inhomogeneous_units.append(Quantity(1000, "Hz"))
self.inhomogeneous_units.append(Quantity(10, "s"))
class _PulseLogger(AutoContext): class _PulseLogger(AutoContext):
@ -157,6 +161,8 @@ class ExecutionCase(unittest.TestCase):
self.assertEqual(Fraction(uut.decimal_fraction_n, self.assertEqual(Fraction(uut.decimal_fraction_n,
uut.decimal_fraction_d), uut.decimal_fraction_d),
Fraction("1.2")) Fraction("1.2"))
self.assertEqual(uut.inhomogeneous_units, [
Quantity(1000, "Hz"), Quantity(10, "s")])
def test_pulses(self): def test_pulses(self):
l_device, l_host = [], [] l_device, l_host = [], []