1
0
forked from M-Labs/nac3

core: Implement multi-operand __eq__ and __ne__ for lists

This commit is contained in:
David Mak 2024-07-02 20:10:39 +08:00
parent 66c205275f
commit c4052b6342
2 changed files with 176 additions and 2 deletions

View File

@ -14,7 +14,10 @@ use crate::{
call_memcpy_generic, call_memcpy_generic,
}, },
need_sret, numpy, need_sret, numpy,
stmt::{gen_for_callback_incrementing, gen_if_else_expr_callback, gen_raise, gen_var}, stmt::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
gen_var,
},
CodeGenContext, CodeGenTask, CodeGenerator, CodeGenContext, CodeGenTask, CodeGenerator,
}, },
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
@ -36,7 +39,8 @@ use inkwell::{
}; };
use itertools::{chain, izip, Either, Itertools}; use itertools::{chain, izip, Either, Itertools};
use nac3parser::ast::{ use nac3parser::ast::{
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
Unaryop,
}; };
pub fn get_subst_key( pub fn get_subst_key(
@ -1833,6 +1837,175 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
_ => unreachable!(), _ => unreachable!(),
}; };
ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap() ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap()
} else if [left_ty, right_ty]
.iter()
.any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()))
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let gen_list_cmpop = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>|
-> Result<IntValue<'ctx>, String> {
let is_list1 =
left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id());
let is_list2 =
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id());
let gen_bool_const = |ctx: &CodeGenContext<'ctx, '_>, val: bool| {
let llvm_i1 = ctx.ctx.bool_type();
match (op, val) {
(Cmpop::Eq, true) | (Cmpop::NotEq, false) => llvm_i1.const_all_ones(),
(Cmpop::Eq, false) | (Cmpop::NotEq, true) => llvm_i1.const_zero(),
(_, _) => unreachable!(),
}
};
if !(is_list1 && is_list2) {
return Ok(generator.bool_to_i8(ctx, gen_bool_const(ctx, false)));
}
let left_elem_ty = if let TypeEnum::TObj { params, .. } =
&*ctx.unifier.get_ty_immutable(left_ty)
{
*params.iter().next().unwrap().1
} else {
unreachable!()
};
let right_elem_ty = if let TypeEnum::TObj { params, .. } =
&*ctx.unifier.get_ty_immutable(right_ty)
{
*params.iter().next().unwrap().1
} else {
unreachable!()
};
if !ctx.unifier.unioned(left_elem_ty, right_elem_ty) {
return Ok(generator.bool_to_i8(ctx, gen_bool_const(ctx, false)));
}
if ![Cmpop::Eq, Cmpop::NotEq].contains(op) {
todo!("Only __eq__ and __ne__ is implemented for lists")
}
let left_val =
ListValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None);
let right_val =
ListValue::from_ptr_val(rhs.into_pointer_value(), llvm_usize, None);
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::EQ,
left_val.load_size(ctx, None),
right_val.load_size(ctx, None),
"",
)
.unwrap())
},
|generator, ctx| {
let acc_addr = generator
.gen_var_alloc(ctx, ctx.ctx.bool_type().into(), None)
.unwrap();
ctx.builder
.build_store(acc_addr, ctx.ctx.bool_type().const_all_ones())
.unwrap();
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_zero(),
(left_val.load_size(ctx, None), false),
|generator, ctx, hooks, i| {
let left = unsafe {
left_val.data().get_unchecked(ctx, generator, &i, None)
};
let right = unsafe {
right_val.data().get_unchecked(ctx, generator, &i, None)
};
let res = gen_cmpop_expr_with_values(
generator,
ctx,
(Some(left_elem_ty), left),
&[Cmpop::Eq],
&[(Some(right_elem_ty), right)],
)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
.unwrap()
.into_int_value();
gen_if_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::EQ,
res,
res.get_type().const_zero(),
"",
)
.unwrap())
},
|_, ctx| {
ctx.builder
.build_store(
acc_addr,
ctx.ctx.bool_type().const_zero(),
)
.unwrap();
ctx.builder
.build_unconditional_branch(hooks.exit_bb)
.unwrap();
Ok(())
},
|_, _| Ok(()),
)
.unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc = ctx
.builder
.build_load(acc_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let acc = if *op == Cmpop::NotEq {
gen_unaryop_expr_with_values(
generator,
ctx,
Unaryop::Not,
(&Some(ctx.primitives.bool), acc.into()),
)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.bool)?
.into_int_value()
} else {
acc
};
Ok(Some(generator.bool_to_i8(ctx, acc)))
},
|generator, ctx| {
Ok(Some(generator.bool_to_i8(ctx, gen_bool_const(ctx, false))))
},
)?
.map(BasicValueEnum::into_int_value)
.unwrap())
};
gen_list_cmpop(generator, ctx)?
} else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TVar { .. })) { } else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TVar { .. })) {
if ctx.registry.llvm_options.opt_level != OptimizationLevel::None { if ctx.registry.llvm_options.opt_level != OptimizationLevel::None {
ctx.make_assert( ctx.make_assert(

View File

@ -667,6 +667,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
/* list ======== */ /* list ======== */
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]); impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]); impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);
impl_cmpop(unifier, store, list_t, &[list_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
/* ndarray ===== */ /* ndarray ===== */
let ndarray_usized_ndims_tvar = let ndarray_usized_ndims_tvar =