forked from M-Labs/artiq-zynq
kernel/api: add linalg.det and linalg.matrix_power functions
This commit is contained in:
parent
67ad4ea14a
commit
9c82781217
@ -467,5 +467,7 @@ macro_rules! artiq_raise {
|
|||||||
$crate::eh_artiq::raise(&exn)
|
$crate::eh_artiq::raise(&exn)
|
||||||
}
|
}
|
||||||
}};
|
}};
|
||||||
($name:expr, $message:expr) => {{ artiq_raise!($name, $message, 0, 0, 0) }};
|
($name:expr, $message:expr) => {{
|
||||||
|
artiq_raise!($name, $message, 0, 0, 0)
|
||||||
|
}};
|
||||||
}
|
}
|
||||||
|
@ -7,10 +7,12 @@ use log::{info, warn};
|
|||||||
|
|
||||||
#[cfg(has_drtio)]
|
#[cfg(has_drtio)]
|
||||||
use super::subkernel;
|
use super::subkernel;
|
||||||
use super::{cache,
|
use super::{
|
||||||
|
cache,
|
||||||
core1::rtio_get_destination_status,
|
core1::rtio_get_destination_status,
|
||||||
dma, linalg,
|
dma, linalg,
|
||||||
rpc::{rpc_recv, rpc_send, rpc_send_async}};
|
rpc::{rpc_recv, rpc_send, rpc_send_async},
|
||||||
|
};
|
||||||
use crate::{eh_artiq, i2c, rtio};
|
use crate::{eh_artiq, i2c, rtio};
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@ -321,13 +323,14 @@ pub fn resolve(required: &[u8]) -> Option<u32> {
|
|||||||
},
|
},
|
||||||
|
|
||||||
// linalg
|
// linalg
|
||||||
api!(np_dot = linalg::np_dot),
|
|
||||||
api!(np_linalg_matmul = linalg::np_linalg_matmul),
|
api!(np_linalg_matmul = linalg::np_linalg_matmul),
|
||||||
api!(np_linalg_cholesky = linalg::np_linalg_cholesky),
|
api!(np_linalg_cholesky = linalg::np_linalg_cholesky),
|
||||||
api!(np_linalg_qr = linalg::np_linalg_qr),
|
api!(np_linalg_qr = linalg::np_linalg_qr),
|
||||||
api!(np_linalg_svd = linalg::np_linalg_svd),
|
api!(np_linalg_svd = linalg::np_linalg_svd),
|
||||||
api!(np_linalg_inv = linalg::np_linalg_inv),
|
api!(np_linalg_inv = linalg::np_linalg_inv),
|
||||||
api!(np_linalg_pinv = linalg::np_linalg_pinv),
|
api!(np_linalg_pinv = linalg::np_linalg_pinv),
|
||||||
|
api!(np_linalg_matrix_power = linalg::np_linalg_matrix_power),
|
||||||
|
api!(np_linalg_det = linalg::np_linalg_det),
|
||||||
api!(sp_linalg_lu = linalg::sp_linalg_lu),
|
api!(sp_linalg_lu = linalg::sp_linalg_lu),
|
||||||
api!(sp_linalg_schur = linalg::sp_linalg_schur),
|
api!(sp_linalg_schur = linalg::sp_linalg_schur),
|
||||||
api!(sp_linalg_hessenberg = linalg::sp_linalg_hessenberg),
|
api!(sp_linalg_hessenberg = linalg::sp_linalg_hessenberg),
|
||||||
|
@ -25,38 +25,6 @@ impl InputMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 {
|
|
||||||
let mat1 = mat1.as_mut().unwrap();
|
|
||||||
let mat2 = mat2.as_mut().unwrap();
|
|
||||||
|
|
||||||
if !(mat1.ndims == 1 && mat2.ndims == 1) {
|
|
||||||
let err_msg = format!(
|
|
||||||
"expected 1D Vector Input, but received {}-D and {}-D input",
|
|
||||||
mat1.ndims, mat2.ndims
|
|
||||||
);
|
|
||||||
artiq_raise!("ValueError", err_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
let dim1 = (*mat1).get_dims();
|
|
||||||
let dim2 = (*mat2).get_dims();
|
|
||||||
|
|
||||||
if dim1[0] != dim2[0] {
|
|
||||||
let err_msg = format!("shapes ({},) and ({},) not aligned", dim1[0], dim2[0]);
|
|
||||||
artiq_raise!("ValueError", err_msg);
|
|
||||||
}
|
|
||||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) };
|
|
||||||
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) };
|
|
||||||
|
|
||||||
let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1);
|
|
||||||
let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2);
|
|
||||||
|
|
||||||
matrix1.dot(&matrix2)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
/// # Safety
|
||||||
///
|
///
|
||||||
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
|
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
@ -107,7 +75,7 @@ pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut In
|
|||||||
let out = out.as_mut().unwrap();
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
artiq_raise!("ValueError", err_msg);
|
artiq_raise!("ValueError", err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,7 +114,7 @@ pub unsafe extern "C" fn np_linalg_qr(mat1: *mut InputMatrix, out_q: *mut InputM
|
|||||||
let out_r = out_r.as_mut().unwrap();
|
let out_r = out_r.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
artiq_raise!("ValueError", err_msg);
|
artiq_raise!("ValueError", err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,7 +153,7 @@ pub unsafe extern "C" fn np_linalg_svd(
|
|||||||
let outvh = outvh.as_mut().unwrap();
|
let outvh = outvh.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
artiq_raise!("ValueError", err_msg);
|
artiq_raise!("ValueError", err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,7 +183,7 @@ pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMa
|
|||||||
let out = out.as_mut().unwrap();
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
artiq_raise!("ValueError", err_msg);
|
artiq_raise!("ValueError", err_msg);
|
||||||
}
|
}
|
||||||
let dim1 = (*mat1).get_dims();
|
let dim1 = (*mat1).get_dims();
|
||||||
@ -249,7 +217,7 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
|
|||||||
let out = out.as_mut().unwrap();
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
artiq_raise!("ValueError", err_msg);
|
artiq_raise!("ValueError", err_msg);
|
||||||
}
|
}
|
||||||
let dim1 = (*mat1).get_dims();
|
let dim1 = (*mat1).get_dims();
|
||||||
@ -271,6 +239,76 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_matrix_power(mat1: *mut InputMatrix, mat2: *mut InputMatrix, out: *mut InputMatrix) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let mat2 = mat2.as_mut().unwrap();
|
||||||
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
artiq_raise!("ValueError", err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
|
||||||
|
let power = power[0];
|
||||||
|
let outdim = out.get_dims();
|
||||||
|
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
let mut abs_power = power;
|
||||||
|
if abs_power < 0.0 {
|
||||||
|
abs_power = abs_power * -1.0;
|
||||||
|
}
|
||||||
|
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
if !matrix1.is_square() {
|
||||||
|
let err_msg = format!(
|
||||||
|
"last 2 dimensions of the array must be square: {0} != {1}",
|
||||||
|
dim1[0], dim1[1]
|
||||||
|
);
|
||||||
|
artiq_raise!("LinAlgError", err_msg);
|
||||||
|
}
|
||||||
|
let mut result = matrix1.pow(abs_power as u32);
|
||||||
|
|
||||||
|
if power < 0.0 {
|
||||||
|
if !matrix1.is_invertible() {
|
||||||
|
artiq_raise!("LinAlgError", "no inverse for Singular Matrix");
|
||||||
|
}
|
||||||
|
result = result.try_inverse().unwrap();
|
||||||
|
}
|
||||||
|
out_slice.copy_from_slice(result.transpose().as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
artiq_raise!("ValueError", err_msg);
|
||||||
|
}
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
|
||||||
|
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
if !matrix.is_square() {
|
||||||
|
let err_msg = format!(
|
||||||
|
"last 2 dimensions of the array must be square: {0} != {1}",
|
||||||
|
dim1[0], dim1[1]
|
||||||
|
);
|
||||||
|
artiq_raise!("LinAlgError", err_msg);
|
||||||
|
}
|
||||||
|
out_slice[0] = matrix.determinant();
|
||||||
|
}
|
||||||
|
|
||||||
/// # Safety
|
/// # Safety
|
||||||
///
|
///
|
||||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
@ -281,7 +319,7 @@ pub unsafe extern "C" fn sp_linalg_lu(mat1: *mut InputMatrix, out_l: *mut InputM
|
|||||||
let out_u = out_u.as_mut().unwrap();
|
let out_u = out_u.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
artiq_raise!("ValueError", err_msg);
|
artiq_raise!("ValueError", err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -310,7 +348,7 @@ pub unsafe extern "C" fn sp_linalg_schur(mat1: *mut InputMatrix, out_t: *mut Inp
|
|||||||
let out_z = out_z.as_mut().unwrap();
|
let out_z = out_z.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
artiq_raise!("ValueError", err_msg);
|
artiq_raise!("ValueError", err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -352,7 +390,7 @@ pub unsafe extern "C" fn sp_linalg_hessenberg(
|
|||||||
let out_q = out_q.as_mut().unwrap();
|
let out_q = out_q.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
artiq_raise!("ValueError", err_msg);
|
artiq_raise!("ValueError", err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user