From cea7cade51fe6701447ecae8696d4ea363a97425 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 12 Jul 2024 18:18:28 +0800 Subject: [PATCH] core: add np_argmax/np_argmin functions --- nac3core/src/toplevel/builtins.rs | 50 +++++++++++++++++-------------- nac3core/src/toplevel/helper.rs | 4 +++ 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index f2097dd..7ad22b9 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -510,7 +510,10 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim), - PrimDef::FunNpMin | PrimDef::FunNpMax => self.build_np_min_max_function(prim), + PrimDef::FunNpArgmin + | PrimDef::FunNpArgmax + | PrimDef::FunNpMin + | PrimDef::FunNpMax => self.build_np_max_min_function(prim), PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => { self.build_np_minimum_maximum_function(prim) @@ -1555,39 +1558,42 @@ impl<'a> BuiltinBuilder<'a> { } } - /// Build the functions `np_min()` and `np_max()`. - fn build_np_min_max_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMin, PrimDef::FunNpMax]); + /// Build the functions `np_max()`, `np_min()`, `np_argmax()` and `np_argmin()` + /// Calls `call_numpy_max_min` with the function name + fn build_np_max_min_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpArgmin, PrimDef::FunNpArgmax, PrimDef::FunNpMin, PrimDef::FunNpMax]); - let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None); - let var_map = self - .num_or_ndarray_var_map - .clone() - .into_iter() - .chain(once((ret_ty.id, ret_ty.ty))) - .collect::>(); + let (var_map, ret_ty) = match prim { + PrimDef::FunNpArgmax | PrimDef::FunNpArgmin => { + (self.num_or_ndarray_var_map.clone(), self.primitives.int64) + }, + PrimDef::FunNpMax | PrimDef::FunNpMin => { + let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None); + let var_map = self + .num_or_ndarray_var_map + .clone() + .into_iter() + .chain(once((ret_ty.id, ret_ty.ty))) + .collect::>(); + (var_map, ret_ty.ty) + }, + _ => unreachable!() + }; create_fn_by_codegen( self.unifier, &var_map, prim.name(), - ret_ty.ty, - &[(self.float_or_ndarray_ty.ty, "a")], + ret_ty, + &[(self.num_or_ndarray_ty.ty, "a")], Box::new(move |ctx, _, fun, args, generator| { let a_ty = fun.0.args[0].ty; let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; - - let func = match prim { - PrimDef::FunNpMin => builtin_fns::call_numpy_min, - PrimDef::FunNpMax => builtin_fns::call_numpy_max, - _ => unreachable!(), - }; - - Ok(Some(func(generator, ctx, (a_ty, a))?)) + + Ok(Some(builtin_fns::call_numpy_max_min(generator, ctx, (a_ty, a), &prim.name())?)) }), ) } - /// Build the functions `np_minimum()` and `np_maximum()`. fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef { debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]); diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 1d0c291..5560f41 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -62,9 +62,11 @@ pub enum PrimDef { FunMin, FunNpMin, FunNpMinimum, + FunNpArgmin, FunMax, FunNpMax, FunNpMaximum, + FunNpArgmax, FunAbs, FunNpIsNan, FunNpIsInf, @@ -216,9 +218,11 @@ impl PrimDef { PrimDef::FunMin => fun("min", None), PrimDef::FunNpMin => fun("np_min", None), PrimDef::FunNpMinimum => fun("np_minimum", None), + PrimDef::FunNpArgmin => fun("np_argmin", None), PrimDef::FunMax => fun("max", None), PrimDef::FunNpMax => fun("np_max", None), PrimDef::FunNpMaximum => fun("np_maximum", None), + PrimDef::FunNpArgmax => fun("np_argmax", None), PrimDef::FunAbs => fun("abs", None), PrimDef::FunNpIsNan => fun("np_isnan", None), PrimDef::FunNpIsInf => fun("np_isinf", None),