diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 5807d7308..50d70eb7c 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -40,6 +40,7 @@ use nac3parser::ast::{ self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, Unaryop, }; +use std::cmp::min; use std::iter::{repeat, repeat_with}; use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; @@ -2303,6 +2304,119 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( }; gen_list_cmpop(generator, ctx)? + } else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TTuple { .. })) { + let TypeEnum::TTuple { ty: left_tys, .. } = &*ctx.unifier.get_ty_immutable(left_ty) else { + return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) + }; + let TypeEnum::TTuple { ty: right_tys, .. } = &*ctx.unifier.get_ty_immutable(right_ty) else { + return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) + }; + + if ![Cmpop::Eq, Cmpop::NotEq].contains(op) { + todo!("Only __eq__ and __ne__ is implemented for tuples") + } + + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_i32 = ctx.ctx.i32_type(); + + // Assume `true` by default + let cmp_addr = generator.gen_var_alloc(ctx, llvm_i1.into(), None).unwrap(); + ctx.builder.build_store(cmp_addr, llvm_i1.const_all_ones()).unwrap(); + + let current_bb = ctx.builder.get_insert_block().unwrap(); + let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end"); + + ctx.builder.position_at_end(post_foreach_cmp); + let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap(); + ctx.builder.position_at_end(current_bb); + + // Generate comparison between each element + let min_len = min(left_tys.len(), right_tys.len()); + for i in 0..min_len { + let current_bb = ctx.builder.get_insert_block().unwrap(); + let bb = ctx.ctx.insert_basic_block_after(current_bb, &format!("foreach.cmp.tuple.{i}e")); + ctx.builder.build_unconditional_branch(bb).unwrap(); + + ctx.builder.position_at_end(bb); + let left_ty = left_tys[i]; + let left_elem = { + let plhs = generator.gen_var_alloc(ctx, lhs.get_type(), None).unwrap(); + ctx.builder.build_store(plhs, *lhs).unwrap(); + + ctx.build_in_bounds_gep_and_load( + plhs, + &[llvm_i32.const_zero(), llvm_i32.const_int(i as u64, false)], + None, + ) + }; + let right_ty = right_tys[i]; + let right_elem = { + let prhs = generator.gen_var_alloc(ctx, rhs.get_type(), None).unwrap(); + ctx.builder.build_store(prhs, *rhs).unwrap(); + + ctx.build_in_bounds_gep_and_load( + prhs, + &[llvm_i32.const_zero(), llvm_i32.const_int(i as u64, false)], + None, + ) + }; + + gen_if_callback( + generator, + ctx, + |generator, ctx| { + // Defer the `not` operation until the end - a != b <=> !(a == b) + let op = if *op == Cmpop::NotEq { Cmpop::Eq } else { *op }; + + let cmp = gen_cmpop_expr_with_values( + generator, + ctx, + (Some(left_ty), left_elem), + &[op], + &[(Some(right_ty), right_elem)], + ) + .transpose() + .unwrap() + .and_then(|v| { + v.to_basic_value_enum(ctx, generator, ctx.primitives.bool) + }) + .map(BasicValueEnum::into_int_value)?; + + Ok(ctx.builder.build_not(cmp, "").unwrap()) + }, + |_, ctx| { + let bb = ctx.builder.get_insert_block().unwrap(); + cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]); + ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); + + Ok(()) + }, + |_, _| Ok(()), + )?; + } + + // Length of tuples is checked last as operators do not short-circuit by tuple + // length in Python: + // + // >>> (1, 2) < ("a",) + // TypeError: '<' not supported between instances of 'int' and 'str' + let bb = ctx.builder.get_insert_block().unwrap(); + let is_len_eq = llvm_i1.const_int( + u64::from(left_tys.len() == right_tys.len()), + false, + ); + cmp_phi.add_incoming(&[(&is_len_eq, bb)]); + ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); + + ctx.builder.position_at_end(post_foreach_cmp); + let cmp_phi = cmp_phi.as_basic_value().into_int_value(); + + // Invert the final value if __ne__ + if *op == Cmpop::NotEq { + ctx.builder.build_not(cmp_phi, "").unwrap() + } else { + cmp_phi + } } else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TVar { .. })) { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { ctx.make_assert(