core: Implement multi-operand __eq__ and __ne__ for lists
This commit is contained in:
parent
66c205275f
commit
c4052b6342
|
@ -14,7 +14,10 @@ use crate::{
|
|||
call_memcpy_generic,
|
||||
},
|
||||
need_sret, numpy,
|
||||
stmt::{gen_for_callback_incrementing, gen_if_else_expr_callback, gen_raise, gen_var},
|
||||
stmt::{
|
||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||
gen_var,
|
||||
},
|
||||
CodeGenContext, CodeGenTask, CodeGenerator,
|
||||
},
|
||||
symbol_resolver::{SymbolValue, ValueEnum},
|
||||
|
@ -36,7 +39,8 @@ use inkwell::{
|
|||
};
|
||||
use itertools::{chain, izip, Either, Itertools};
|
||||
use nac3parser::ast::{
|
||||
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
||||
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
||||
Unaryop,
|
||||
};
|
||||
|
||||
pub fn get_subst_key(
|
||||
|
@ -1833,6 +1837,175 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
_ => unreachable!(),
|
||||
};
|
||||
ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap()
|
||||
} else if [left_ty, right_ty]
|
||||
.iter()
|
||||
.any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()))
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let gen_list_cmpop = |generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>|
|
||||
-> Result<IntValue<'ctx>, String> {
|
||||
let is_list1 =
|
||||
left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id());
|
||||
let is_list2 =
|
||||
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id());
|
||||
|
||||
let gen_bool_const = |ctx: &CodeGenContext<'ctx, '_>, val: bool| {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
|
||||
match (op, val) {
|
||||
(Cmpop::Eq, true) | (Cmpop::NotEq, false) => llvm_i1.const_all_ones(),
|
||||
(Cmpop::Eq, false) | (Cmpop::NotEq, true) => llvm_i1.const_zero(),
|
||||
(_, _) => unreachable!(),
|
||||
}
|
||||
};
|
||||
|
||||
if !(is_list1 && is_list2) {
|
||||
return Ok(generator.bool_to_i8(ctx, gen_bool_const(ctx, false)));
|
||||
}
|
||||
|
||||
let left_elem_ty = if let TypeEnum::TObj { params, .. } =
|
||||
&*ctx.unifier.get_ty_immutable(left_ty)
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let right_elem_ty = if let TypeEnum::TObj { params, .. } =
|
||||
&*ctx.unifier.get_ty_immutable(right_ty)
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
if !ctx.unifier.unioned(left_elem_ty, right_elem_ty) {
|
||||
return Ok(generator.bool_to_i8(ctx, gen_bool_const(ctx, false)));
|
||||
}
|
||||
|
||||
if ![Cmpop::Eq, Cmpop::NotEq].contains(op) {
|
||||
todo!("Only __eq__ and __ne__ is implemented for lists")
|
||||
}
|
||||
|
||||
let left_val =
|
||||
ListValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None);
|
||||
let right_val =
|
||||
ListValue::from_ptr_val(rhs.into_pointer_value(), llvm_usize, None);
|
||||
|
||||
Ok(gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
left_val.load_size(ctx, None),
|
||||
right_val.load_size(ctx, None),
|
||||
"",
|
||||
)
|
||||
.unwrap())
|
||||
},
|
||||
|generator, ctx| {
|
||||
let acc_addr = generator
|
||||
.gen_var_alloc(ctx, ctx.ctx.bool_type().into(), None)
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_store(acc_addr, ctx.ctx.bool_type().const_all_ones())
|
||||
.unwrap();
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
llvm_usize.const_zero(),
|
||||
(left_val.load_size(ctx, None), false),
|
||||
|generator, ctx, hooks, i| {
|
||||
let left = unsafe {
|
||||
left_val.data().get_unchecked(ctx, generator, &i, None)
|
||||
};
|
||||
let right = unsafe {
|
||||
right_val.data().get_unchecked(ctx, generator, &i, None)
|
||||
};
|
||||
|
||||
let res = gen_cmpop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(Some(left_elem_ty), left),
|
||||
&[Cmpop::Eq],
|
||||
&[(Some(right_elem_ty), right)],
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
res,
|
||||
res.get_type().const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap())
|
||||
},
|
||||
|_, ctx| {
|
||||
ctx.builder
|
||||
.build_store(
|
||||
acc_addr,
|
||||
ctx.ctx.bool_type().const_zero(),
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_unconditional_branch(hooks.exit_bb)
|
||||
.unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, _| Ok(()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let acc = ctx
|
||||
.builder
|
||||
.build_load(acc_addr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let acc = if *op == Cmpop::NotEq {
|
||||
gen_unaryop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
Unaryop::Not,
|
||||
(&Some(ctx.primitives.bool), acc.into()),
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.bool)?
|
||||
.into_int_value()
|
||||
} else {
|
||||
acc
|
||||
};
|
||||
|
||||
Ok(Some(generator.bool_to_i8(ctx, acc)))
|
||||
},
|
||||
|generator, ctx| {
|
||||
Ok(Some(generator.bool_to_i8(ctx, gen_bool_const(ctx, false))))
|
||||
},
|
||||
)?
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap())
|
||||
};
|
||||
|
||||
gen_list_cmpop(generator, ctx)?
|
||||
} else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TVar { .. })) {
|
||||
if ctx.registry.llvm_options.opt_level != OptimizationLevel::None {
|
||||
ctx.make_assert(
|
||||
|
|
|
@ -667,6 +667,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
/* list ======== */
|
||||
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
|
||||
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);
|
||||
impl_cmpop(unifier, store, list_t, &[list_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
|
||||
|
||||
/* ndarray ===== */
|
||||
let ndarray_usized_ndims_tvar =
|
||||
|
|
Loading…
Reference in New Issue