diff --git a/artiq/compiler/fold_constants.py b/artiq/compiler/fold_constants.py index 2dbcf835e..3ab9d31c0 100644 --- a/artiq/compiler/fold_constants.py +++ b/artiq/compiler/fold_constants.py @@ -1,7 +1,7 @@ import ast, operator from artiq.language import units -from artiq.compiler.tools import value_to_ast +from artiq.compiler.tools import value_to_ast, make_stmt_transformer class _NotConstant(Exception): pass @@ -79,7 +79,4 @@ class _ConstantFolder(ast.NodeTransformer): return node return result -def fold_constants(stmts): - constant_folder = _ConstantFolder() - new_stmts = [constant_folder.visit(stmt) for stmt in stmts] - stmts[:] = new_stmts +fold_constants = make_stmt_transformer(_ConstantFolder) diff --git a/artiq/compiler/tools.py b/artiq/compiler/tools.py index 1e91a6a43..8cecd5580 100644 --- a/artiq/compiler/tools.py +++ b/artiq/compiler/tools.py @@ -23,3 +23,10 @@ def value_to_ast(value): args=[ast.Num(value.amount), ast.Name("base_"+value.unit.name+"_unit", ast.Load())], keywords=[], starargs=None, kwargs=None) return None + +def make_stmt_transformer(transformer_class): + def stmt_transformer(stmts): + transformer = transformer_class() + new_stmts = [transformer.visit(stmt) for stmt in stmts] + stmts[:] = new_stmts + return stmt_transformer