transforms: track units, now() returns seconds, implement time_to_cycles and cycles_to_time

This commit is contained in:
Sebastien Bourdeauducq 2014-10-06 23:28:56 +08:00
parent 1a64e92e75
commit e22301ea05
5 changed files with 153 additions and 48 deletions

View File

@ -48,7 +48,7 @@ class Core:
self, k_function, k_args, k_kwargs) self, k_function, k_args, k_kwargs)
_debug_unparse("inline", func_def) _debug_unparse("inline", func_def)
lower_units(func_def, self.runtime_env.ref_period) lower_units(func_def, rpc_map)
_debug_unparse("lower_units", func_def) _debug_unparse("lower_units", func_def)
fold_constants(func_def) fold_constants(func_def)
@ -60,7 +60,9 @@ class Core:
interleave(func_def) interleave(func_def)
_debug_unparse("interleave", func_def) _debug_unparse("interleave", func_def)
lower_time(func_def, getattr(self.runtime_env, "initial_time", 0)) lower_time(func_def,
getattr(self.runtime_env, "initial_time", 0),
self.runtime_env.ref_period)
_debug_unparse("lower_time", func_def) _debug_unparse("lower_time", func_def)
fold_constants(func_def) fold_constants(func_def)

View File

@ -34,12 +34,12 @@ class DDS(AutoContext):
""" """
if self.previous_frequency != frequency: if self.previous_frequency != frequency:
if self.sw.previous_timestamp != now(): if self.sw.previous_timestamp != time_to_cycles(now()):
self.sw.sync() self.sw.sync()
if self.sw.previous_value: if self.sw.previous_value:
# Channel is already on. # Channel is already on.
# Precise timing of frequency change is required. # Precise timing of frequency change is required.
fud_time = now() fud_time = time_to_cycles(now())
else: else:
# Channel is off. # Channel is off.
# Use soft timing on FUD to prevent conflicts when # Use soft timing on FUD to prevent conflicts when

View File

@ -6,7 +6,7 @@ class _RTIOBase(AutoContext):
parameters = "channel" parameters = "channel"
def build(self): def build(self):
self.previous_timestamp = int64(0) self.previous_timestamp = int64(0) # in RTIO cycles
self.previous_value = 0 self.previous_value = 0
kernel_attr = "previous_timestamp previous_value" kernel_attr = "previous_timestamp previous_value"
@ -17,14 +17,16 @@ class _RTIOBase(AutoContext):
@kernel @kernel
def _set_value(self, value): def _set_value(self, value):
if now() < self.previous_timestamp: if time_to_cycles(now()) < self.previous_timestamp:
raise RTIOSequenceError raise RTIOSequenceError
if self.previous_value != value: if self.previous_value != value:
if self.previous_timestamp == now(): if self.previous_timestamp == time_to_cycles(now()):
syscall("rtio_replace", now(), self.channel, value) syscall("rtio_replace", time_to_cycles(now()),
self.channel, value)
else: else:
syscall("rtio_set", now(), self.channel, value) syscall("rtio_set", time_to_cycles(now()),
self.previous_timestamp = now() self.channel, value)
self.previous_timestamp = time_to_cycles(now())
self.previous_value = value self.previous_value = value

View File

