From 1c72698d02789ad9d3a6ed2f4281483b14d60554 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Wed, 31 Jul 2024 18:02:54 +0800 Subject: [PATCH] core: add np_linalg_det and np_linalg_matrix_power functions --- nac3core/src/codegen/builtin_fns.rs | 102 +++++++++++++++++++++++++- nac3core/src/codegen/extern_fns.rs | 2 + nac3core/src/codegen/numpy.rs | 2 +- nac3core/src/toplevel/builtins.rs | 41 ++++++++++- nac3core/src/toplevel/helper.rs | 4 + nac3standalone/demo/interpret_demo.py | 2 + nac3standalone/demo/linalg/src/lib.rs | 70 ++++++++++++++++++ nac3standalone/demo/src/ndarray.py | 16 ++++ 8 files changed, 233 insertions(+), 6 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 4bc7c913..ad15d61b 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -3,7 +3,9 @@ use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use itertools::Itertools; -use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor}; +use crate::codegen::classes::{ + NDArrayValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, +}; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; @@ -2196,6 +2198,104 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( } } +/// Invokes the `np_linalg_matrix_power` linalg function +pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_matrix_power"; + let (x1_ty, x1) = x1; + let (x2_ty, x2) = x2; + let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap(); + + let llvm_usize = generator.get_size_type(ctx.ctx); + if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + // Changing second parameter to a `NDArray` for uniformity in function call + let n2_array = numpy::create_ndarray_const_shape( + generator, + ctx, + elem_ty, + &[llvm_usize.const_int(1, false)], + ) + .unwrap(); + unsafe { + n2_array.data().set_unchecked( + ctx, + generator, + &llvm_usize.const_zero(), + n2.as_basic_value_enum(), + ); + }; + let n2_array = n2_array.as_base_value().as_basic_value_enum(); + + let outdim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let outdim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + + let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None); + Ok(out) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + } +} + +/// Invokes the `np_linalg_det` linalg function +pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_matrix_power"; + let (x1_ty, x1) = x1; + + let llvm_usize = generator.get_size_type(ctx.ctx); + if let BasicValueEnum::PointerValue(_) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unsupported_type(ctx, FN_NAME, &[x1_ty]); + }; + + // Changing second parameter to a `NDArray` for uniformity in function call + let out = numpy::create_ndarray_const_shape( + generator, + ctx, + elem_ty, + &[llvm_usize.const_int(1, false)], + ) + .unwrap(); + extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None); + let res = + unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; + Ok(res) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + /// Invokes the `sp_linalg_schur` linalg function pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, diff --git a/nac3core/src/codegen/extern_fns.rs b/nac3core/src/codegen/extern_fns.rs index 089a94f5..e181f57f 100644 --- a/nac3core/src/codegen/extern_fns.rs +++ b/nac3core/src/codegen/extern_fns.rs @@ -185,6 +185,8 @@ generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3); generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4); generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2); generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2); +generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3); +generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2); generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3); generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3); generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3); diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index f4299ff5..6ebffe80 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -2426,7 +2426,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( /// For matrix multiplication use `np_matmul` /// /// The input `NDArray` are flattened and treated as 1D -/// The operation is equivalent to np.dot(arr1.ravel(), arr2.ravel()) +/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())` pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 746ea052..49692baf 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -568,6 +568,8 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpLinalgSvd | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv + | PrimDef::FunNpLinalgMatrixPower + | PrimDef::FunNpLinalgDet | PrimDef::FunSpLinalgLu | PrimDef::FunSpLinalgSchur | PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim), @@ -1954,6 +1956,8 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::FunNpLinalgSvd, PrimDef::FunNpLinalgInv, PrimDef::FunNpLinalgPinv, + PrimDef::FunNpLinalgMatrixPower, + PrimDef::FunNpLinalgDet, PrimDef::FunSpLinalgLu, PrimDef::FunSpLinalgSchur, PrimDef::FunSpLinalgHessenberg, @@ -2072,10 +2076,39 @@ impl<'a> BuiltinBuilder<'a> { }), ) } - _ => { - println!("{:?}", prim.name()); - unreachable!() - } + PrimDef::FunNpLinalgMatrixPower => create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_float_2d, + &[(self.ndarray_float_2d, "x1"), (self.primitives.int32, "power")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + + Ok(Some(builtin_fns::call_np_linalg_matrix_power( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }), + ), + PrimDef::FunNpLinalgDet => create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.primitives.float, + &[(self.ndarray_float_2d, "x1")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + Ok(Some(builtin_fns::call_np_linalg_det(generator, ctx, (x1_ty, x1_val))?)) + }), + ), + _ => unreachable!(), } } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index ae17e62c..be467182 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -110,6 +110,8 @@ pub enum PrimDef { FunNpLinalgSvd, FunNpLinalgInv, FunNpLinalgPinv, + FunNpLinalgMatrixPower, + FunNpLinalgDet, FunSpLinalgLu, FunSpLinalgSchur, FunSpLinalgHessenberg, @@ -295,6 +297,8 @@ impl PrimDef { PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None), PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None), PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None), + PrimDef::FunNpLinalgMatrixPower => fun("np_linalg_matrix_power", None), + PrimDef::FunNpLinalgDet => fun("np_linalg_det", None), PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None), PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None), PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None), diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index cb9693ff..15bf0853 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -237,6 +237,8 @@ def patch(module): module.np_linalg_svd = np.linalg.svd module.np_linalg_inv = np.linalg.inv module.np_linalg_pinv = np.linalg.pinv + module.np_linalg_matrix_power = np.linalg.matrix_power + module.np_linalg_det = np.linalg.det module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True) module.sp_linalg_schur = sp.linalg.schur diff --git a/nac3standalone/demo/linalg/src/lib.rs b/nac3standalone/demo/linalg/src/lib.rs index 76933c13..d4ee030f 100644 --- a/nac3standalone/demo/linalg/src/lib.rs +++ b/nac3standalone/demo/linalg/src/lib.rs @@ -267,6 +267,76 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM } } +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_matrix_power( + mat1: *mut InputMatrix, + mat2: *mut InputMatrix, + out: *mut InputMatrix, +) { + let mat1 = mat1.as_mut().unwrap(); + let mat2 = mat2.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + let err_msg = format!("expected 2D Vector Input, but received {}D", mat1.ndims); + report_error("ValueError", "np_linalg_matrix_power", file!(), line!(), column!(), &err_msg); + } + + let dim1 = (*mat1).get_dims(); + let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) }; + let power = power[0]; + let outdim = out.get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + + let abs_pow = power.abs(); + let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let mut result = matrix1.pow(abs_pow as u32); + + if power < 0.0 { + if !result.is_invertible() { + report_error( + "LinAlgError", + "np_linalg_inv", + file!(), + line!(), + column!(), + "no inverse for Singular Matrix", + ); + } + result = result.try_inverse().unwrap(); + } + out_slice.copy_from_slice(result.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); + report_error("ValueError", "np_linalg_det", file!(), line!(), column!(), &err_msg); + } + let dim1 = (*mat1).get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + if !matrix.is_square() { + let err_msg = + format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]); + report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg); + } + out_slice[0] = matrix.determinant(); +} + /// # Safety /// /// `mat1` should point to a valid 2DArray of `f64` floats in row-major order diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index beb0c1d2..0c7e3233 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1518,6 +1518,20 @@ def test_ndarray_pinv(): output_ndarray_float_2(x) output_ndarray_float_2(y) +def test_ndarray_matrix_power(): + x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]]) + y = np_linalg_matrix_power(x, -9) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_det(): + x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]]) + y = np_linalg_det(x) + + output_ndarray_float_2(x) + output_float64(y) + def test_ndarray_schur(): x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]]) t, z = sp_linalg_schur(x) @@ -1751,6 +1765,8 @@ def run() -> int32: test_ndarray_svd() test_ndarray_linalg_inv() test_ndarray_pinv() + test_ndarray_matrix_power() + test_ndarray_det() test_ndarray_lu() test_ndarray_schur() test_ndarray_hessenberg()