diff --git a/artiq/transforms/fold_constants.py b/artiq/transforms/fold_constants.py index 1258796f4..ff9aac9ca 100644 --- a/artiq/transforms/fold_constants.py +++ b/artiq/transforms/fold_constants.py @@ -28,6 +28,14 @@ _ast_binops = { ast.BitAnd: operator.and_ } +_ast_cmpops = { + ast.Eq: operator.eq, + ast.NotEq: operator.ne, + ast.Lt: operator.lt, + ast.LtE: operator.le, + ast.Gt: operator.gt, + ast.GtE: operator.ge +} _ast_boolops = { ast.Or: lambda x, y: x or y, @@ -68,6 +76,36 @@ class _ConstantFolder(ast.NodeTransformer): return node return ast.copy_location(result, node) + def visit_Compare(self, node): + self.generic_visit(node) + try: + operands = [eval_constant(node.left)] + except NotConstant: + operands = [node.left] + ops = [] + for op, right_ast in zip(node.ops, node.comparators): + try: + right = eval_constant(right_ast) + except NotConstant: + right = right_ast + if (not isinstance(operands[-1], ast.AST) + and not isinstance(right, ast.AST)): + left = operands.pop() + operands.append(_ast_cmpops[type(op)](left, right)) + else: + ops.append(op) + operands.append(right_ast) + operands = [operand if isinstance(operand, ast.AST) + else ast.copy_location(value_to_ast(operand), node) + for operand in operands] + if len(operands) == 1: + return operands[0] + else: + node.left = operands[0] + node.right = operands[1:] + node.ops = ops + return node + def visit_BoolOp(self, node): self.generic_visit(node) new_values = []