@ -4,21 +4,52 @@ from artiq.transforms.tools import value_to_ast
from artiq.language.core import int64 from artiq.language.core import int64
def _insert_int64(node): def _time_to_cycles(ref_period, node):
divided = ast.copy_location(
ast.BinOp(left=node,
op=ast.Div(),
right=value_to_ast(ref_period)),
node)
return ast.copy_location( return ast.copy_location(
ast.Call(func=ast.Name("int64", ast.Load()), ast.Call(func=ast.Name("int64", ast.Load()),
args=[node], args=[divided],
keywords=[], starargs=[], kwargs=[]), keywords=[], starargs=[], kwargs=[]),
divided)
def _cycles_to_time(ref_period, node):
return ast.copy_location(
ast.BinOp(left=node,
op=ast.Mult(),
right=value_to_ast(ref_period)),
node) node)
class _TimeLowerer(ast.NodeTransformer): class _TimeLowerer(ast.NodeTransformer):
def __init__(self, ref_period):
self.ref_period = ref_period
def visit_Call(self, node): def visit_Call(self, node):
if isinstance(node.func, ast.Name) and node.func.id == "now": # optimize time_to_cycles(now()) -> now
if (isinstance(node.func, ast.Name)
and node.func.id == "time_to_cycles"
and isinstance(node.args[0], ast.Call)
and isinstance(node.args[0].func, ast.Name)
and node.args[0].func.id == "now"):
return ast.copy_location(ast.Name("now", ast.Load()), node) return ast.copy_location(ast.Name("now", ast.Load()), node)
else:
self.generic_visit(node) self.generic_visit(node)
return node if isinstance(node.func, ast.Name):
funcname = node.func.id
if funcname == "now":
return _cycles_to_time(
self.ref_period,
ast.copy_location(ast.Name("now", ast.Load()), node))
elif funcname == "time_to_cycles":
return _time_to_cycles(self.ref_period, node)
elif funcname == "cycles_to_time":
return _cycles_to_time(self.ref_period, node)
return node
def visit_Expr(self, node): def visit_Expr(self, node):
self.generic_visit(node) self.generic_visit(node)
@ -29,12 +60,14 @@ class _TimeLowerer(ast.NodeTransformer):
return ast.copy_location( return ast.copy_location(
ast.AugAssign(target=ast.Name("now", ast.Store()), ast.AugAssign(target=ast.Name("now", ast.Store()),
op=ast.Add(), op=ast.Add(),
value=_insert_int64(node.value.args[0])), value=_time_to_cycles(self.ref_period,
node.value.args[0])),
node) node)
elif funcname == "at": elif funcname == "at":
return ast.copy_location( return ast.copy_location(
ast.Assign(targets=[ast.Name("now", ast.Store())], ast.Assign(targets=[ast.Name("now", ast.Store())],
value=_insert_int64(node.value.args[0])), value=_time_to_cycles(self.ref_period,
node.value.args[0])),
node) node)
else: else:
return node return node
@ -42,8 +75,8 @@ class _TimeLowerer(ast.NodeTransformer):
return node return node
def lower_time(func_def, initial_time): def lower_time(func_def, initial_time, ref_period):
_TimeLowerer().visit(func_def) _TimeLowerer(ref_period).visit(func_def)
func_def.body.insert(0, ast.copy_location( func_def.body.insert(0, ast.copy_location(
ast.Assign(targets=[ast.Name("now", ast.Store())], ast.Assign(targets=[ast.Name("now", ast.Store())],
value=value_to_ast(int64(initial_time))), value=value_to_ast(int64(initial_time))),

View File

@ -1,40 +1,108 @@
import ast import ast
from artiq.transforms.tools import value_to_ast from artiq.language import units
# TODO: def _add_units(f, unit_list):
# * track variable and expression dimensions def wrapper(*args):
# * raise exception on dimension errors in expressions new_args = [arg if unit is None else units.Quantity(arg, unit)
# * modify RPC map to reintroduce units for arg, unit in zip(args, unit_list)]
# * handle core time conversion outside of delay/at, return f(*new_args)
# e.g. foo = now() + 1*us [...] at(foo) return wrapper
class _UnitsLowerer(ast.NodeTransformer): class _UnitsLowerer(ast.NodeTransformer):
def __init__(self, ref_period): def __init__(self, rpc_map):
self.ref_period = ref_period self.rpc_map = rpc_map
self.in_core_time = False self.variable_units = dict()
def visit_Name(self, node):
try:
unit = self.variable_units[node.id]
except KeyError:
pass
else:
if unit is not None:
node.unit = unit
return node
def visit_UnaryOp(self, node):
self.generic_visit(node)
if hasattr(node.operand, "unit"):
node.unit = node.operand.unit
return node
def visit_BinOp(self, node):
self.generic_visit(node)
op = type(node.op)
left_unit = getattr(node.left, "unit", None)
right_unit = getattr(node.right, "unit", None)
if op in (ast.Add, ast.Sub, ast.Mod):
unit = units.addsub_dimension(left_unit, right_unit)
elif op == ast.Mult:
unit = units.mul_dimension(left_unit, right_unit)
elif op in (ast.Div, ast.FloorDiv):
unit = units.div_dimension(left_unit, right_unit)
else:
unit = None
if unit is not None:
node.unit = unit
return node
def visit_Attribute(self, node):
self.generic_visit(node)
if node.attr == "amount" and hasattr(node.value, "unit"):
del node.value.unit
return node.value
else:
return node
def visit_Call(self, node): def visit_Call(self, node):
fn = node.func.id self.generic_visit(node)
if fn in ("delay", "at"): if node.func.id == "Quantity":
old_in_core_time = self.in_core_time amount, unit = node.args
self.in_core_time = True amount.unit = unit.s
self.generic_visit(node) return amount
self.in_core_time = old_in_core_time elif node.func.id == "now":
elif fn == "Quantity": node.unit = "s"
if self.in_core_time: elif node.func.id == "syscall" and node.args[0].s == "rpc":
node = ast.copy_location( unit_list = [getattr(arg, "unit", None) for arg in node.args]
ast.BinOp(left=node.args[0], rpc_n = node.args[1].n
op=ast.Div(), self.rpc_map[rpc_n] = _add_units(self.rpc_map[rpc_n], unit_list)
right=value_to_ast(self.ref_period)), return node
node)
def _update_target(self, target, unit):
if isinstance(target, ast.Name):
if target.id in self.variable_units:
if self.variable_units[target.id] != unit:
raise TypeError(
"Inconsistent units for variable '{}': '{}' and '{}'"
.format(target.id,
self.variable_units[target.id],
unit))
else: else:
node = node.args[0] self.variable_units[target.id] = unit
else:
self.generic_visit(node) def visit_Assign(self, node):
node.value = self.visit(node.value)
unit = getattr(node.value, "unit", None)
for target in node.targets:
self._update_target(target, unit)
return node
def visit_AugAssign(self, node):
value = self.visit_BinOp(ast.BinOp(
op=node.op, left=node.target, right=node.value))
unit = getattr(value, "unit", None)
self._update_target(node.target, unit)
return node
# Only dimensionless iterators are supported
def visit_For(self, node):
self.generic_visit(node)
self._update_target(node.target, None)
return node return node
def lower_units(func_def, ref_period): def lower_units(func_def, rpc_map):
_UnitsLowerer(ref_period).visit(func_def) _UnitsLowerer(rpc_map).visit(func_def)