diff --git a/artiq/compiler/fold_constants.py b/artiq/compiler/fold_constants.py new file mode 100644 index 000000000..2dbcf835e --- /dev/null +++ b/artiq/compiler/fold_constants.py @@ -0,0 +1,85 @@ +import ast, operator + +from artiq.language import units +from artiq.compiler.tools import value_to_ast + +class _NotConstant(Exception): + pass + +def _get_constant(node): + if isinstance(node, ast.Num): + return node.n + elif isinstance(node, ast.Str): + return node.s + elif isinstance(node, ast.Call) \ + and isinstance(node.func, ast.Name) \ + and node.func.id == "Quantity": + amount, unit = node.args + amount = _get_constant(amount) + try: + unit = getattr(units, unit.id) + except: + raise _NotConstant + return units.Quantity(amount, unit) + else: + raise _NotConstant + +_ast_unops = { + ast.Invert: operator.inv, + ast.Not: operator.not_, + ast.UAdd: operator.pos, + ast.USub: operator.neg +} + +_ast_binops = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.LShift: operator.lshift, + ast.RShift: operator.rshift, + ast.BitOr: operator.or_, + ast.BitXor: operator.xor, + ast.BitAnd: operator.and_ +} + +class _ConstantFolder(ast.NodeTransformer): + def visit_UnaryOp(self, node): + self.generic_visit(node) + try: + operand = _get_constant(node.operand) + except _NotConstant: + return node + try: + op = _ast_unops[type(node.op)] + except KeyError: + return node + try: + result = value_to_ast(op(operand)) + except: + return node + return result + + def visit_BinOp(self, node): + self.generic_visit(node) + try: + left, right = _get_constant(node.left), _get_constant(node.right) + except _NotConstant: + return node + try: + op = _ast_binops[type(node.op)] + except KeyError: + return node + try: + result = value_to_ast(op(left, right)) + except: + return node + return result + +def fold_constants(stmts): + constant_folder = _ConstantFolder() + new_stmts = [constant_folder.visit(stmt) for stmt in stmts] + stmts[:] = new_stmts diff --git a/artiq/devices/core.py b/artiq/devices/core.py index 66eb365ab..ca225520a 100644 --- a/artiq/devices/core.py +++ b/artiq/devices/core.py @@ -1,11 +1,14 @@ from operator import itemgetter from artiq.compiler.inline import inline +from artiq.compiler.fold_constants import fold_constants from artiq.compiler.unparse import Unparser class Core: def run(self, k_function, k_args, k_kwargs): stmts, rpc_map = inline(self, k_function, k_args, k_kwargs) + fold_constants(stmts) + print("=========================") print(" Inlined")