From 9788be129443d617aa68a57b2e993011ffa98faa Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 26 Jul 2024 14:47:58 +0800 Subject: [PATCH] kernel: add linalg functions --- src/Cargo.lock | 219 ++++++--------- src/libksupport/src/eh_artiq.rs | 4 +- src/libksupport/src/kernel/api.rs | 38 +-- src/libksupport/src/kernel/linalg.rs | 381 +++++++++++++++++++++++++++ src/libksupport/src/kernel/mod.rs | 1 + 5 files changed, 475 insertions(+), 168 deletions(-) create mode 100644 src/libksupport/src/kernel/linalg.rs diff --git a/src/Cargo.lock b/src/Cargo.lock index c14d523..0fd99fa 100644 --- a/src/Cargo.lock +++ b/src/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "arrayvec" version = "0.7.4" @@ -246,10 +255,10 @@ dependencies = [ "libsupport_zynq", "log", "log_buffer", + "nalgebra", "nb 0.1.3", "unwind", "vcell", - "nalgebra", "void", ] @@ -383,6 +392,19 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c75de51135344a4f8ed3cfe2720dc27736f7711989703a0b43aadf3753c55577" +[[package]] +name = "nalgebra" +version = "0.32.6" +source = "git+https://git.m-labs.hk/M-labs/nalgebra?rev=dd00f9b#dd00f9b46046e0b931d1b470166db02fd29591be" +dependencies = [ + "approx", + "num-complex", + "num-rational", + "num-traits", + "simba", + "typenum", +] + [[package]] name = "nb" version = "0.1.3" @@ -398,6 +420,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "546c37ac5d9e56f55e73b677106873d9d9f5190605e41a856503623648488cae" +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -409,6 +440,26 @@ dependencies = [ "syn", ] +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.15" @@ -416,8 +467,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", + "libm", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -524,6 +582,18 @@ version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4f410fedcf71af0345d7607d246e7ad15faaadd49d240ee3b24e5dc21a820ac" +[[package]] +name = "simba" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50582927ed6f77e4ac020c057f37a268fc6aebc29225050365aacbb9deeeddc4" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", +] + [[package]] name = "smoltcp" version = "0.7.5" @@ -556,6 +626,12 @@ dependencies = [ "log", ] +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "unicode-ident" version = "1.0.5" @@ -572,147 +648,6 @@ dependencies = [ "libc", ] -[[package]] -name = "nalgebra" -version = "0.32.6" -source = "git+https://git.m-labs.hk/M-labs/nalgebra?rev=dd00f9b#dd00f9b46046e0b931d1b470166db02fd29591be" -dependencies = [ - "approx", - "matrixmultiply", - "nalgebra-macros", - "num-complex", - "num-rational", - "num-traits", - "simba", - "typenum", -] - -[[package]] -name = "approx" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" -dependencies = [ - "num-traits", -] - -[[package]] -name = "matrixmultiply" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" -dependencies = [ - "autocfg", - "rawpointer", -] - -[[package]] -name = "nalgebra-macros" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91761aed67d03ad966ef783ae962ef9bbaca728d2dd7ceb7939ec110fffad998" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "num-complex" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-rational" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" -dependencies = [ - "num-bigint", - "num-integer", - "num-traits", -] - -[[package]] -name = "simba" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50582927ed6f77e4ac020c057f37a268fc6aebc29225050365aacbb9deeeddc4" -dependencies = [ - "approx", - "num-complex", - "num-traits", - "paste", - "wide", -] - -[[package]] -name = "typenum" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" - -[[package]] -name = "rawpointer" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" - -[[package]] -name = "num-bigint" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d047c1062aa51e256408c560894e5251f08925980e53cf1aa5bd00eec6512" -dependencies = [ - "num-integer", - "num-traits", -] - -[[package]] -name = "num-integer" -version = "0.1.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" -dependencies = [ - "num-traits", -] - -[[package]] -name = "paste" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" - -[[package]] -name = "wide" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd89cf484471f953ee84f07c0dff0ea20e9ddf976f03cabdf5dda48b221f22e7" -features = ["no_std"] -dependencies = [ - "bytemuck", - "safe_arch", -] - -[[package]] -name = "bytemuck" -version = "1.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" - -[[package]] -name = "safe_arch" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "794821e4ccb0d9f979512f9c1973480123f9bd62a90d74ab0f9426fcf8f4a529" -dependencies = [ - "bytemuck", -] - [[package]] name = "vcell" version = "0.1.3" diff --git a/src/libksupport/src/eh_artiq.rs b/src/libksupport/src/eh_artiq.rs index 6f159ac..69e8e80 100644 --- a/src/libksupport/src/eh_artiq.rs +++ b/src/libksupport/src/eh_artiq.rs @@ -422,7 +422,7 @@ extern "C" fn stop_fn( } // Must be kept in sync with preallocate_runtime_exception_names() in artiq/language/embedding_map.py -static EXCEPTION_ID_LOOKUP: [(&str, u32); 12] = [ +static EXCEPTION_ID_LOOKUP: [(&str, u32); 14] = [ ("RuntimeError", 0), ("RTIOUnderflow", 1), ("RTIOOverflow", 2), @@ -435,6 +435,8 @@ static EXCEPTION_ID_LOOKUP: [(&str, u32); 12] = [ ("IndexError", 9), ("UnwrapNoneError", 10), ("SubkernelError", 11), + ("ValueError", 12), + ("LinAlgError", 13), ]; pub fn get_exception_id(name: &str) -> u32 { diff --git a/src/libksupport/src/kernel/api.rs b/src/libksupport/src/kernel/api.rs index 4ffb7b5..6cf6d1f 100644 --- a/src/libksupport/src/kernel/api.rs +++ b/src/libksupport/src/kernel/api.rs @@ -1,16 +1,15 @@ use alloc::vec; -use core::{ffi::VaList, ptr, slice, str}; +use core::{ffi::VaList, ptr, str}; use libc::{c_char, c_int, size_t}; use libm; use log::{info, warn}; -use nalgebra::{linalg, DMatrix}; #[cfg(has_drtio)] use super::subkernel; use super::{cache, core1::rtio_get_destination_status, - dma, + dma, linalg, rpc::{rpc_recv, rpc_send, rpc_send_async}}; use crate::{eh_artiq, i2c, rtio}; @@ -39,26 +38,6 @@ unsafe extern "C" fn rtio_log(fmt: *const c_char, mut args: ...) { rtio::write_log(buf.as_slice()); } -unsafe extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8 { - let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) }; - let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice); - let mut inverted_matrix = DMatrix::::zeros(dim0, dim1); - - if linalg::try_invert_to(matrix, &mut inverted_matrix) { - data_slice.copy_from_slice(inverted_matrix.transpose().as_slice()); - 1 - } else { - 0 - } -} - -unsafe extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64 { - let data_slice = slice::from_raw_parts_mut(data, dim0 * dim1); - let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice); - - linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)]) -} - macro_rules! api { ($i:ident) => ({ extern { static $i: u8; } @@ -342,8 +321,17 @@ pub fn resolve(required: &[u8]) -> Option { }, // linalg - api!(linalg_try_invert_to = linalg_try_invert_to), - api!(linalg_wilkinson_shift = linalg_wilkinson_shift), + api!(np_dot = linalg::np_dot), + api!(np_linalg_matmul = linalg::np_linalg_matmul), + api!(np_linalg_cholesky = linalg::np_linalg_cholesky), + api!(np_linalg_qr = linalg::np_linalg_qr), + api!(np_linalg_svd = linalg::np_linalg_svd), + api!(np_linalg_inv = linalg::np_linalg_inv), + api!(np_linalg_pinv = linalg::np_linalg_pinv), + api!(sp_linalg_lu = linalg::sp_linalg_lu), + api!(sp_linalg_schur = linalg::sp_linalg_schur), + api!(sp_linalg_hessenberg = linalg::sp_linalg_hessenberg), + ]; api.iter() .find(|&&(exported, _)| exported.as_bytes() == required) diff --git a/src/libksupport/src/kernel/linalg.rs b/src/libksupport/src/kernel/linalg.rs new file mode 100644 index 0000000..2529008 --- /dev/null +++ b/src/libksupport/src/kernel/linalg.rs @@ -0,0 +1,381 @@ +// Uses `nalgebra` crate to invoke `np_linalg` and `sp_linalg` functions +// When converting between `nalgebra::Matrix` and `NDArray` following considerations are necessary +// +// * Both `nalgebra::Matrix` and `NDArray` require their content to be stored in row-major order +// * `NDArray` data pointer can be directly read and converted to `nalgebra::Matrix` (row and column number must be known) +// * `nalgebra::Matrix::as_slice` returns the content of matrix in column-major order and initial data needs to be transposed before storing it in `NDArray` data pointer + +use alloc::vec::Vec; +use core::slice; + +use nalgebra::DMatrix; + +use crate::artiq_raise; + +pub struct InputMatrix { + pub ndims: usize, + pub dims: *const usize, + pub data: *mut f64, +} + +impl InputMatrix { + fn get_dims(&mut self) -> Vec { + let dims = unsafe { slice::from_raw_parts(self.dims, self.ndims) }; + dims.to_vec() + } +} + +/// # Safety +/// +/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 { + let mat1 = mat1.as_mut().unwrap(); + let mat2 = mat2.as_mut().unwrap(); + + if !(mat1.ndims == 1 && mat2.ndims == 1) { + let err_msg = format!( + "expected 1D Vector Input, but received {}-D and {}-D input", + mat1.ndims, mat2.ndims + ); + artiq_raise!("ValueError", err_msg); + } + + let dim1 = (*mat1).get_dims(); + let dim2 = (*mat2).get_dims(); + + if dim1[0] != dim2[0] { + let err_msg = format!("shapes ({},) and ({},) not aligned", dim1[0], dim2[0]); + artiq_raise!("ValueError", err_msg); + } + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) }; + let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) }; + + let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1); + let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2); + + matrix1.dot(&matrix2) +} + +/// # Safety +/// +/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_matmul(mat1: *mut InputMatrix, mat2: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let mat2 = mat2.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if !(mat1.ndims == 2 && mat2.ndims == 2) { + let err_msg = format!( + "expected 2D Vector Input, but received {}-D and {}-D input", + mat1.ndims, mat2.ndims + ); + artiq_raise!("ValueError", err_msg); + } + + let dim1 = (*mat1).get_dims(); + let dim2 = (*mat2).get_dims(); + + if dim1[1] != dim2[0] { + let err_msg = format!( + "shapes ({},{}) and ({},{}) not aligned: {} (dim 1) != {} (dim 0)", + dim1[0], dim1[1], dim2[0], dim2[1], dim1[1], dim2[0] + ); + artiq_raise!("ValueError", err_msg); + } + + let outdim = out.get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0] * dim2[1]) }; + + let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let matrix2 = DMatrix::from_row_slice(dim2[0], dim2[1], data_slice2); + let mut result = DMatrix::::zeros(outdim[0], outdim[1]); + + matrix1.mul_to(&matrix2, &mut result); + out_slice.copy_from_slice(result.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + artiq_raise!("ValueError", err_msg); + } + + let dim1 = (*mat1).get_dims(); + if dim1[0] != dim1[1] { + let err_msg = format!( + "last 2 dimensions of the array must be square: {0} != {1}", + dim1[0], dim1[1] + ); + artiq_raise!("LinAlgError", err_msg); + } + + let outdim = out.get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + + let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let result = matrix1.cholesky(); + match result { + Some(res) => { + out_slice.copy_from_slice(res.unpack().transpose().as_slice()); + } + None => { + artiq_raise!("LinAlgError", "Matrix is not positive definite"); + } + }; +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_qr(mat1: *mut InputMatrix, out_q: *mut InputMatrix, out_r: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out_q = out_q.as_mut().unwrap(); + let out_r = out_r.as_mut().unwrap(); + + if mat1.ndims != 2 { + let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + artiq_raise!("ValueError", err_msg); + } + + let dim1 = (*mat1).get_dims(); + let outq_dim = (*out_q).get_dims(); + let outr_dim = (*out_r).get_dims(); + + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, outq_dim[0] * outq_dim[1]) }; + let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r.data, outr_dim[0] * outr_dim[1]) }; + + // Refer to https://github.com/dimforge/nalgebra/issues/735 + let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + + let res = matrix1.qr(); + let (q, r) = res.unpack(); + + // Uses different algo need to match numpy + out_q_slice.copy_from_slice(q.transpose().as_slice()); + out_r_slice.copy_from_slice(r.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_svd( + mat1: *mut InputMatrix, + outu: *mut InputMatrix, + outs: *mut InputMatrix, + outvh: *mut InputMatrix, +) { + let mat1 = mat1.as_mut().unwrap(); + let outu = outu.as_mut().unwrap(); + let outs = outs.as_mut().unwrap(); + let outvh = outvh.as_mut().unwrap(); + + if mat1.ndims != 2 { + let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + artiq_raise!("ValueError", err_msg); + } + + let dim1 = (*mat1).get_dims(); + let outu_dim = (*outu).get_dims(); + let outs_dim = (*outs).get_dims(); + let outvh_dim = (*outvh).get_dims(); + + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let out_u_slice = unsafe { slice::from_raw_parts_mut(outu.data, outu_dim[0] * outu_dim[1]) }; + let out_s_slice = unsafe { slice::from_raw_parts_mut(outs.data, outs_dim[0]) }; + let out_vh_slice = unsafe { slice::from_raw_parts_mut(outvh.data, outvh_dim[0] * outvh_dim[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let result = matrix.svd(true, true); + out_u_slice.copy_from_slice(result.u.unwrap().transpose().as_slice()); + out_s_slice.copy_from_slice(result.singular_values.as_slice()); + out_vh_slice.copy_from_slice(result.v_t.unwrap().transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + artiq_raise!("ValueError", err_msg); + } + let dim1 = (*mat1).get_dims(); + + if dim1[0] != dim1[1] { + let err_msg = format!( + "last 2 dimensions of the array must be square: {0} != {1}", + dim1[0], dim1[1] + ); + artiq_raise!("LinAlgError", err_msg); + } + + let outdim = out.get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + if !matrix.is_invertible() { + artiq_raise!("LinAlgError", "no inverse for Singular Matrix"); + } + let inv = matrix.try_inverse().unwrap(); + out_slice.copy_from_slice(inv.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + artiq_raise!("ValueError", err_msg); + } + let dim1 = (*mat1).get_dims(); + let outdim = out.get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let svd = matrix.svd(true, true); + let inv = svd.pseudo_inverse(1e-15); + + match inv { + Ok(m) => { + out_slice.copy_from_slice(m.transpose().as_slice()); + } + Err(e) => { + artiq_raise!("LinAlgError", e); + } + } +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn sp_linalg_lu(mat1: *mut InputMatrix, out_l: *mut InputMatrix, out_u: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out_l = out_l.as_mut().unwrap(); + let out_u = out_u.as_mut().unwrap(); + + if mat1.ndims != 2 { + let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + artiq_raise!("ValueError", err_msg); + } + + let dim1 = (*mat1).get_dims(); + let outl_dim = (*out_l).get_dims(); + let outu_dim = (*out_u).get_dims(); + + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l.data, outl_dim[0] * outl_dim[1]) }; + let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u.data, outu_dim[0] * outu_dim[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let (_, l, u) = matrix.lu().unpack(); + + out_l_slice.copy_from_slice(l.transpose().as_slice()); + out_u_slice.copy_from_slice(u.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn sp_linalg_schur(mat1: *mut InputMatrix, out_t: *mut InputMatrix, out_z: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out_t = out_t.as_mut().unwrap(); + let out_z = out_z.as_mut().unwrap(); + + if mat1.ndims != 2 { + let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + artiq_raise!("ValueError", err_msg); + } + + let dim1 = (*mat1).get_dims(); + + if dim1[0] != dim1[1] { + let err_msg = format!( + "last 2 dimensions of the array must be square: {0} != {1}", + dim1[0], dim1[1] + ); + artiq_raise!("LinAlgError", err_msg); + } + + let out_t_dim = (*out_t).get_dims(); + let out_z_dim = (*out_z).get_dims(); + + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t.data, out_t_dim[0] * out_t_dim[1]) }; + let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z.data, out_z_dim[0] * out_z_dim[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let (z, t) = matrix.schur().unpack(); + + out_t_slice.copy_from_slice(t.transpose().as_slice()); + out_z_slice.copy_from_slice(z.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn sp_linalg_hessenberg( + mat1: *mut InputMatrix, + out_h: *mut InputMatrix, + out_q: *mut InputMatrix, +) { + let mat1 = mat1.as_mut().unwrap(); + let out_h = out_h.as_mut().unwrap(); + let out_q = out_q.as_mut().unwrap(); + + if mat1.ndims != 2 { + let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); + artiq_raise!("ValueError", err_msg); + } + + let dim1 = (*mat1).get_dims(); + + if dim1[0] != dim1[1] { + let err_msg = format!( + "last 2 dimensions of the array must be square: {} != {}", + dim1[0], dim1[1] + ); + artiq_raise!("LinAlgError", err_msg); + } + + let out_h_dim = (*out_h).get_dims(); + let out_q_dim = (*out_q).get_dims(); + + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h.data, out_h_dim[0] * out_h_dim[1]) }; + let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, out_q_dim[0] * out_q_dim[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let (q, h) = matrix.hessenberg().unpack(); + + out_h_slice.copy_from_slice(h.transpose().as_slice()); + out_q_slice.copy_from_slice(q.transpose().as_slice()); +} diff --git a/src/libksupport/src/kernel/mod.rs b/src/libksupport/src/kernel/mod.rs index b235cdd..2e511a2 100644 --- a/src/libksupport/src/kernel/mod.rs +++ b/src/libksupport/src/kernel/mod.rs @@ -13,6 +13,7 @@ mod dma; mod rpc; pub use dma::DmaRecorder; mod cache; +mod linalg; #[cfg(has_drtio)] mod subkernel;