forked from M-Labs/artiq
transforms: track units, now() returns seconds, implement time_to_cycles and cycles_to_time
This commit is contained in:
parent
1a64e92e75
commit
e22301ea05
|
@ -48,7 +48,7 @@ class Core:
|
|||
self, k_function, k_args, k_kwargs)
|
||||
_debug_unparse("inline", func_def)
|
||||
|
||||
lower_units(func_def, self.runtime_env.ref_period)
|
||||
lower_units(func_def, rpc_map)
|
||||
_debug_unparse("lower_units", func_def)
|
||||
|
||||
fold_constants(func_def)
|
||||
|
@ -60,7 +60,9 @@ class Core:
|
|||
interleave(func_def)
|
||||
_debug_unparse("interleave", func_def)
|
||||
|
||||
lower_time(func_def, getattr(self.runtime_env, "initial_time", 0))
|
||||
lower_time(func_def,
|
||||
getattr(self.runtime_env, "initial_time", 0),
|
||||
self.runtime_env.ref_period)
|
||||
_debug_unparse("lower_time", func_def)
|
||||
|
||||
fold_constants(func_def)
|
||||
|
|
|
@ -34,12 +34,12 @@ class DDS(AutoContext):
|
|||
|
||||
"""
|
||||
if self.previous_frequency != frequency:
|
||||
if self.sw.previous_timestamp != now():
|
||||
if self.sw.previous_timestamp != time_to_cycles(now()):
|
||||
self.sw.sync()
|
||||
if self.sw.previous_value:
|
||||
# Channel is already on.
|
||||
# Precise timing of frequency change is required.
|
||||
fud_time = now()
|
||||
fud_time = time_to_cycles(now())
|
||||
else:
|
||||
# Channel is off.
|
||||
# Use soft timing on FUD to prevent conflicts when
|
||||
|
|
|
@ -6,7 +6,7 @@ class _RTIOBase(AutoContext):
|
|||
parameters = "channel"
|
||||
|
||||
def build(self):
|
||||
self.previous_timestamp = int64(0)
|
||||
self.previous_timestamp = int64(0) # in RTIO cycles
|
||||
self.previous_value = 0
|
||||
|
||||
kernel_attr = "previous_timestamp previous_value"
|
||||
|
@ -17,14 +17,16 @@ class _RTIOBase(AutoContext):
|
|||
|
||||
@kernel
|
||||
def _set_value(self, value):
|
||||
if now() < self.previous_timestamp:
|
||||
if time_to_cycles(now()) < self.previous_timestamp:
|
||||
raise RTIOSequenceError
|
||||
if self.previous_value != value:
|
||||
if self.previous_timestamp == now():
|
||||
syscall("rtio_replace", now(), self.channel, value)
|
||||
if self.previous_timestamp == time_to_cycles(now()):
|
||||
syscall("rtio_replace", time_to_cycles(now()),
|
||||
self.channel, value)
|
||||
else:
|
||||
syscall("rtio_set", now(), self.channel, value)
|
||||
self.previous_timestamp = now()
|
||||
syscall("rtio_set", time_to_cycles(now()),
|
||||
self.channel, value)
|
||||
self.previous_timestamp = time_to_cycles(now())
|
||||
self.previous_value = value
|
||||
|
||||
|
||||
|
|
|
@ -4,21 +4,52 @@ from artiq.transforms.tools import value_to_ast
|
|||
from artiq.language.core import int64
|
||||
|
||||
|
||||
def _insert_int64(node):
|
||||
def _time_to_cycles(ref_period, node):
|
||||
divided = ast.copy_location(
|
||||
ast.BinOp(left=node,
|
||||
op=ast.Div(),
|
||||
right=value_to_ast(ref_period)),
|
||||
node)
|
||||
return ast.copy_location(
|
||||
ast.Call(func=ast.Name("int64", ast.Load()),
|
||||
args=[node],
|
||||
args=[divided],
|
||||
keywords=[], starargs=[], kwargs=[]),
|
||||
divided)
|
||||
|
||||
|
||||
def _cycles_to_time(ref_period, node):
|
||||
return ast.copy_location(
|
||||
ast.BinOp(left=node,
|
||||
op=ast.Mult(),
|
||||
right=value_to_ast(ref_period)),
|
||||
node)
|
||||
|
||||
|
||||
class _TimeLowerer(ast.NodeTransformer):
|
||||
def __init__(self, ref_period):
|
||||
self.ref_period = ref_period
|
||||
|
||||
def visit_Call(self, node):
|
||||
if isinstance(node.func, ast.Name) and node.func.id == "now":
|
||||
# optimize time_to_cycles(now()) -> now
|
||||
if (isinstance(node.func, ast.Name)
|
||||
and node.func.id == "time_to_cycles"
|
||||
and isinstance(node.args[0], ast.Call)
|
||||
and isinstance(node.args[0].func, ast.Name)
|
||||
and node.args[0].func.id == "now"):
|
||||
return ast.copy_location(ast.Name("now", ast.Load()), node)
|
||||
else:
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.func, ast.Name):
|
||||
funcname = node.func.id
|
||||
if funcname == "now":
|
||||
return _cycles_to_time(
|
||||
self.ref_period,
|
||||
ast.copy_location(ast.Name("now", ast.Load()), node))
|
||||
elif funcname == "time_to_cycles":
|
||||
return _time_to_cycles(self.ref_period, node)
|
||||
elif funcname == "cycles_to_time":
|
||||
return _cycles_to_time(self.ref_period, node)
|
||||
return node
|
||||
|
||||
def visit_Expr(self, node):
|
||||
self.generic_visit(node)
|
||||
|
@ -29,12 +60,14 @@ class _TimeLowerer(ast.NodeTransformer):
|
|||
return ast.copy_location(
|
||||
ast.AugAssign(target=ast.Name("now", ast.Store()),
|
||||
op=ast.Add(),
|
||||
value=_insert_int64(node.value.args[0])),
|
||||
value=_time_to_cycles(self.ref_period,
|
||||
node.value.args[0])),
|
||||
node)
|
||||
elif funcname == "at":
|
||||
return ast.copy_location(
|
||||
ast.Assign(targets=[ast.Name("now", ast.Store())],
|
||||
value=_insert_int64(node.value.args[0])),
|
||||
value=_time_to_cycles(self.ref_period,
|
||||
node.value.args[0])),
|
||||
node)
|
||||
else:
|
||||
return node
|
||||
|
@ -42,8 +75,8 @@ class _TimeLowerer(ast.NodeTransformer):
|
|||
return node
|
||||
|
||||
|
||||
def lower_time(func_def, initial_time):
|
||||
_TimeLowerer().visit(func_def)
|
||||
def lower_time(func_def, initial_time, ref_period):
|
||||
_TimeLowerer(ref_period).visit(func_def)
|
||||
func_def.body.insert(0, ast.copy_location(
|
||||
ast.Assign(targets=[ast.Name("now", ast.Store())],
|
||||
value=value_to_ast(int64(initial_time))),
|
||||
|
|
|
@ -1,40 +1,108 @@
|
|||
import ast
|
||||
|
||||
from artiq.transforms.tools import value_to_ast
|
||||
from artiq.language import units
|
||||
|
||||
|
||||
# TODO:
|
||||
# * track variable and expression dimensions
|
||||
# * raise exception on dimension errors in expressions
|
||||
# * modify RPC map to reintroduce units
|
||||
# * handle core time conversion outside of delay/at,
|
||||
# e.g. foo = now() + 1*us [...] at(foo)
|
||||
def _add_units(f, unit_list):
|
||||
def wrapper(*args):
|
||||
new_args = [arg if unit is None else units.Quantity(arg, unit)
|
||||
for arg, unit in zip(args, unit_list)]
|
||||
return f(*new_args)
|
||||
return wrapper
|
||||
|
||||
|
||||
class _UnitsLowerer(ast.NodeTransformer):
|
||||
def __init__(self, ref_period):
|
||||
self.ref_period = ref_period
|
||||
self.in_core_time = False
|
||||
def __init__(self, rpc_map):
|
||||
self.rpc_map = rpc_map
|
||||
self.variable_units = dict()
|
||||
|
||||
def visit_Name(self, node):
|
||||
try:
|
||||
unit = self.variable_units[node.id]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if unit is not None:
|
||||
node.unit = unit
|
||||
return node
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
self.generic_visit(node)
|
||||
if hasattr(node.operand, "unit"):
|
||||
node.unit = node.operand.unit
|
||||
return node
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
self.generic_visit(node)
|
||||
op = type(node.op)
|
||||
left_unit = getattr(node.left, "unit", None)
|
||||
right_unit = getattr(node.right, "unit", None)
|
||||
if op in (ast.Add, ast.Sub, ast.Mod):
|
||||
unit = units.addsub_dimension(left_unit, right_unit)
|
||||
elif op == ast.Mult:
|
||||
unit = units.mul_dimension(left_unit, right_unit)
|
||||
elif op in (ast.Div, ast.FloorDiv):
|
||||
unit = units.div_dimension(left_unit, right_unit)
|
||||
else:
|
||||
unit = None
|
||||
if unit is not None:
|
||||
node.unit = unit
|
||||
return node
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
self.generic_visit(node)
|
||||
if node.attr == "amount" and hasattr(node.value, "unit"):
|
||||
del node.value.unit
|
||||
return node.value
|
||||
else:
|
||||
return node
|
||||
|
||||
def visit_Call(self, node):
|
||||
fn = node.func.id
|
||||
if fn in ("delay", "at"):
|
||||
old_in_core_time = self.in_core_time
|
||||
self.in_core_time = True
|
||||
self.generic_visit(node)
|
||||
self.in_core_time = old_in_core_time
|
||||
elif fn == "Quantity":
|
||||
if self.in_core_time:
|
||||
node = ast.copy_location(
|
||||
ast.BinOp(left=node.args[0],
|
||||
op=ast.Div(),
|
||||
right=value_to_ast(self.ref_period)),
|
||||
node)
|
||||
self.generic_visit(node)
|
||||
if node.func.id == "Quantity":
|
||||
amount, unit = node.args
|
||||
amount.unit = unit.s
|
||||
return amount
|
||||
elif node.func.id == "now":
|
||||
node.unit = "s"
|
||||
elif node.func.id == "syscall" and node.args[0].s == "rpc":
|
||||
unit_list = [getattr(arg, "unit", None) for arg in node.args]
|
||||
rpc_n = node.args[1].n
|
||||
self.rpc_map[rpc_n] = _add_units(self.rpc_map[rpc_n], unit_list)
|
||||
return node
|
||||
|
||||
def _update_target(self, target, unit):
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in self.variable_units:
|
||||
if self.variable_units[target.id] != unit:
|
||||
raise TypeError(
|
||||
"Inconsistent units for variable '{}': '{}' and '{}'"
|
||||
.format(target.id,
|
||||
self.variable_units[target.id],
|
||||
unit))
|
||||
else:
|
||||
node = node.args[0]
|
||||
else:
|
||||
self.generic_visit(node)
|
||||
self.variable_units[target.id] = unit
|
||||
|
||||
def visit_Assign(self, node):
|
||||
node.value = self.visit(node.value)
|
||||
unit = getattr(node.value, "unit", None)
|
||||
for target in node.targets:
|
||||
self._update_target(target, unit)
|
||||
return node
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
value = self.visit_BinOp(ast.BinOp(
|
||||
op=node.op, left=node.target, right=node.value))
|
||||
unit = getattr(value, "unit", None)
|
||||
self._update_target(node.target, unit)
|
||||
return node
|
||||
|
||||
# Only dimensionless iterators are supported
|
||||
def visit_For(self, node):
|
||||
self.generic_visit(node)
|
||||
self._update_target(node.target, None)
|
||||
return node
|
||||
|
||||
|
||||
def lower_units(func_def, ref_period):
|
||||
_UnitsLowerer(ref_period).visit(func_def)
|
||||
def lower_units(func_def, rpc_map):
|
||||
_UnitsLowerer(rpc_map).visit(func_def)
|
||||
|
|
Loading…
Reference in New Issue