kernel/api: add linalg.det and linalg.matrix_power functions

This commit is contained in:
abdul124 2024-07-31 18:00:00 +08:00
parent 67ad4ea14a
commit 9c82781217
3 changed files with 89 additions and 46 deletions

View File

@ -467,5 +467,7 @@ macro_rules! artiq_raise {
$crate::eh_artiq::raise(&exn) $crate::eh_artiq::raise(&exn)
} }
}}; }};
($name:expr, $message:expr) => {{ artiq_raise!($name, $message, 0, 0, 0) }}; ($name:expr, $message:expr) => {{
artiq_raise!($name, $message, 0, 0, 0)
}};
} }

View File

@ -7,10 +7,12 @@ use log::{info, warn};
#[cfg(has_drtio)] #[cfg(has_drtio)]
use super::subkernel; use super::subkernel;
use super::{cache, use super::{
cache,
core1::rtio_get_destination_status, core1::rtio_get_destination_status,
dma, linalg, dma, linalg,
rpc::{rpc_recv, rpc_send, rpc_send_async}}; rpc::{rpc_recv, rpc_send, rpc_send_async},
};
use crate::{eh_artiq, i2c, rtio}; use crate::{eh_artiq, i2c, rtio};
extern "C" { extern "C" {
@ -321,13 +323,14 @@ pub fn resolve(required: &[u8]) -> Option<u32> {
}, },
// linalg // linalg
api!(np_dot = linalg::np_dot),
api!(np_linalg_matmul = linalg::np_linalg_matmul), api!(np_linalg_matmul = linalg::np_linalg_matmul),
api!(np_linalg_cholesky = linalg::np_linalg_cholesky), api!(np_linalg_cholesky = linalg::np_linalg_cholesky),
api!(np_linalg_qr = linalg::np_linalg_qr), api!(np_linalg_qr = linalg::np_linalg_qr),
api!(np_linalg_svd = linalg::np_linalg_svd), api!(np_linalg_svd = linalg::np_linalg_svd),
api!(np_linalg_inv = linalg::np_linalg_inv), api!(np_linalg_inv = linalg::np_linalg_inv),
api!(np_linalg_pinv = linalg::np_linalg_pinv), api!(np_linalg_pinv = linalg::np_linalg_pinv),
api!(np_linalg_matrix_power = linalg::np_linalg_matrix_power),
api!(np_linalg_det = linalg::np_linalg_det),
api!(sp_linalg_lu = linalg::sp_linalg_lu), api!(sp_linalg_lu = linalg::sp_linalg_lu),
api!(sp_linalg_schur = linalg::sp_linalg_schur), api!(sp_linalg_schur = linalg::sp_linalg_schur),
api!(sp_linalg_hessenberg = linalg::sp_linalg_hessenberg), api!(sp_linalg_hessenberg = linalg::sp_linalg_hessenberg),

View File

@ -25,38 +25,6 @@ impl InputMatrix {
} }
} }
/// # Safety
///
/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
if !(mat1.ndims == 1 && mat2.ndims == 1) {
let err_msg = format!(
"expected 1D Vector Input, but received {}-D and {}-D input",
mat1.ndims, mat2.ndims
);
artiq_raise!("ValueError", err_msg);
}
let dim1 = (*mat1).get_dims();
let dim2 = (*mat2).get_dims();
if dim1[0] != dim2[0] {
let err_msg = format!("shapes ({},) and ({},) not aligned", dim1[0], dim2[0]);
artiq_raise!("ValueError", err_msg);
}
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2);
matrix1.dot(&matrix2)
}
/// # Safety /// # Safety
/// ///
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order /// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
@ -107,7 +75,7 @@ pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut In
let out = out.as_mut().unwrap(); let out = out.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
@ -146,7 +114,7 @@ pub unsafe extern "C" fn np_linalg_qr(mat1: *mut InputMatrix, out_q: *mut InputM
let out_r = out_r.as_mut().unwrap(); let out_r = out_r.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
@ -185,7 +153,7 @@ pub unsafe extern "C" fn np_linalg_svd(
let outvh = outvh.as_mut().unwrap(); let outvh = outvh.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
@ -215,7 +183,7 @@ pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMa
let out = out.as_mut().unwrap(); let out = out.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
let dim1 = (*mat1).get_dims(); let dim1 = (*mat1).get_dims();
@ -249,7 +217,7 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
let out = out.as_mut().unwrap(); let out = out.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
let dim1 = (*mat1).get_dims(); let dim1 = (*mat1).get_dims();
@ -271,6 +239,76 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
} }
} }
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matrix_power(mat1: *mut InputMatrix, mat2: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
artiq_raise!("ValueError", err_msg);
}
let dim1 = (*mat1).get_dims();
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
let power = power[0];
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let mut abs_power = power;
if abs_power < 0.0 {
abs_power = abs_power * -1.0;
}
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix1.is_square() {
let err_msg = format!(
"last 2 dimensions of the array must be square: {0} != {1}",
dim1[0], dim1[1]
);
artiq_raise!("LinAlgError", err_msg);
}
let mut result = matrix1.pow(abs_power as u32);
if power < 0.0 {
if !matrix1.is_invertible() {
artiq_raise!("LinAlgError", "no inverse for Singular Matrix");
}
result = result.try_inverse().unwrap();
}
out_slice.copy_from_slice(result.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
artiq_raise!("ValueError", err_msg);
}
let dim1 = (*mat1).get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix.is_square() {
let err_msg = format!(
"last 2 dimensions of the array must be square: {0} != {1}",
dim1[0], dim1[1]
);
artiq_raise!("LinAlgError", err_msg);
}
out_slice[0] = matrix.determinant();
}
/// # Safety /// # Safety
/// ///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order /// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
@ -281,7 +319,7 @@ pub unsafe extern "C" fn sp_linalg_lu(mat1: *mut InputMatrix, out_l: *mut InputM
let out_u = out_u.as_mut().unwrap(); let out_u = out_u.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
@ -310,7 +348,7 @@ pub unsafe extern "C" fn sp_linalg_schur(mat1: *mut InputMatrix, out_t: *mut Inp
let out_z = out_z.as_mut().unwrap(); let out_z = out_z.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
@ -352,7 +390,7 @@ pub unsafe extern "C" fn sp_linalg_hessenberg(
let out_q = out_q.as_mut().unwrap(); let out_q = out_q.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }