forked from M-Labs/nac3
1
0
Fork 0

core: refactor np_max/np_min functions

This commit is contained in:
abdul124 2024-07-12 18:18:54 +08:00
parent cea7cade51
commit 2e01b77fc8
1 changed files with 53 additions and 102 deletions

View File

@ -661,90 +661,6 @@ pub fn call_min<'ctx>(
}
}
/// Invokes the `np_min` builtin function.
pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_min";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (a_ty, a) = a;
Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a
}
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx
.builder
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
"zero-size array to reduction operation minimum which has no identity",
[None, None, None],
ctx.current_loc,
);
}
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
unsafe {
let identity =
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap();
}
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(1, false),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem));
ctx.builder.build_store(accumulator_addr, result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
})
}
/// Invokes the `np_minimum` builtin function.
pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
@ -877,18 +793,20 @@ pub fn call_max<'ctx>(
}
}
/// Invokes the `np_max` builtin function.
pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
/// Invokes the np_max, np_min, np_argmax, np_argmin functions
/// * `fn_name`: Can be one of "np_argmin", "np_argmax", "np_max", "np_min"
pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
fn_name: &str,
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_max";
debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name));
let llvm_int64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let (a_ty, a) = a;
Ok( match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([
@ -902,9 +820,12 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
.iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a
match fn_name {
"np_argmin" | "np_argmax" => llvm_int64.const_zero().into(),
"np_max" | "np_min" => a,
_ => unreachable!()
}
}
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
@ -923,41 +844,71 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
generator,
n_sz_eqz,
"0:ValueError",
"zero-size array to reduction operation minimum which has no identity",
format!("zero-size array to reduction operation {}", fn_name).as_str(),
[None, None, None],
ctx.current_loc,
);
}
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?;
unsafe {
let identity =
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap();
ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap();
}
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(1, false),
llvm_int64.const_int(1, false),
(n_sz, false),
|generator, ctx, _, idx| {
|generator, ctx, _, idx,| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem));
let cur_idx = ctx.builder.build_load(res_idx, "").unwrap();
let result = match fn_name {
"np_argmin" | "np_min" => call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)),
"np_argmax" | "np_max" => call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)),
_ => unreachable!()
};
let updated_idx = match (accumulator, result){
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => {
ctx.builder.build_select(
ctx.builder.build_int_compare(IntPredicate::NE,m, n, "").unwrap(),
idx.into(),
cur_idx,
"").unwrap()
},
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => {
ctx.builder.build_select(
ctx.builder.build_float_compare(FloatPredicate::ONE,m, n, "").unwrap(),
idx.into(),
cur_idx,
"").unwrap()
},
_ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]),
};
ctx.builder.build_store(res_idx, updated_idx).unwrap();
ctx.builder.build_store(accumulator_addr, result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
llvm_int64.const_int(1, false),
)?;
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator
match fn_name {
"np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(),
"np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(),
_ => unreachable!()
}
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
_ => unsupported_type(ctx, fn_name, &[a_ty])
})
}