From 08ff5154a8dd6ff92b6f015eaa1026fdcb7a9849 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sun, 25 May 2014 13:31:02 +0200 Subject: [PATCH] transform: explicit delays and unroll for loops --- examples/collapse_test.py | 13 ++++ examples/transform.py | 150 ++++++++++++++++++++++++++------------ 2 files changed, 115 insertions(+), 48 deletions(-) create mode 100644 examples/collapse_test.py diff --git a/examples/collapse_test.py b/examples/collapse_test.py new file mode 100644 index 000000000..623e4a41c --- /dev/null +++ b/examples/collapse_test.py @@ -0,0 +1,13 @@ +def collapse_test(): + for i in range(3): + with parallel: + with sequential: + pulse("a", 100*MHz, 10*us) + delay(10*us) + pulse("b", 100*MHz, 10*us) + delay(20*us) + with sequential: + pulse("a", 100*MHz, 10*us) + delay(10*us) + pulse("b", 100*MHz, 10*us) + delay(10*us) diff --git a/examples/transform.py b/examples/transform.py index ba5d4b94e..931a70873 100644 --- a/examples/transform.py +++ b/examples/transform.py @@ -2,55 +2,109 @@ import inspect, textwrap, ast from artiq import units, unparse -_now = "_ARTIQ_now" - -class _RequestTransformer(ast.NodeTransformer): - def __init__(self, target_globals): - self.target_globals = target_globals - - def visit_FunctionDef(self, node): - self.generic_visit(node) - node.body.insert(0, ast.copy_location( - ast.Assign(targets=[ast.Name(id=_now, ctx=ast.Store())], - value=ast.Num(n=0)), node)) - node.body.append(ast.copy_location( - ast.Return(value=ast.Name(id=_now, ctx=ast.Store())), - node)) - return node - - def visit_Return(self, node): - raise TypeError("Kernels cannot return values") - - def visit_Call(self, node): - self.generic_visit(node) - name = node.func.id - if name == "delay": - if len(node.args) != 1: - raise TypeError("delay() takes 1 positional argument but {} were given".format(len(node.args))) - return ast.copy_location(ast.AugAssign( - target=ast.Name(id=_now, ctx=ast.Store()), - op=ast.Add(), value=node.args[0]), node) - return node - - def visit_Name(self, node): - if not isinstance(node.ctx, ast.Load): - return node - try: - obj = self.target_globals[node.id] - except KeyError: - return node - if isinstance(obj, units.Quantity): - return ast.Num(obj.amount) +def find_kernel_body(node): + while True: + if isinstance(node, ast.Module): + if len(node.body) != 1: + raise TypeError + node = node.body[0] + elif isinstance(node, ast.FunctionDef): + return node.body else: - return node + raise TypeError -def request_transform(target_ast, target_globals): - transformer = _RequestTransformer(target_globals) - transformer.visit(target_ast) +def eval_ast(expr, symdict=dict()): + if not isinstance(expr, ast.Expression): + expr = ast.Expression(expr) + code = compile(expr, "", "eval") + return eval(code, symdict) + +def _try_eval_with_units(node): + try: + r = eval_ast(node, units.__dict__) + except: + return node + if isinstance(r, units.Quantity): + return ast.copy_location(ast.Num(n=r.amount), node) + else: + return node + +def explicit_delays(stmts): + insertions = [] + for i, stmt in enumerate(stmts): + if isinstance(stmt, (ast.For, ast.While, ast.If)): + explicit_delays(stmt.body) + explicit_delays(stmt.orelse) + elif isinstance(stmt, ast.With): + explicit_delays(stmt.body) + elif isinstance(stmt, ast.Expr): + if not isinstance(stmt.value, ast.Call) or not isinstance(stmt.value.func, ast.Name): + continue + call = stmt.value + name = call.func.id + if name == "delay": + call.args[0] = _try_eval_with_units(call.args[0]) + elif name == "pulse": + call.func.id = "pulse_start" + insertions.append((i+1, ast.copy_location( + ast.Expr(ast.Call(func=ast.Name(id="delay", ctx=ast.Load()), + args=[_try_eval_with_units(call.args[2])], + keywords=[], starargs=[], kwargs=[])), + stmt))) + for i, (location, stmt) in enumerate(insertions): + stmts.insert(location+i, stmt) + +def _count_stmts(node): + if isinstance(node, (ast.For, ast.While, ast.If)): + print(ast.dump(node)) + return 1 + _count_stmts(node.body) + _count_stmts(node.orelse) + elif isinstance(node, ast.With): + return 1 + _count_stmts(node.body) + elif isinstance(node, list): + return sum(map(_count_stmts, node)) + else: + return 1 + +def unroll_loops(stmts, limit): + replacements = [] + for stmt_i, stmt in enumerate(stmts): + if isinstance(stmt, ast.For): + try: + it = eval_ast(stmt.iter) + except: + pass + else: + unroll_loops(stmt.body, limit) + n = len(it)*_count_stmts(stmt.body) + if n < limit: + replacement = [] + for i in it: + if not isinstance(i, int): + replacement = None + break + replacement.append(ast.copy_location( + ast.Assign(targets=[stmt.target], value=ast.Num(n=i)), stmt)) + replacement += stmt.body + if replacement is not None: + replacements.append((stmt_i, replacement)) + if isinstance(stmt, (ast.While, ast.If)): + unroll_loops(stmt.body, limit) + unroll_loops(stmt.orelse, limit) + elif isinstance(stmt, ast.With): + unroll_loops(stmt.body, limit) + offset = 0 + for location, new_stmts in replacements: + stmts[offset+location:offset+location+1] = new_stmts + offset += len(new_stmts) - 1 if __name__ == "__main__": - import threads_test - kernel = threads_test.threads_test - a = ast.parse(textwrap.dedent(inspect.getsource(kernel))) - request_transform(a, kernel.__globals__) - unparse.Unparser(a) + import collapse_test + kernel = collapse_test.collapse_test + + node = ast.parse(textwrap.dedent(inspect.getsource(kernel))) + node = find_kernel_body(node) + + explicit_delays(node) + unroll_loops(node, 50) + + unparse.Unparser(node)