diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 435243a82..9f7398dec 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -14,7 +14,10 @@ use crate::{ call_memcpy_generic, }, need_sret, numpy, - stmt::{gen_for_callback_incrementing, gen_if_else_expr_callback, gen_raise, gen_var}, + stmt::{ + gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, + gen_var, + }, CodeGenContext, CodeGenTask, CodeGenerator, }, symbol_resolver::{SymbolValue, ValueEnum}, @@ -36,7 +39,8 @@ use inkwell::{ }; use itertools::{chain, izip, Either, Itertools}; use nac3parser::ast::{ - self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, + self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, + Unaryop, }; pub fn get_subst_key( @@ -1833,6 +1837,175 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( _ => unreachable!(), }; ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap() + } else if [left_ty, right_ty] + .iter() + .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())) + { + let llvm_usize = generator.get_size_type(ctx.ctx); + + let gen_list_cmpop = |generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>| + -> Result, String> { + let is_list1 = + left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()); + let is_list2 = + right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()); + + let gen_bool_const = |ctx: &CodeGenContext<'ctx, '_>, val: bool| { + let llvm_i1 = ctx.ctx.bool_type(); + + match (op, val) { + (Cmpop::Eq, true) | (Cmpop::NotEq, false) => llvm_i1.const_all_ones(), + (Cmpop::Eq, false) | (Cmpop::NotEq, true) => llvm_i1.const_zero(), + (_, _) => unreachable!(), + } + }; + + if !(is_list1 && is_list2) { + return Ok(generator.bool_to_i8(ctx, gen_bool_const(ctx, false))); + } + + let left_elem_ty = if let TypeEnum::TObj { params, .. } = + &*ctx.unifier.get_ty_immutable(left_ty) + { + *params.iter().next().unwrap().1 + } else { + unreachable!() + }; + let right_elem_ty = if let TypeEnum::TObj { params, .. } = + &*ctx.unifier.get_ty_immutable(right_ty) + { + *params.iter().next().unwrap().1 + } else { + unreachable!() + }; + + if !ctx.unifier.unioned(left_elem_ty, right_elem_ty) { + return Ok(generator.bool_to_i8(ctx, gen_bool_const(ctx, false))); + } + + if ![Cmpop::Eq, Cmpop::NotEq].contains(op) { + todo!("Only __eq__ and __ne__ is implemented for lists") + } + + let left_val = + ListValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None); + let right_val = + ListValue::from_ptr_val(rhs.into_pointer_value(), llvm_usize, None); + + Ok(gen_if_else_expr_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare( + IntPredicate::EQ, + left_val.load_size(ctx, None), + right_val.load_size(ctx, None), + "", + ) + .unwrap()) + }, + |generator, ctx| { + let acc_addr = generator + .gen_var_alloc(ctx, ctx.ctx.bool_type().into(), None) + .unwrap(); + ctx.builder + .build_store(acc_addr, ctx.ctx.bool_type().const_all_ones()) + .unwrap(); + + gen_for_callback_incrementing( + generator, + ctx, + llvm_usize.const_zero(), + (left_val.load_size(ctx, None), false), + |generator, ctx, hooks, i| { + let left = unsafe { + left_val.data().get_unchecked(ctx, generator, &i, None) + }; + let right = unsafe { + right_val.data().get_unchecked(ctx, generator, &i, None) + }; + + let res = gen_cmpop_expr_with_values( + generator, + ctx, + (Some(left_elem_ty), left), + &[Cmpop::Eq], + &[(Some(right_elem_ty), right)], + )? + .unwrap() + .to_basic_value_enum(ctx, generator, ctx.primitives.bool) + .unwrap() + .into_int_value(); + + gen_if_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare( + IntPredicate::EQ, + res, + res.get_type().const_zero(), + "", + ) + .unwrap()) + }, + |_, ctx| { + ctx.builder + .build_store( + acc_addr, + ctx.ctx.bool_type().const_zero(), + ) + .unwrap(); + ctx.builder + .build_unconditional_branch(hooks.exit_bb) + .unwrap(); + + Ok(()) + }, + |_, _| Ok(()), + ) + .unwrap(); + + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + + let acc = ctx + .builder + .build_load(acc_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let acc = if *op == Cmpop::NotEq { + gen_unaryop_expr_with_values( + generator, + ctx, + Unaryop::Not, + (&Some(ctx.primitives.bool), acc.into()), + )? + .unwrap() + .to_basic_value_enum(ctx, generator, ctx.primitives.bool)? + .into_int_value() + } else { + acc + }; + + Ok(Some(generator.bool_to_i8(ctx, acc))) + }, + |generator, ctx| { + Ok(Some(generator.bool_to_i8(ctx, gen_bool_const(ctx, false)))) + }, + )? + .map(BasicValueEnum::into_int_value) + .unwrap()) + }; + + gen_list_cmpop(generator, ctx)? } else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TVar { .. })) { if ctx.registry.llvm_options.opt_level != OptimizationLevel::None { ctx.make_assert( diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index cbf870775..7f6b2f359 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -667,6 +667,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie /* list ======== */ impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]); impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]); + impl_cmpop(unifier, store, list_t, &[list_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t)); /* ndarray ===== */ let ndarray_usized_ndims_tvar =