core: add np_linalg_det and np_linalg_matrix_power functions
This commit is contained in:
parent
54f883f0a5
commit
1c72698d02
@ -3,7 +3,9 @@ use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
|
|||||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
|
use crate::codegen::classes::{
|
||||||
|
NDArrayValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
|
};
|
||||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||||
@ -2196,6 +2198,104 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_linalg_matrix_power` linalg function
|
||||||
|
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
x2: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let (x2_ty, x2) = x2;
|
||||||
|
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||||
|
let n2_array = numpy::create_ndarray_const_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&[llvm_usize.const_int(1, false)],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
unsafe {
|
||||||
|
n2_array.data().set_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_zero(),
|
||||||
|
n2.as_basic_value_enum(),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
let n2_array = n2_array.as_base_value().as_basic_value_enum();
|
||||||
|
|
||||||
|
let outdim0 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let outdim1 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
|
||||||
|
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
|
||||||
|
Ok(out)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_linalg_det` linalg function
|
||||||
|
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
if let BasicValueEnum::PointerValue(_) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||||
|
let out = numpy::create_ndarray_const_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&[llvm_usize.const_int(1, false)],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
|
||||||
|
let res =
|
||||||
|
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||||
|
Ok(res)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Invokes the `sp_linalg_schur` linalg function
|
/// Invokes the `sp_linalg_schur` linalg function
|
||||||
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -185,6 +185,8 @@ generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
|
|||||||
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
|
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
|
||||||
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
|
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
|
||||||
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
|
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
|
||||||
|
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
|
||||||
|
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
|
||||||
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
|
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
|
||||||
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
|
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
|
||||||
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
|
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
|
||||||
|
@ -2426,7 +2426,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
/// For matrix multiplication use `np_matmul`
|
/// For matrix multiplication use `np_matmul`
|
||||||
///
|
///
|
||||||
/// The input `NDArray` are flattened and treated as 1D
|
/// The input `NDArray` are flattened and treated as 1D
|
||||||
/// The operation is equivalent to np.dot(arr1.ravel(), arr2.ravel())
|
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
|
||||||
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -568,6 +568,8 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
| PrimDef::FunNpLinalgSvd
|
| PrimDef::FunNpLinalgSvd
|
||||||
| PrimDef::FunNpLinalgInv
|
| PrimDef::FunNpLinalgInv
|
||||||
| PrimDef::FunNpLinalgPinv
|
| PrimDef::FunNpLinalgPinv
|
||||||
|
| PrimDef::FunNpLinalgMatrixPower
|
||||||
|
| PrimDef::FunNpLinalgDet
|
||||||
| PrimDef::FunSpLinalgLu
|
| PrimDef::FunSpLinalgLu
|
||||||
| PrimDef::FunSpLinalgSchur
|
| PrimDef::FunSpLinalgSchur
|
||||||
| PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim),
|
| PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim),
|
||||||
@ -1954,6 +1956,8 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
PrimDef::FunNpLinalgSvd,
|
PrimDef::FunNpLinalgSvd,
|
||||||
PrimDef::FunNpLinalgInv,
|
PrimDef::FunNpLinalgInv,
|
||||||
PrimDef::FunNpLinalgPinv,
|
PrimDef::FunNpLinalgPinv,
|
||||||
|
PrimDef::FunNpLinalgMatrixPower,
|
||||||
|
PrimDef::FunNpLinalgDet,
|
||||||
PrimDef::FunSpLinalgLu,
|
PrimDef::FunSpLinalgLu,
|
||||||
PrimDef::FunSpLinalgSchur,
|
PrimDef::FunSpLinalgSchur,
|
||||||
PrimDef::FunSpLinalgHessenberg,
|
PrimDef::FunSpLinalgHessenberg,
|
||||||
@ -2072,10 +2076,39 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
_ => {
|
PrimDef::FunNpLinalgMatrixPower => create_fn_by_codegen(
|
||||||
println!("{:?}", prim.name());
|
self.unifier,
|
||||||
unreachable!()
|
&VarMap::new(),
|
||||||
}
|
prim.name(),
|
||||||
|
self.ndarray_float_2d,
|
||||||
|
&[(self.ndarray_float_2d, "x1"), (self.primitives.int32, "power")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
let x2_ty = fun.0.args[1].ty;
|
||||||
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
|
||||||
|
Ok(Some(builtin_fns::call_np_linalg_matrix_power(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(x1_ty, x1_val),
|
||||||
|
(x2_ty, x2_val),
|
||||||
|
)?))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
PrimDef::FunNpLinalgDet => create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&VarMap::new(),
|
||||||
|
prim.name(),
|
||||||
|
self.primitives.float,
|
||||||
|
&[(self.ndarray_float_2d, "x1")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
Ok(Some(builtin_fns::call_np_linalg_det(generator, ctx, (x1_ty, x1_val))?))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,6 +110,8 @@ pub enum PrimDef {
|
|||||||
FunNpLinalgSvd,
|
FunNpLinalgSvd,
|
||||||
FunNpLinalgInv,
|
FunNpLinalgInv,
|
||||||
FunNpLinalgPinv,
|
FunNpLinalgPinv,
|
||||||
|
FunNpLinalgMatrixPower,
|
||||||
|
FunNpLinalgDet,
|
||||||
FunSpLinalgLu,
|
FunSpLinalgLu,
|
||||||
FunSpLinalgSchur,
|
FunSpLinalgSchur,
|
||||||
FunSpLinalgHessenberg,
|
FunSpLinalgHessenberg,
|
||||||
@ -295,6 +297,8 @@ impl PrimDef {
|
|||||||
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
|
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
|
||||||
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
|
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
|
||||||
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
|
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
|
||||||
|
PrimDef::FunNpLinalgMatrixPower => fun("np_linalg_matrix_power", None),
|
||||||
|
PrimDef::FunNpLinalgDet => fun("np_linalg_det", None),
|
||||||
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
|
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
|
||||||
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
|
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
|
||||||
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
|
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
|
||||||
|
@ -237,6 +237,8 @@ def patch(module):
|
|||||||
module.np_linalg_svd = np.linalg.svd
|
module.np_linalg_svd = np.linalg.svd
|
||||||
module.np_linalg_inv = np.linalg.inv
|
module.np_linalg_inv = np.linalg.inv
|
||||||
module.np_linalg_pinv = np.linalg.pinv
|
module.np_linalg_pinv = np.linalg.pinv
|
||||||
|
module.np_linalg_matrix_power = np.linalg.matrix_power
|
||||||
|
module.np_linalg_det = np.linalg.det
|
||||||
|
|
||||||
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
|
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
|
||||||
module.sp_linalg_schur = sp.linalg.schur
|
module.sp_linalg_schur = sp.linalg.schur
|
||||||
|
@ -267,6 +267,76 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_matrix_power(
|
||||||
|
mat1: *mut InputMatrix,
|
||||||
|
mat2: *mut InputMatrix,
|
||||||
|
out: *mut InputMatrix,
|
||||||
|
) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let mat2 = mat2.as_mut().unwrap();
|
||||||
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D", mat1.ndims);
|
||||||
|
report_error("ValueError", "np_linalg_matrix_power", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
|
||||||
|
let power = power[0];
|
||||||
|
let outdim = out.get_dims();
|
||||||
|
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
|
||||||
|
let abs_pow = power.abs();
|
||||||
|
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
let mut result = matrix1.pow(abs_pow as u32);
|
||||||
|
|
||||||
|
if power < 0.0 {
|
||||||
|
if !result.is_invertible() {
|
||||||
|
report_error(
|
||||||
|
"LinAlgError",
|
||||||
|
"np_linalg_inv",
|
||||||
|
file!(),
|
||||||
|
line!(),
|
||||||
|
column!(),
|
||||||
|
"no inverse for Singular Matrix",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
result = result.try_inverse().unwrap();
|
||||||
|
}
|
||||||
|
out_slice.copy_from_slice(result.transpose().as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
report_error("ValueError", "np_linalg_det", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
|
||||||
|
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
if !matrix.is_square() {
|
||||||
|
let err_msg =
|
||||||
|
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
|
||||||
|
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
out_slice[0] = matrix.determinant();
|
||||||
|
}
|
||||||
|
|
||||||
/// # Safety
|
/// # Safety
|
||||||
///
|
///
|
||||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
@ -1518,6 +1518,20 @@ def test_ndarray_pinv():
|
|||||||
output_ndarray_float_2(x)
|
output_ndarray_float_2(x)
|
||||||
output_ndarray_float_2(y)
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_matrix_power():
|
||||||
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||||
|
y = np_linalg_matrix_power(x, -9)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_det():
|
||||||
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||||
|
y = np_linalg_det(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_float64(y)
|
||||||
|
|
||||||
def test_ndarray_schur():
|
def test_ndarray_schur():
|
||||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||||
t, z = sp_linalg_schur(x)
|
t, z = sp_linalg_schur(x)
|
||||||
@ -1751,6 +1765,8 @@ def run() -> int32:
|
|||||||
test_ndarray_svd()
|
test_ndarray_svd()
|
||||||
test_ndarray_linalg_inv()
|
test_ndarray_linalg_inv()
|
||||||
test_ndarray_pinv()
|
test_ndarray_pinv()
|
||||||
|
test_ndarray_matrix_power()
|
||||||
|
test_ndarray_det()
|
||||||
test_ndarray_lu()
|
test_ndarray_lu()
|
||||||
test_ndarray_schur()
|
test_ndarray_schur()
|
||||||
test_ndarray_hessenberg()
|
test_ndarray_hessenberg()
|
||||||
|
Loading…
Reference in New Issue
Block a user