From 1ab9a118d6155f61d0ca3b44de0c6cdd8934d9c4 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 20 Aug 2024 19:49:52 +0800 Subject: [PATCH] [core] codegen/expr: Implement comparison of tuples --- nac3core/src/codegen/expr.rs | 107 +++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 91322d58..2af57275 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -40,6 +40,7 @@ use nac3parser::ast::{ self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, Unaryop, }; +use std::cmp::min; use std::iter::{repeat, repeat_with}; use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; @@ -2207,6 +2208,112 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } ctx.ctx.bool_type().get_poison() + } else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TTuple { .. })) { + let TypeEnum::TTuple { ty: left_tys, .. } = &*ctx.unifier.get_ty_immutable(left_ty) else { + return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) + }; + let TypeEnum::TTuple { ty: right_tys, .. } = &*ctx.unifier.get_ty_immutable(right_ty) else { + return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) + }; + + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_i32 = ctx.ctx.i32_type(); + + // Assume `true` by default, similar to reduce-like operations + let cmp_addr = generator.gen_var_alloc(ctx, llvm_i1.into(), None).unwrap(); + ctx.builder.build_store(cmp_addr, llvm_i1.const_all_ones()).unwrap(); + + let current_bb = ctx.builder.get_insert_block().unwrap(); + let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end"); + + // Generate comparison between each element: + // for i in 0..min(lhs.len(), rhs.len()) { + // if !(lhs op rhs) { + // cmp = false; + // goto end; + // } + // } + let min_len = min(left_tys.len(), right_tys.len()); + for i in 0..min_len { + let current_bb = ctx.builder.get_insert_block().unwrap(); + let bb = ctx.ctx.insert_basic_block_after(current_bb, &format!("foreach.cmp.tuple.{i}e")); + ctx.builder.build_unconditional_branch(bb).unwrap(); + + ctx.builder.position_at_end(bb); + let left_ty = left_tys[i]; + let left_elem = ctx.build_in_bounds_gep_and_load( + lhs.into_pointer_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(i as u64, false)], + None, + ); + let right_ty = right_tys[i]; + let right_elem = ctx.build_in_bounds_gep_and_load( + rhs.into_pointer_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(i as u64, false)], + None, + ); + + gen_if_callback( + generator, + ctx, + |generator, ctx| { + gen_cmpop_expr_with_values( + generator, + ctx, + (Some(left_ty), left_elem), + &[*op], + &[(Some(right_ty), right_elem)], + ) + .transpose() + .unwrap() + .and_then(|v| v.to_basic_value_enum(ctx, generator, ctx.primitives.bool)) + .map(BasicValueEnum::into_int_value) + }, + |_, ctx| { + ctx.builder.build_store(cmp_addr, llvm_i1.const_zero()).unwrap(); + ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); + + Ok(()) + }, + |_, _| Ok(()), + )?; + } + + // All elements checked - Finally check length of tuples + gen_if_callback( + generator, + ctx, + |_, ctx| { + let cmp = ctx.builder.build_load(cmp_addr, "").map(BasicValueEnum::into_int_value).unwrap(); + + let llvm_cmpop = match op { + Cmpop::Eq => IntPredicate::EQ, + Cmpop::NotEq => IntPredicate::NE, + Cmpop::Lt => IntPredicate::ULT, + Cmpop::LtE => IntPredicate::ULE, + Cmpop::Gt => IntPredicate::UGT, + Cmpop::GtE => IntPredicate::UGE, + _ => unreachable!(), + }; + + Ok(ctx.builder.build_int_compare( + llvm_cmpop, + cmp, + llvm_i1.const_all_ones(), + "", + ).unwrap()) + }, + |_, ctx| { + ctx.builder.build_store(cmp_addr, llvm_i1.const_zero()).unwrap(); + + Ok(()) + }, + |_, _| Ok(()), + )?; + ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); + + ctx.builder.position_at_end(post_foreach_cmp); + ctx.builder.build_load(cmp_addr, "").map(BasicValueEnum::into_int_value).unwrap() } else { return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) };