diff --git a/artiq/compiler/interleave.py b/artiq/compiler/interleave.py new file mode 100644 index 000000000..eac041a14 --- /dev/null +++ b/artiq/compiler/interleave.py @@ -0,0 +1,115 @@ +import ast, types + +from artiq.language import units +from artiq.compiler.tools import eval_ast + +# -1 statement duration could not be pre-determined +# 0 statement has no effect on timeline +# >0 statement is a static delay that advances the timeline +# by the given amount (in base_s_unit) +def _get_duration(stmt): + if isinstance(stmt, (ast.Expr, ast.Assign)): + return _get_duration(stmt.value) + elif isinstance(stmt, ast.If): + if all(_get_duration(s) == 0 for s in stmt.body) and all(_get_duration(s) == 0 for s in stmt.orelse): + return 0 + else: + return -1 + elif isinstance(stmt, ast.Call) and isinstance(stmt.func, ast.Name): + name = stmt.func.id + if name == "delay": + da = stmt.args[0] + if isinstance(da, ast.Call) \ + and isinstance(da.func, ast.Name) \ + and da.func.id == "Quantity" \ + and isinstance(da.args[0], ast.Num): + if not isinstance(da.args[1], ast.Name) or da.args[1].id != "base_s_unit": + raise units.DimensionError("Delay not expressed in seconds") + return da.args[0].n + else: + return -1 + else: + return 0 + else: + return 0 + +def _interleave_timelines(timelines): + r = [] + + current_stmts = [] + for stmts in timelines: + it = iter(stmts) + try: + stmt = next(it) + except StopIteration: + pass + else: + current_stmts.append(types.SimpleNamespace(delay=_get_duration(stmt), stmt=stmt, it=it)) + + while current_stmts: + dt = min(stmt.delay for stmt in current_stmts) + print("\n".join("{} -> {}".format(ast.dump(stmt.stmt), stmt.delay) for stmt in current_stmts)) + print("") + if dt < 0: + # contains statement(s) with indeterminate duration + return None + if dt > 0: + # advance timeline by dt + for stmt in current_stmts: + stmt.delay -= dt + if stmt.delay == 0: + ref_stmt = stmt.stmt + da_expr = ast.copy_location( + ast.Call(func=ast.Name("Quantity", ast.Load()), + args=[ast.Num(dt), ast.Name("base_s_unit", ast.Load())], + keywords=[], starargs=[], kwargs=[]), + ref_stmt) + delay_stmt = ast.copy_location( + ast.Expr(ast.Call(func=ast.Name("delay", ast.Load()), + args=[da_expr], + keywords=[], starargs=[], kwargs=[])), + ref_stmt) + r.append(delay_stmt) + else: + for stmt in current_stmts: + if stmt.delay == 0: + r.append(stmt.stmt) + # discard executed statements + exhausted_list = [] + for stmt_i, stmt in enumerate(current_stmts): + if stmt.delay == 0: + try: + stmt.stmt = next(stmt.it) + except StopIteration: + exhausted_list.append(stmt_i) + else: + stmt.delay = _get_duration(stmt.stmt) + for offset, i in enumerate(exhausted_list): + current_stmts.pop(i-offset) + + return r + +def interleave(stmts): + replacements = [] + for stmt_i, stmt in enumerate(stmts): + if isinstance(stmt, (ast.For, ast.While, ast.If)): + interleave(stmt.body) + interleave(stmt.orelse) + elif isinstance(stmt, ast.With): + btype = stmt.items[0].context_expr.id + if btype == "sequential": + interleave(stmt.body) + replacements.append((stmt_i, stmt.body)) + elif btype == "parallel": + timelines = [[s] for s in stmt.body] + for timeline in timelines: + interleave(timeline) + merged = _interleave_timelines(timelines) + if merged is not None: + replacements.append((stmt_i, merged)) + else: + raise ValueError("Unknown block type: " + btype) + offset = 0 + for location, new_stmts in replacements: + stmts[offset+location:offset+location+1] = new_stmts + offset += len(new_stmts) - 1 diff --git a/artiq/compiler/transform.py b/artiq/compiler/transform.py deleted file mode 100644 index 05010fbcd..000000000 --- a/artiq/compiler/transform.py +++ /dev/null @@ -1,203 +0,0 @@ -import inspect, textwrap, ast, types - -from artiq.language import units -from artiq.compiler import unparse -from artiq.compiler.tools import eval_ast - -def find_kernel_body(node): - while True: - if isinstance(node, ast.Module): - if len(node.body) != 1: - raise TypeError - node = node.body[0] - elif isinstance(node, ast.FunctionDef): - return node.body - else: - raise TypeError - -def _try_eval_with_units(node): - try: - r = eval_ast(node, units.__dict__) - except: - return node - if isinstance(r, units.Quantity): - return ast.copy_location(ast.Num(n=r.amount), node) - else: - return node - -def explicit_delays(stmts): - insertions = [] - for i, stmt in enumerate(stmts): - if isinstance(stmt, (ast.For, ast.While, ast.If)): - explicit_delays(stmt.body) - explicit_delays(stmt.orelse) - elif isinstance(stmt, ast.With): - explicit_delays(stmt.body) - elif isinstance(stmt, ast.Expr): - if not isinstance(stmt.value, ast.Call) or not isinstance(stmt.value.func, ast.Name): - continue - call = stmt.value - name = call.func.id - if name == "delay": - call.args[0] = _try_eval_with_units(call.args[0]) - elif name == "pulse": - call.func.id = "pulse_start" - insertions.append((i+1, ast.copy_location( - ast.Expr(ast.Call(func=ast.Name(id="delay", ctx=ast.Load()), - args=[_try_eval_with_units(call.args[2])], - keywords=[], starargs=[], kwargs=[])), - stmt))) - for i, (location, stmt) in enumerate(insertions): - stmts.insert(location+i, stmt) - -def _count_stmts(node): - if isinstance(node, (ast.For, ast.While, ast.If)): - print(ast.dump(node)) - return 1 + _count_stmts(node.body) + _count_stmts(node.orelse) - elif isinstance(node, ast.With): - return 1 + _count_stmts(node.body) - elif isinstance(node, list): - return sum(map(_count_stmts, node)) - else: - return 1 - -def unroll_loops(stmts, limit): - replacements = [] - for stmt_i, stmt in enumerate(stmts): - if isinstance(stmt, ast.For): - try: - it = eval_ast(stmt.iter) - except: - pass - else: - unroll_loops(stmt.body, limit) - unroll_loops(stmt.orelse, limit) - l_it = len(it) - if l_it: - n = l_it*_count_stmts(stmt.body) - if n < limit: - replacement = [] - for i in it: - if not isinstance(i, int): - replacement = None - break - replacement.append(ast.copy_location( - ast.Assign(targets=[stmt.target], value=ast.Num(n=i)), stmt)) - replacement += stmt.body - if replacement is not None: - replacements.append((stmt_i, replacement)) - else: - replacements.append((stmt_i, stmt.orelse)) - if isinstance(stmt, (ast.While, ast.If)): - unroll_loops(stmt.body, limit) - unroll_loops(stmt.orelse, limit) - elif isinstance(stmt, ast.With): - unroll_loops(stmt.body, limit) - offset = 0 - for location, new_stmts in replacements: - stmts[offset+location:offset+location+1] = new_stmts - offset += len(new_stmts) - 1 - -# -1 statement duration could not be pre-determined -# 0 statement has no effect on timeline -# >0 statement is a static delay that advances the timeline by the given amount -def _get_duration(stmt): - if isinstance(stmt, (ast.Expr, ast.Assign)): - return _get_duration(stmt.value) - elif isinstance(stmt, ast.Call) and isinstance(stmt.func, ast.Name): - name = stmt.func.id - if name == "delay": - if isinstance(stmt.args[0], ast.Num): - return stmt.args[0].n - else: - return -1 - elif name == "wait_edge": - return -1 - else: - return 0 - else: - return -1 - -def _merge_timelines(timelines): - r = [] - - current_stmts = [] - for stmts in timelines: - it = iter(stmts) - try: - stmt = next(it) - except StopIteration: - pass - else: - current_stmts.append(types.SimpleNamespace(delay=_get_duration(stmt), stmt=stmt, it=it)) - - while current_stmts: - dt = min(stmt.delay for stmt in current_stmts) - if dt < 0: - # contains statement(s) with indeterminate duration - return None - if dt > 0: - # advance timeline by dt - for stmt in current_stmts: - stmt.delay -= dt - if stmt.delay == 0: - ref_stmt = stmt.stmt - delay_stmt = ast.copy_location( - ast.Expr(ast.Call(func=ast.Name(id="delay", ctx=ast.Load()), - args=[ast.Num(n=dt)], - keywords=[], starargs=[], kwargs=[])), - ref_stmt) - r.append(delay_stmt) - else: - for stmt in current_stmts: - if stmt.delay == 0: - r.append(stmt.stmt) - # discard executed statements - exhausted_list = [] - for stmt_i, stmt in enumerate(current_stmts): - if stmt.delay == 0: - try: - stmt.stmt = next(stmt.it) - except StopIteration: - exhausted_list.append(stmt_i) - else: - stmt.delay = _get_duration(stmt.stmt) - for offset, i in enumerate(exhausted_list): - current_stmts.pop(i-offset) - - return r - -def collapse(stmts): - replacements = [] - for stmt_i, stmt in enumerate(stmts): - if isinstance(stmt, (ast.For, ast.While, ast.If)): - collapse(stmt.body) - collapse(stmt.orelse) - elif isinstance(stmt, ast.With): - btype = stmt.items[0].context_expr.id - if btype == "sequential": - collapse(stmt.body) - replacements.append((stmt_i, stmt.body)) - elif btype == "parallel": - timelines = [[s] for s in stmt.body] - for timeline in timelines: - collapse(timeline) - merged = _merge_timelines(timelines) - if merged is not None: - replacements.append((stmt_i, merged)) - else: - raise ValueError("Unknown block type: " + btype) - offset = 0 - for location, new_stmts in replacements: - stmts[offset+location:offset+location+1] = new_stmts - offset += len(new_stmts) - 1 - -def transform(k_function, k_args, k_kwargs): - node = ast.parse(textwrap.dedent(inspect.getsource(k_function))) - node = find_kernel_body(node) - - explicit_delays(node) - unroll_loops(node, 50) - collapse(node) - - unparse.Unparser(node) diff --git a/artiq/devices/core.py b/artiq/devices/core.py index 3e924adc0..c027127c6 100644 --- a/artiq/devices/core.py +++ b/artiq/devices/core.py @@ -3,6 +3,7 @@ from operator import itemgetter from artiq.compiler.inline import inline from artiq.compiler.fold_constants import fold_constants from artiq.compiler.unroll_loops import unroll_loops +from artiq.compiler.interleave import interleave from artiq.compiler.unparse import Unparser class Core: @@ -10,6 +11,7 @@ class Core: stmts, rpc_map = inline(self, k_function, k_args, k_kwargs) fold_constants(stmts) unroll_loops(stmts, 50) + interleave(stmts) print("=========================") print(" Inlined") diff --git a/examples/compiler_test.py b/examples/compiler_test.py index fcd4e55c1..35916a935 100644 --- a/examples/compiler_test.py +++ b/examples/compiler_test.py @@ -14,14 +14,14 @@ class CompilerTest(Experiment): @kernel def run(self, n, t2): - t2 += 1*us for i in my_range(n): self.set_some_slowdev(i) delay(100*ms) with parallel: with sequential: - self.a.pulse(100*MHz, 20*us) - self.b.pulse(100*MHz, t2) + for j in my_range(3): + self.a.pulse((j+1)*100*MHz, 20*us) + self.b.pulse(100*MHz, t2) with sequential: self.A.pulse(100*MHz, 10*us) self.B.pulse(100*MHz, t2)