mirror of https://github.com/m-labs/artiq.git
transform: explicit delays and unroll for loops
This commit is contained in:
parent
e3620cc61f
commit
08ff5154a8
|
@ -0,0 +1,13 @@
|
||||||
|
def collapse_test():
|
||||||
|
for i in range(3):
|
||||||
|
with parallel:
|
||||||
|
with sequential:
|
||||||
|
pulse("a", 100*MHz, 10*us)
|
||||||
|
delay(10*us)
|
||||||
|
pulse("b", 100*MHz, 10*us)
|
||||||
|
delay(20*us)
|
||||||
|
with sequential:
|
||||||
|
pulse("a", 100*MHz, 10*us)
|
||||||
|
delay(10*us)
|
||||||
|
pulse("b", 100*MHz, 10*us)
|
||||||
|
delay(10*us)
|
|
@ -2,55 +2,109 @@ import inspect, textwrap, ast
|
||||||
|
|
||||||
from artiq import units, unparse
|
from artiq import units, unparse
|
||||||
|
|
||||||
_now = "_ARTIQ_now"
|
def find_kernel_body(node):
|
||||||
|
while True:
|
||||||
class _RequestTransformer(ast.NodeTransformer):
|
if isinstance(node, ast.Module):
|
||||||
def __init__(self, target_globals):
|
if len(node.body) != 1:
|
||||||
self.target_globals = target_globals
|
raise TypeError
|
||||||
|
node = node.body[0]
|
||||||
def visit_FunctionDef(self, node):
|
elif isinstance(node, ast.FunctionDef):
|
||||||
self.generic_visit(node)
|
return node.body
|
||||||
node.body.insert(0, ast.copy_location(
|
|
||||||
ast.Assign(targets=[ast.Name(id=_now, ctx=ast.Store())],
|
|
||||||
value=ast.Num(n=0)), node))
|
|
||||||
node.body.append(ast.copy_location(
|
|
||||||
ast.Return(value=ast.Name(id=_now, ctx=ast.Store())),
|
|
||||||
node))
|
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_Return(self, node):
|
|
||||||
raise TypeError("Kernels cannot return values")
|
|
||||||
|
|
||||||
def visit_Call(self, node):
|
|
||||||
self.generic_visit(node)
|
|
||||||
name = node.func.id
|
|
||||||
if name == "delay":
|
|
||||||
if len(node.args) != 1:
|
|
||||||
raise TypeError("delay() takes 1 positional argument but {} were given".format(len(node.args)))
|
|
||||||
return ast.copy_location(ast.AugAssign(
|
|
||||||
target=ast.Name(id=_now, ctx=ast.Store()),
|
|
||||||
op=ast.Add(), value=node.args[0]), node)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_Name(self, node):
|
|
||||||
if not isinstance(node.ctx, ast.Load):
|
|
||||||
return node
|
|
||||||
try:
|
|
||||||
obj = self.target_globals[node.id]
|
|
||||||
except KeyError:
|
|
||||||
return node
|
|
||||||
if isinstance(obj, units.Quantity):
|
|
||||||
return ast.Num(obj.amount)
|
|
||||||
else:
|
else:
|
||||||
return node
|
raise TypeError
|
||||||
|
|
||||||
def request_transform(target_ast, target_globals):
|
def eval_ast(expr, symdict=dict()):
|
||||||
transformer = _RequestTransformer(target_globals)
|
if not isinstance(expr, ast.Expression):
|
||||||
transformer.visit(target_ast)
|
expr = ast.Expression(expr)
|
||||||
|
code = compile(expr, "<ast>", "eval")
|
||||||
|
return eval(code, symdict)
|
||||||
|
|
||||||
|
def _try_eval_with_units(node):
|
||||||
|
try:
|
||||||
|
r = eval_ast(node, units.__dict__)
|
||||||
|
except:
|
||||||
|
return node
|
||||||
|
if isinstance(r, units.Quantity):
|
||||||
|
return ast.copy_location(ast.Num(n=r.amount), node)
|
||||||
|
else:
|
||||||
|
return node
|
||||||
|
|
||||||
|
def explicit_delays(stmts):
|
||||||
|
insertions = []
|
||||||
|
for i, stmt in enumerate(stmts):
|
||||||
|
if isinstance(stmt, (ast.For, ast.While, ast.If)):
|
||||||
|
explicit_delays(stmt.body)
|
||||||
|
explicit_delays(stmt.orelse)
|
||||||
|
elif isinstance(stmt, ast.With):
|
||||||
|
explicit_delays(stmt.body)
|
||||||
|
elif isinstance(stmt, ast.Expr):
|
||||||
|
if not isinstance(stmt.value, ast.Call) or not isinstance(stmt.value.func, ast.Name):
|
||||||
|
continue
|
||||||
|
call = stmt.value
|
||||||
|
name = call.func.id
|
||||||
|
if name == "delay":
|
||||||
|
call.args[0] = _try_eval_with_units(call.args[0])
|
||||||
|
elif name == "pulse":
|
||||||
|
call.func.id = "pulse_start"
|
||||||
|
insertions.append((i+1, ast.copy_location(
|
||||||
|
ast.Expr(ast.Call(func=ast.Name(id="delay", ctx=ast.Load()),
|
||||||
|
args=[_try_eval_with_units(call.args[2])],
|
||||||
|
keywords=[], starargs=[], kwargs=[])),
|
||||||
|
stmt)))
|
||||||
|
for i, (location, stmt) in enumerate(insertions):
|
||||||
|
stmts.insert(location+i, stmt)
|
||||||
|
|
||||||
|
def _count_stmts(node):
|
||||||
|
if isinstance(node, (ast.For, ast.While, ast.If)):
|
||||||
|
print(ast.dump(node))
|
||||||
|
return 1 + _count_stmts(node.body) + _count_stmts(node.orelse)
|
||||||
|
elif isinstance(node, ast.With):
|
||||||
|
return 1 + _count_stmts(node.body)
|
||||||
|
elif isinstance(node, list):
|
||||||
|
return sum(map(_count_stmts, node))
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def unroll_loops(stmts, limit):
|
||||||
|
replacements = []
|
||||||
|
for stmt_i, stmt in enumerate(stmts):
|
||||||
|
if isinstance(stmt, ast.For):
|
||||||
|
try:
|
||||||
|
it = eval_ast(stmt.iter)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
unroll_loops(stmt.body, limit)
|
||||||
|
n = len(it)*_count_stmts(stmt.body)
|
||||||
|
if n < limit:
|
||||||
|
replacement = []
|
||||||
|
for i in it:
|
||||||
|
if not isinstance(i, int):
|
||||||
|
replacement = None
|
||||||
|
break
|
||||||
|
replacement.append(ast.copy_location(
|
||||||
|
ast.Assign(targets=[stmt.target], value=ast.Num(n=i)), stmt))
|
||||||
|
replacement += stmt.body
|
||||||
|
if replacement is not None:
|
||||||
|
replacements.append((stmt_i, replacement))
|
||||||
|
if isinstance(stmt, (ast.While, ast.If)):
|
||||||
|
unroll_loops(stmt.body, limit)
|
||||||
|
unroll_loops(stmt.orelse, limit)
|
||||||
|
elif isinstance(stmt, ast.With):
|
||||||
|
unroll_loops(stmt.body, limit)
|
||||||
|
offset = 0
|
||||||
|
for location, new_stmts in replacements:
|
||||||
|
stmts[offset+location:offset+location+1] = new_stmts
|
||||||
|
offset += len(new_stmts) - 1
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import threads_test
|
import collapse_test
|
||||||
kernel = threads_test.threads_test
|
kernel = collapse_test.collapse_test
|
||||||
a = ast.parse(textwrap.dedent(inspect.getsource(kernel)))
|
|
||||||
request_transform(a, kernel.__globals__)
|
node = ast.parse(textwrap.dedent(inspect.getsource(kernel)))
|
||||||
unparse.Unparser(a)
|
node = find_kernel_body(node)
|
||||||
|
|
||||||
|
explicit_delays(node)
|
||||||
|
unroll_loops(node, 50)
|
||||||
|
|
||||||
|
unparse.Unparser(node)
|
||||||
|
|
Loading…
Reference in New Issue