From 9e0f636d2a9434a744ccf42cab7c3f9c22ce1006 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 16:17:52 +0800 Subject: [PATCH] core: categorize np_{transpose,reshape} as 'view functions' --- nac3core/src/toplevel/builtins.rs | 110 +++++++++++++------------- nac3core/src/toplevel/helper.rs | 12 ++- nac3standalone/demo/interpret_demo.py | 6 +- 3 files changed, 67 insertions(+), 61 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index c95c618..dadc236 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -518,6 +518,10 @@ impl<'a> BuiltinBuilder<'a> { self.build_ndarray_property_getter_function(prim) } + PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { + self.build_ndarray_view_function(prim) + } + PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -583,10 +587,6 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpHypot | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), - PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { - self.build_np_sp_ndarray_function(prim) - } - PrimDef::FunNpDot | PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr @@ -1464,6 +1464,57 @@ impl<'a> BuiltinBuilder<'a> { } } + /// Build np/sp functions that take as input `NDArray` only + fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); + + match prim { + PrimDef::FunNpTranspose => { + let ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.ndarray_num_ty], + Some("T".into()), + None, + ); + create_fn_by_codegen( + self.unifier, + &into_var_map([ndarray_ty]), + prim.name(), + self.ndarray_num_ty, + &[(self.ndarray_num_ty, "x")], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) + }), + ) + } + + // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and + // the `param_ty` for `create_fn_by_codegen`. + // + // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking + // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], + // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. + PrimDef::FunNpReshape => create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_num_ty, + &[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + }), + ), + + _ => unreachable!(), + } + } + /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; @@ -1951,57 +2002,6 @@ impl<'a> BuiltinBuilder<'a> { } } - /// Build np/sp functions that take as input `NDArray` only - fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); - - match prim { - PrimDef::FunNpTranspose => { - let ndarray_ty = self.unifier.get_fresh_var_with_range( - &[self.ndarray_num_ty], - Some("T".into()), - None, - ); - create_fn_by_codegen( - self.unifier, - &into_var_map([ndarray_ty]), - prim.name(), - self.ndarray_num_ty, - &[(self.ndarray_num_ty, "x")], - Box::new(move |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) - }), - ) - } - - // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and - // the `param_ty` for `create_fn_by_codegen`. - // - // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking - // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], - // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => create_fn_by_codegen( - self.unifier, - &VarMap::new(), - prim.name(), - self.ndarray_num_ty, - &[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], - Box::new(move |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - }), - ), - - _ => unreachable!(), - } - } - /// Build `np_linalg` and `sp_linalg` functions /// /// The input to these functions must be floating point `NDArray` diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 46e4555..946e608 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -59,6 +59,10 @@ pub enum PrimDef { FunNpShape, FunNpStrides, + // NumPy ndarray view functions + FunNpTranspose, + FunNpReshape, + // Miscellaneous NumPy & SciPy functions FunNpRound, FunNpFloor, @@ -106,8 +110,6 @@ pub enum PrimDef { FunNpLdExp, FunNpHypot, FunNpNextAfter, - FunNpTranspose, - FunNpReshape, // Linalg functions FunNpDot, @@ -250,6 +252,10 @@ impl PrimDef { PrimDef::FunNpShape => fun("np_shape", None), PrimDef::FunNpStrides => fun("np_strides", None), + // NumPy NDArray view functions + PrimDef::FunNpTranspose => fun("np_transpose", None), + PrimDef::FunNpReshape => fun("np_reshape", None), + // Miscellaneous NumPy & SciPy functions PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpFloor => fun("np_floor", None), @@ -297,8 +303,6 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), - PrimDef::FunNpTranspose => fun("np_transpose", None), - PrimDef::FunNpReshape => fun("np_reshape", None), // Linalg functions PrimDef::FunNpDot => fun("np_dot", None), diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index ca17c3d..56c6126 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -179,6 +179,10 @@ def patch(module): module.np_identity = np.identity module.np_array = np.array + # NumPy NDArray view functions + module.np_transpose = np.transpose + module.np_reshape = np.reshape + # NumPy NDArray property getters module.np_size = np.size module.np_shape = np.shape @@ -223,8 +227,6 @@ def patch(module): module.np_ldexp = np.ldexp module.np_hypot = np.hypot module.np_nextafter = np.nextafter - module.np_transpose = np.transpose - module.np_reshape = np.reshape # SciPy Math functions module.sp_spec_erf = special.erf