From 9c8278121768de16fa608c26e1a5e5b90be3ca24 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Wed, 31 Jul 2024 18:00:00 +0800 Subject: [PATCH] kernel/api: add linalg.det and linalg.matrix_power functions --- src/libksupport/src/eh_artiq.rs | 4 +- src/libksupport/src/kernel/api.rs | 13 +-- src/libksupport/src/kernel/linalg.rs | 118 ++++++++++++++++++--------- 3 files changed, 89 insertions(+), 46 deletions(-) diff --git a/src/libksupport/src/eh_artiq.rs b/src/libksupport/src/eh_artiq.rs index 6f159ac..0ca73d7 100644 --- a/src/libksupport/src/eh_artiq.rs +++ b/src/libksupport/src/eh_artiq.rs @@ -467,5 +467,7 @@ macro_rules! artiq_raise { $crate::eh_artiq::raise(&exn) } }}; - ($name:expr, $message:expr) => {{ artiq_raise!($name, $message, 0, 0, 0) }}; + ($name:expr, $message:expr) => {{ + artiq_raise!($name, $message, 0, 0, 0) + }}; } diff --git a/src/libksupport/src/kernel/api.rs b/src/libksupport/src/kernel/api.rs index 6cf6d1f..b92f6a5 100644 --- a/src/libksupport/src/kernel/api.rs +++ b/src/libksupport/src/kernel/api.rs @@ -7,10 +7,12 @@ use log::{info, warn}; #[cfg(has_drtio)] use super::subkernel; -use super::{cache, - core1::rtio_get_destination_status, - dma, linalg, - rpc::{rpc_recv, rpc_send, rpc_send_async}}; +use super::{ + cache, + core1::rtio_get_destination_status, + dma, linalg, + rpc::{rpc_recv, rpc_send, rpc_send_async}, +}; use crate::{eh_artiq, i2c, rtio}; extern "C" { @@ -321,13 +323,14 @@ pub fn resolve(required: &[u8]) -> Option { }, // linalg - api!(np_dot = linalg::np_dot), api!(np_linalg_matmul = linalg::np_linalg_matmul), api!(np_linalg_cholesky = linalg::np_linalg_cholesky), api!(np_linalg_qr = linalg::np_linalg_qr), api!(np_linalg_svd = linalg::np_linalg_svd), api!(np_linalg_inv = linalg::np_linalg_inv), api!(np_linalg_pinv = linalg::np_linalg_pinv), + api!(np_linalg_matrix_power = linalg::np_linalg_matrix_power), + api!(np_linalg_det = linalg::np_linalg_det), api!(sp_linalg_lu = linalg::sp_linalg_lu), api!(sp_linalg_schur = linalg::sp_linalg_schur), api!(sp_linalg_hessenberg = linalg::sp_linalg_hessenberg), diff --git a/src/libksupport/src/kernel/linalg.rs b/src/libksupport/src/kernel/linalg.rs index 2529008..abef0a8 100644 --- a/src/libksupport/src/kernel/linalg.rs +++ b/src/libksupport/src/kernel/linalg.rs @@ -25,38 +25,6 @@ impl InputMatrix { } } -/// # Safety -/// -/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order -#[no_mangle] -pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 { - let mat1 = mat1.as_mut().unwrap(); - let mat2 = mat2.as_mut().unwrap(); - - if !(mat1.ndims == 1 && mat2.ndims == 1) { - let err_msg = format!( - "expected 1D Vector Input, but received {}-D and {}-D input", - mat1.ndims, mat2.ndims - ); - artiq_raise!("ValueError", err_msg); - } - - let dim1 = (*mat1).get_dims(); - let dim2 = (*mat2).get_dims(); - - if dim1[0] != dim2[0] { - let err_msg = format!("shapes ({},) and ({},) not aligned", dim1[0], dim2[0]); - artiq_raise!("ValueError", err_msg); - } - let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) }; - let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) }; - - let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1); - let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2); - - matrix1.dot(&matrix2) -} - /// # Safety /// /// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order @@ -107,7 +75,7 @@ pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut In let out = out.as_mut().unwrap(); if mat1.ndims != 2 { - let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); artiq_raise!("ValueError", err_msg); } @@ -146,7 +114,7 @@ pub unsafe extern "C" fn np_linalg_qr(mat1: *mut InputMatrix, out_q: *mut InputM let out_r = out_r.as_mut().unwrap(); if mat1.ndims != 2 { - let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); artiq_raise!("ValueError", err_msg); } @@ -185,7 +153,7 @@ pub unsafe extern "C" fn np_linalg_svd( let outvh = outvh.as_mut().unwrap(); if mat1.ndims != 2 { - let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); artiq_raise!("ValueError", err_msg); } @@ -215,7 +183,7 @@ pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMa let out = out.as_mut().unwrap(); if mat1.ndims != 2 { - let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); artiq_raise!("ValueError", err_msg); } let dim1 = (*mat1).get_dims(); @@ -249,7 +217,7 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM let out = out.as_mut().unwrap(); if mat1.ndims != 2 { - let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); artiq_raise!("ValueError", err_msg); } let dim1 = (*mat1).get_dims(); @@ -271,6 +239,76 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM } } +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_matrix_power(mat1: *mut InputMatrix, mat2: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let mat2 = mat2.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); + artiq_raise!("ValueError", err_msg); + } + + let dim1 = (*mat1).get_dims(); + let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) }; + let power = power[0]; + let outdim = out.get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let mut abs_power = power; + if abs_power < 0.0 { + abs_power = abs_power * -1.0; + } + let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + if !matrix1.is_square() { + let err_msg = format!( + "last 2 dimensions of the array must be square: {0} != {1}", + dim1[0], dim1[1] + ); + artiq_raise!("LinAlgError", err_msg); + } + let mut result = matrix1.pow(abs_power as u32); + + if power < 0.0 { + if !matrix1.is_invertible() { + artiq_raise!("LinAlgError", "no inverse for Singular Matrix"); + } + result = result.try_inverse().unwrap(); + } + out_slice.copy_from_slice(result.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); + artiq_raise!("ValueError", err_msg); + } + let dim1 = (*mat1).get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + if !matrix.is_square() { + let err_msg = format!( + "last 2 dimensions of the array must be square: {0} != {1}", + dim1[0], dim1[1] + ); + artiq_raise!("LinAlgError", err_msg); + } + out_slice[0] = matrix.determinant(); +} + /// # Safety /// /// `mat1` should point to a valid 2DArray of `f64` floats in row-major order @@ -281,7 +319,7 @@ pub unsafe extern "C" fn sp_linalg_lu(mat1: *mut InputMatrix, out_l: *mut InputM let out_u = out_u.as_mut().unwrap(); if mat1.ndims != 2 { - let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); artiq_raise!("ValueError", err_msg); } @@ -310,7 +348,7 @@ pub unsafe extern "C" fn sp_linalg_schur(mat1: *mut InputMatrix, out_t: *mut Inp let out_z = out_z.as_mut().unwrap(); if mat1.ndims != 2 { - let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); artiq_raise!("ValueError", err_msg); } @@ -352,7 +390,7 @@ pub unsafe extern "C" fn sp_linalg_hessenberg( let out_q = out_q.as_mut().unwrap(); if mat1.ndims != 2 { - let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); artiq_raise!("ValueError", err_msg); }