From e22301ea0589b14648adc566d5122780c2eff91e Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Mon, 6 Oct 2014 23:28:56 +0800 Subject: [PATCH] transforms: track units, now() returns seconds, implement time_to_cycles and cycles_to_time --- artiq/devices/core.py | 6 +- artiq/devices/dds_core.py | 4 +- artiq/devices/rtio_core.py | 14 ++-- artiq/transforms/lower_time.py | 53 +++++++++++--- artiq/transforms/lower_units.py | 124 ++++++++++++++++++++++++-------- 5 files changed, 153 insertions(+), 48 deletions(-) diff --git a/artiq/devices/core.py b/artiq/devices/core.py index 8635f75a6..0cdd1bd43 100644 --- a/artiq/devices/core.py +++ b/artiq/devices/core.py @@ -48,7 +48,7 @@ class Core: self, k_function, k_args, k_kwargs) _debug_unparse("inline", func_def) - lower_units(func_def, self.runtime_env.ref_period) + lower_units(func_def, rpc_map) _debug_unparse("lower_units", func_def) fold_constants(func_def) @@ -60,7 +60,9 @@ class Core: interleave(func_def) _debug_unparse("interleave", func_def) - lower_time(func_def, getattr(self.runtime_env, "initial_time", 0)) + lower_time(func_def, + getattr(self.runtime_env, "initial_time", 0), + self.runtime_env.ref_period) _debug_unparse("lower_time", func_def) fold_constants(func_def) diff --git a/artiq/devices/dds_core.py b/artiq/devices/dds_core.py index c0dbd1fcf..d320b7510 100644 --- a/artiq/devices/dds_core.py +++ b/artiq/devices/dds_core.py @@ -34,12 +34,12 @@ class DDS(AutoContext): """ if self.previous_frequency != frequency: - if self.sw.previous_timestamp != now(): + if self.sw.previous_timestamp != time_to_cycles(now()): self.sw.sync() if self.sw.previous_value: # Channel is already on. # Precise timing of frequency change is required. - fud_time = now() + fud_time = time_to_cycles(now()) else: # Channel is off. # Use soft timing on FUD to prevent conflicts when diff --git a/artiq/devices/rtio_core.py b/artiq/devices/rtio_core.py index 88143a919..4c95cc4fa 100644 --- a/artiq/devices/rtio_core.py +++ b/artiq/devices/rtio_core.py @@ -6,7 +6,7 @@ class _RTIOBase(AutoContext): parameters = "channel" def build(self): - self.previous_timestamp = int64(0) + self.previous_timestamp = int64(0) # in RTIO cycles self.previous_value = 0 kernel_attr = "previous_timestamp previous_value" @@ -17,14 +17,16 @@ class _RTIOBase(AutoContext): @kernel def _set_value(self, value): - if now() < self.previous_timestamp: + if time_to_cycles(now()) < self.previous_timestamp: raise RTIOSequenceError if self.previous_value != value: - if self.previous_timestamp == now(): - syscall("rtio_replace", now(), self.channel, value) + if self.previous_timestamp == time_to_cycles(now()): + syscall("rtio_replace", time_to_cycles(now()), + self.channel, value) else: - syscall("rtio_set", now(), self.channel, value) - self.previous_timestamp = now() + syscall("rtio_set", time_to_cycles(now()), + self.channel, value) + self.previous_timestamp = time_to_cycles(now()) self.previous_value = value diff --git a/artiq/transforms/lower_time.py b/artiq/transforms/lower_time.py index 76d2eab58..c5ee3159c 100644 --- a/artiq/transforms/lower_time.py +++ b/artiq/transforms/lower_time.py @@ -4,21 +4,52 @@ from artiq.transforms.tools import value_to_ast from artiq.language.core import int64 -def _insert_int64(node): +def _time_to_cycles(ref_period, node): + divided = ast.copy_location( + ast.BinOp(left=node, + op=ast.Div(), + right=value_to_ast(ref_period)), + node) return ast.copy_location( ast.Call(func=ast.Name("int64", ast.Load()), - args=[node], + args=[divided], keywords=[], starargs=[], kwargs=[]), + divided) + + +def _cycles_to_time(ref_period, node): + return ast.copy_location( + ast.BinOp(left=node, + op=ast.Mult(), + right=value_to_ast(ref_period)), node) class _TimeLowerer(ast.NodeTransformer): + def __init__(self, ref_period): + self.ref_period = ref_period + def visit_Call(self, node): - if isinstance(node.func, ast.Name) and node.func.id == "now": + # optimize time_to_cycles(now()) -> now + if (isinstance(node.func, ast.Name) + and node.func.id == "time_to_cycles" + and isinstance(node.args[0], ast.Call) + and isinstance(node.args[0].func, ast.Name) + and node.args[0].func.id == "now"): return ast.copy_location(ast.Name("now", ast.Load()), node) - else: - self.generic_visit(node) - return node + + self.generic_visit(node) + if isinstance(node.func, ast.Name): + funcname = node.func.id + if funcname == "now": + return _cycles_to_time( + self.ref_period, + ast.copy_location(ast.Name("now", ast.Load()), node)) + elif funcname == "time_to_cycles": + return _time_to_cycles(self.ref_period, node) + elif funcname == "cycles_to_time": + return _cycles_to_time(self.ref_period, node) + return node def visit_Expr(self, node): self.generic_visit(node) @@ -29,12 +60,14 @@ class _TimeLowerer(ast.NodeTransformer): return ast.copy_location( ast.AugAssign(target=ast.Name("now", ast.Store()), op=ast.Add(), - value=_insert_int64(node.value.args[0])), + value=_time_to_cycles(self.ref_period, + node.value.args[0])), node) elif funcname == "at": return ast.copy_location( ast.Assign(targets=[ast.Name("now", ast.Store())], - value=_insert_int64(node.value.args[0])), + value=_time_to_cycles(self.ref_period, + node.value.args[0])), node) else: return node @@ -42,8 +75,8 @@ class _TimeLowerer(ast.NodeTransformer): return node -def lower_time(func_def, initial_time): - _TimeLowerer().visit(func_def) +def lower_time(func_def, initial_time, ref_period): + _TimeLowerer(ref_period).visit(func_def) func_def.body.insert(0, ast.copy_location( ast.Assign(targets=[ast.Name("now", ast.Store())], value=value_to_ast(int64(initial_time))), diff --git a/artiq/transforms/lower_units.py b/artiq/transforms/lower_units.py index c5914c029..3428ea11e 100644 --- a/artiq/transforms/lower_units.py +++ b/artiq/transforms/lower_units.py @@ -1,40 +1,108 @@ import ast -from artiq.transforms.tools import value_to_ast +from artiq.language import units -# TODO: -# * track variable and expression dimensions -# * raise exception on dimension errors in expressions -# * modify RPC map to reintroduce units -# * handle core time conversion outside of delay/at, -# e.g. foo = now() + 1*us [...] at(foo) +def _add_units(f, unit_list): + def wrapper(*args): + new_args = [arg if unit is None else units.Quantity(arg, unit) + for arg, unit in zip(args, unit_list)] + return f(*new_args) + return wrapper + class _UnitsLowerer(ast.NodeTransformer): - def __init__(self, ref_period): - self.ref_period = ref_period - self.in_core_time = False + def __init__(self, rpc_map): + self.rpc_map = rpc_map + self.variable_units = dict() + + def visit_Name(self, node): + try: + unit = self.variable_units[node.id] + except KeyError: + pass + else: + if unit is not None: + node.unit = unit + return node + + def visit_UnaryOp(self, node): + self.generic_visit(node) + if hasattr(node.operand, "unit"): + node.unit = node.operand.unit + return node + + def visit_BinOp(self, node): + self.generic_visit(node) + op = type(node.op) + left_unit = getattr(node.left, "unit", None) + right_unit = getattr(node.right, "unit", None) + if op in (ast.Add, ast.Sub, ast.Mod): + unit = units.addsub_dimension(left_unit, right_unit) + elif op == ast.Mult: + unit = units.mul_dimension(left_unit, right_unit) + elif op in (ast.Div, ast.FloorDiv): + unit = units.div_dimension(left_unit, right_unit) + else: + unit = None + if unit is not None: + node.unit = unit + return node + + def visit_Attribute(self, node): + self.generic_visit(node) + if node.attr == "amount" and hasattr(node.value, "unit"): + del node.value.unit + return node.value + else: + return node def visit_Call(self, node): - fn = node.func.id - if fn in ("delay", "at"): - old_in_core_time = self.in_core_time - self.in_core_time = True - self.generic_visit(node) - self.in_core_time = old_in_core_time - elif fn == "Quantity": - if self.in_core_time: - node = ast.copy_location( - ast.BinOp(left=node.args[0], - op=ast.Div(), - right=value_to_ast(self.ref_period)), - node) + self.generic_visit(node) + if node.func.id == "Quantity": + amount, unit = node.args + amount.unit = unit.s + return amount + elif node.func.id == "now": + node.unit = "s" + elif node.func.id == "syscall" and node.args[0].s == "rpc": + unit_list = [getattr(arg, "unit", None) for arg in node.args] + rpc_n = node.args[1].n + self.rpc_map[rpc_n] = _add_units(self.rpc_map[rpc_n], unit_list) + return node + + def _update_target(self, target, unit): + if isinstance(target, ast.Name): + if target.id in self.variable_units: + if self.variable_units[target.id] != unit: + raise TypeError( + "Inconsistent units for variable '{}': '{}' and '{}'" + .format(target.id, + self.variable_units[target.id], + unit)) else: - node = node.args[0] - else: - self.generic_visit(node) + self.variable_units[target.id] = unit + + def visit_Assign(self, node): + node.value = self.visit(node.value) + unit = getattr(node.value, "unit", None) + for target in node.targets: + self._update_target(target, unit) + return node + + def visit_AugAssign(self, node): + value = self.visit_BinOp(ast.BinOp( + op=node.op, left=node.target, right=node.value)) + unit = getattr(value, "unit", None) + self._update_target(node.target, unit) + return node + + # Only dimensionless iterators are supported + def visit_For(self, node): + self.generic_visit(node) + self._update_target(node.target, None) return node -def lower_units(func_def, ref_period): - _UnitsLowerer(ref_period).visit(func_def) +def lower_units(func_def, rpc_map): + _UnitsLowerer(rpc_map).visit(func_def)