From d11f66291c19157bf5262c53ca6533a668722dbf Mon Sep 17 00:00:00 2001 From: whitequark Date: Fri, 20 Apr 2018 12:18:58 +0000 Subject: [PATCH] compiler: desugar x != y into not x == y (fixes #974). --- .../compiler/transforms/artiq_ir_generator.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index a2f1971dc..df274b794 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -1478,7 +1478,7 @@ class ARTIQIRGenerator(algorithm.Visitor): else: assert False - def polymorphic_compare_pair_inclusion(self, op, needle, haystack): + def polymorphic_compare_pair_inclusion(self, needle, haystack): if builtins.is_range(haystack.type): # Optimized range `in` operator start = self.append(ir.GetAttr(haystack, "start")) @@ -1522,21 +1522,29 @@ class ARTIQIRGenerator(algorithm.Visitor): else: assert False - if isinstance(op, ast.NotIn): - result = self.append(ir.Select(result, - ir.Constant(False, builtins.TBool()), - ir.Constant(True, builtins.TBool()))) - return result + def invert(self, value): + return self.append(ir.Select(value, + ir.Constant(False, builtins.TBool()), + ir.Constant(True, builtins.TBool()))) + def polymorphic_compare_pair(self, op, lhs, rhs): if isinstance(op, (ast.Is, ast.IsNot)): # The backend will handle equality of aggregates. return self.append(ir.Compare(op, lhs, rhs)) - elif isinstance(op, (ast.In, ast.NotIn)): - return self.polymorphic_compare_pair_inclusion(op, lhs, rhs) - else: # Eq, NotEq, Lt, LtE, Gt, GtE + elif isinstance(op, ast.In): + return self.polymorphic_compare_pair_inclusion(lhs, rhs) + elif isinstance(op, ast.NotIn): + result = self.polymorphic_compare_pair_inclusion(lhs, rhs) + return self.invert(result) + elif isinstance(op, (ast.Eq, ast.Lt, ast.LtE, ast.Gt, ast.GtE)): return self.polymorphic_compare_pair_order(op, lhs, rhs) + elif isinstance(op, ast.NotEq): + result = self.polymorphic_compare_pair_order(ast.Eq(loc=op.loc), lhs, rhs) + return self.invert(result) + else: + assert False def visit_CompareT(self, node): # Essentially a sequence of `and`s performed over results