transforms/fold_constants: support comparisons

This commit is contained in:
Sebastien Bourdeauducq 2014-10-29 18:46:06 +08:00
parent c82c631a1d
commit be94a8b07c
1 changed files with 38 additions and 0 deletions

View File

@ -28,6 +28,14 @@ _ast_binops = {
ast.BitAnd: operator.and_ ast.BitAnd: operator.and_
} }
_ast_cmpops = {
ast.Eq: operator.eq,
ast.NotEq: operator.ne,
ast.Lt: operator.lt,
ast.LtE: operator.le,
ast.Gt: operator.gt,
ast.GtE: operator.ge
}
_ast_boolops = { _ast_boolops = {
ast.Or: lambda x, y: x or y, ast.Or: lambda x, y: x or y,
@ -68,6 +76,36 @@ class _ConstantFolder(ast.NodeTransformer):
return node return node
return ast.copy_location(result, node) return ast.copy_location(result, node)
def visit_Compare(self, node):
self.generic_visit(node)
try:
operands = [eval_constant(node.left)]
except NotConstant:
operands = [node.left]
ops = []
for op, right_ast in zip(node.ops, node.comparators):
try:
right = eval_constant(right_ast)
except NotConstant:
right = right_ast
if (not isinstance(operands[-1], ast.AST)
and not isinstance(right, ast.AST)):
left = operands.pop()
operands.append(_ast_cmpops[type(op)](left, right))
else:
ops.append(op)
operands.append(right_ast)
operands = [operand if isinstance(operand, ast.AST)
else ast.copy_location(value_to_ast(operand), node)
for operand in operands]
if len(operands) == 1:
return operands[0]
else:
node.left = operands[0]
node.right = operands[1:]
node.ops = ops
return node
def visit_BoolOp(self, node): def visit_BoolOp(self, node):
self.generic_visit(node) self.generic_visit(node)
new_values = [] new_values = []