From 2e01b77fc87d66d857cb1e3d1856dd060192669e Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 12 Jul 2024 18:18:54 +0800 Subject: [PATCH] core: refactor np_max/np_min functions --- nac3core/src/codegen/builtin_fns.rs | 155 ++++++++++------------------ 1 file changed, 53 insertions(+), 102 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index c87e27306..aefbac9de 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -661,90 +661,6 @@ pub fn call_min<'ctx>( } } -/// Invokes the `np_min` builtin function. -pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - a: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_min"; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (a_ty, a) = a; - - Ok(match a { - BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ctx.primitives.float, - ] - .iter() - .any(|ty| ctx.unifier.unioned(a_ty, *ty))); - - a - } - - BasicValueEnum::PointerValue(n) - if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - - let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); - let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx - .builder - .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") - .unwrap(); - - ctx.make_assert( - generator, - n_sz_eqz, - "0:ValueError", - "zero-size array to reduction operation minimum which has no identity", - [None, None, None], - ctx.current_loc, - ); - } - - let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - unsafe { - let identity = - n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_store(accumulator_addr, identity).unwrap(); - } - - gen_for_callback_incrementing( - generator, - ctx, - llvm_usize.const_int(1, false), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; - - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)); - ctx.builder.build_store(accumulator_addr, result).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - accumulator - } - - _ => unsupported_type(ctx, FN_NAME, &[a_ty]), - }) -} - /// Invokes the `np_minimum` builtin function. pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, @@ -877,19 +793,21 @@ pub fn call_max<'ctx>( } } -/// Invokes the `np_max` builtin function. -pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( +/// Invokes the np_max, np_min, np_argmax, np_argmin functions +/// * `fn_name`: Can be one of "np_argmin", "np_argmax", "np_max", "np_min" +pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, a: (Type, BasicValueEnum<'ctx>), + fn_name: &str, ) -> Result, String> { - const FN_NAME: &str = "np_max"; + debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name)); + let llvm_int64 = ctx.ctx.i64_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let (a_ty, a) = a; - - Ok(match a { + Ok( match a { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { debug_assert!([ ctx.primitives.bool, @@ -901,14 +819,17 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( ] .iter() .any(|ty| ctx.unifier.unioned(a_ty, *ty))); - - a + + match fn_name { + "np_argmin" | "np_argmax" => llvm_int64.const_zero().into(), + "np_max" | "np_min" => a, + _ => unreachable!() + } } - BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); @@ -923,41 +844,71 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( generator, n_sz_eqz, "0:ValueError", - "zero-size array to reduction operation minimum which has no identity", + format!("zero-size array to reduction operation {}", fn_name).as_str(), [None, None, None], ctx.current_loc, ); } let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; + let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?; + unsafe { let identity = n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); ctx.builder.build_store(accumulator_addr, identity).unwrap(); + ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap(); } gen_for_callback_incrementing( generator, ctx, - llvm_usize.const_int(1, false), + llvm_int64.const_int(1, false), (n_sz, false), - |generator, ctx, _, idx| { + |generator, ctx, _, idx,| { let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)); + let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); + + let result = match fn_name { + "np_argmin" | "np_min" => call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)), + "np_argmax" | "np_max" => call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)), + _ => unreachable!() + }; + + let updated_idx = match (accumulator, result){ + (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { + ctx.builder.build_select( + ctx.builder.build_int_compare(IntPredicate::NE,m, n, "").unwrap(), + idx.into(), + cur_idx, + "").unwrap() + }, + (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { + ctx.builder.build_select( + ctx.builder.build_float_compare(FloatPredicate::ONE,m, n, "").unwrap(), + idx.into(), + cur_idx, + "").unwrap() + }, + _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]), + }; + ctx.builder.build_store(res_idx, updated_idx).unwrap(); ctx.builder.build_store(accumulator_addr, result).unwrap(); Ok(()) }, - llvm_usize.const_int(1, false), + llvm_int64.const_int(1, false), )?; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - accumulator + match fn_name { + "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(), + "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(), + _ => unreachable!() + } } - _ => unsupported_type(ctx, FN_NAME, &[a_ty]), + _ => unsupported_type(ctx, fn_name, &[a_ty]) }) }