From fe6f259d483618619af70bcff9ac201daf62cc5e Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 1 Aug 2024 18:16:55 +0800 Subject: [PATCH 1/4] kernel: add linalg functions --- src/Cargo.lock | 219 +++++-------- src/libksupport/src/kernel/api.rs | 38 +-- src/libksupport/src/kernel/linalg.rs | 440 +++++++++++++++++++++++++++ src/libksupport/src/kernel/mod.rs | 1 + 4 files changed, 531 insertions(+), 167 deletions(-) create mode 100644 src/libksupport/src/kernel/linalg.rs diff --git a/src/Cargo.lock b/src/Cargo.lock index c14d523..0fd99fa 100644 --- a/src/Cargo.lock +++ b/src/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "arrayvec" version = "0.7.4" @@ -246,10 +255,10 @@ dependencies = [ "libsupport_zynq", "log", "log_buffer", + "nalgebra", "nb 0.1.3", "unwind", "vcell", - "nalgebra", "void", ] @@ -383,6 +392,19 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c75de51135344a4f8ed3cfe2720dc27736f7711989703a0b43aadf3753c55577" +[[package]] +name = "nalgebra" +version = "0.32.6" +source = "git+https://git.m-labs.hk/M-labs/nalgebra?rev=dd00f9b#dd00f9b46046e0b931d1b470166db02fd29591be" +dependencies = [ + "approx", + "num-complex", + "num-rational", + "num-traits", + "simba", + "typenum", +] + [[package]] name = "nb" version = "0.1.3" @@ -398,6 +420,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "546c37ac5d9e56f55e73b677106873d9d9f5190605e41a856503623648488cae" +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -409,6 +440,26 @@ dependencies = [ "syn", ] +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.15" @@ -416,8 +467,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", + "libm", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -524,6 +582,18 @@ version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4f410fedcf71af0345d7607d246e7ad15faaadd49d240ee3b24e5dc21a820ac" +[[package]] +name = "simba" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50582927ed6f77e4ac020c057f37a268fc6aebc29225050365aacbb9deeeddc4" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", +] + [[package]] name = "smoltcp" version = "0.7.5" @@ -556,6 +626,12 @@ dependencies = [ "log", ] +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "unicode-ident" version = "1.0.5" @@ -572,147 +648,6 @@ dependencies = [ "libc", ] -[[package]] -name = "nalgebra" -version = "0.32.6" -source = "git+https://git.m-labs.hk/M-labs/nalgebra?rev=dd00f9b#dd00f9b46046e0b931d1b470166db02fd29591be" -dependencies = [ - "approx", - "matrixmultiply", - "nalgebra-macros", - "num-complex", - "num-rational", - "num-traits", - "simba", - "typenum", -] - -[[package]] -name = "approx" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" -dependencies = [ - "num-traits", -] - -[[package]] -name = "matrixmultiply" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" -dependencies = [ - "autocfg", - "rawpointer", -] - -[[package]] -name = "nalgebra-macros" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91761aed67d03ad966ef783ae962ef9bbaca728d2dd7ceb7939ec110fffad998" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "num-complex" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-rational" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" -dependencies = [ - "num-bigint", - "num-integer", - "num-traits", -] - -[[package]] -name = "simba" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50582927ed6f77e4ac020c057f37a268fc6aebc29225050365aacbb9deeeddc4" -dependencies = [ - "approx", - "num-complex", - "num-traits", - "paste", - "wide", -] - -[[package]] -name = "typenum" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" - -[[package]] -name = "rawpointer" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" - -[[package]] -name = "num-bigint" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d047c1062aa51e256408c560894e5251f08925980e53cf1aa5bd00eec6512" -dependencies = [ - "num-integer", - "num-traits", -] - -[[package]] -name = "num-integer" -version = "0.1.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" -dependencies = [ - "num-traits", -] - -[[package]] -name = "paste" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" - -[[package]] -name = "wide" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd89cf484471f953ee84f07c0dff0ea20e9ddf976f03cabdf5dda48b221f22e7" -features = ["no_std"] -dependencies = [ - "bytemuck", - "safe_arch", -] - -[[package]] -name = "bytemuck" -version = "1.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" - -[[package]] -name = "safe_arch" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "794821e4ccb0d9f979512f9c1973480123f9bd62a90d74ab0f9426fcf8f4a529" -dependencies = [ - "bytemuck", -] - [[package]] name = "vcell" version = "0.1.3" diff --git a/src/libksupport/src/kernel/api.rs b/src/libksupport/src/kernel/api.rs index 4ffb7b5..106024f 100644 --- a/src/libksupport/src/kernel/api.rs +++ b/src/libksupport/src/kernel/api.rs @@ -1,16 +1,15 @@ use alloc::vec; -use core::{ffi::VaList, ptr, slice, str}; +use core::{ffi::VaList, ptr, str}; use libc::{c_char, c_int, size_t}; use libm; use log::{info, warn}; -use nalgebra::{linalg, DMatrix}; #[cfg(has_drtio)] use super::subkernel; use super::{cache, core1::rtio_get_destination_status, - dma, + dma, linalg, rpc::{rpc_recv, rpc_send, rpc_send_async}}; use crate::{eh_artiq, i2c, rtio}; @@ -39,26 +38,6 @@ unsafe extern "C" fn rtio_log(fmt: *const c_char, mut args: ...) { rtio::write_log(buf.as_slice()); } -unsafe extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8 { - let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) }; - let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice); - let mut inverted_matrix = DMatrix::::zeros(dim0, dim1); - - if linalg::try_invert_to(matrix, &mut inverted_matrix) { - data_slice.copy_from_slice(inverted_matrix.transpose().as_slice()); - 1 - } else { - 0 - } -} - -unsafe extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64 { - let data_slice = slice::from_raw_parts_mut(data, dim0 * dim1); - let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice); - - linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)]) -} - macro_rules! api { ($i:ident) => ({ extern { static $i: u8; } @@ -342,8 +321,17 @@ pub fn resolve(required: &[u8]) -> Option { }, // linalg - api!(linalg_try_invert_to = linalg_try_invert_to), - api!(linalg_wilkinson_shift = linalg_wilkinson_shift), + api!(np_linalg_cholesky = linalg::np_linalg_cholesky), + api!(np_linalg_qr = linalg::np_linalg_qr), + api!(np_linalg_svd = linalg::np_linalg_svd), + api!(np_linalg_inv = linalg::np_linalg_inv), + api!(np_linalg_pinv = linalg::np_linalg_pinv), + api!(np_linalg_matrix_power = linalg::np_linalg_matrix_power), + api!(np_linalg_det = linalg::np_linalg_det), + api!(sp_linalg_lu = linalg::sp_linalg_lu), + api!(sp_linalg_schur = linalg::sp_linalg_schur), + api!(sp_linalg_hessenberg = linalg::sp_linalg_hessenberg), + ]; api.iter() .find(|&&(exported, _)| exported.as_bytes() == required) diff --git a/src/libksupport/src/kernel/linalg.rs b/src/libksupport/src/kernel/linalg.rs new file mode 100644 index 0000000..b5e5769 --- /dev/null +++ b/src/libksupport/src/kernel/linalg.rs @@ -0,0 +1,440 @@ +// Uses `nalgebra` crate to invoke `np_linalg` and `sp_linalg` functions +// When converting between `nalgebra::Matrix` and `NDArray` following considerations are necessary +// +// * Both `nalgebra::Matrix` and `NDArray` require their content to be stored in row-major order +// * `NDArray` data pointer can be directly read and converted to `nalgebra::Matrix` (row and column number must be known) +// * `nalgebra::Matrix::as_slice` returns the content of matrix in column-major order and initial data needs to be transposed before storing it in `NDArray` data pointer + +use alloc::vec::Vec; +use core::slice; + +use nalgebra::DMatrix; + +use crate::artiq_raise; + +pub struct InputMatrix { + pub ndims: usize, + pub dims: *const usize, + pub data: *mut f64, +} + +impl InputMatrix { + fn get_dims(&mut self) -> Vec { + let dims = unsafe { slice::from_raw_parts(self.dims, self.ndims) }; + dims.to_vec() + } +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + artiq_raise!( + "ValueError", + "expected 2D Vector Input, but received {1}D input)", + 0, + mat1.ndims as i64, + 0 + ); + } + + let dim1 = (*mat1).get_dims(); + if dim1[0] != dim1[1] { + artiq_raise!( + "ValueError", + "last 2 dimensions of the array must be square: {1} != {2}", + 0, + dim1[0] as i64, + dim1[1] as i64 + ); + } + + let outdim = out.get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + + let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let result = matrix1.cholesky(); + match result { + Some(res) => { + out_slice.copy_from_slice(res.unpack().transpose().as_slice()); + } + None => { + artiq_raise!("LinAlgError", "Matrix is not positive definite"); + } + }; +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_qr(mat1: *mut InputMatrix, out_q: *mut InputMatrix, out_r: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out_q = out_q.as_mut().unwrap(); + let out_r = out_r.as_mut().unwrap(); + + if mat1.ndims != 2 { + artiq_raise!( + "ValueError", + "expected 2D Vector Input, but received {1}D input)", + 0, + mat1.ndims as i64, + 0 + ); + } + + let dim1 = (*mat1).get_dims(); + let outq_dim = (*out_q).get_dims(); + let outr_dim = (*out_r).get_dims(); + + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, outq_dim[0] * outq_dim[1]) }; + let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r.data, outr_dim[0] * outr_dim[1]) }; + + // Refer to https://github.com/dimforge/nalgebra/issues/735 + let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + + let res = matrix1.qr(); + let (q, r) = res.unpack(); + + // Uses different algo need to match numpy + out_q_slice.copy_from_slice(q.transpose().as_slice()); + out_r_slice.copy_from_slice(r.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_svd( + mat1: *mut InputMatrix, + outu: *mut InputMatrix, + outs: *mut InputMatrix, + outvh: *mut InputMatrix, +) { + let mat1 = mat1.as_mut().unwrap(); + let outu = outu.as_mut().unwrap(); + let outs = outs.as_mut().unwrap(); + let outvh = outvh.as_mut().unwrap(); + + if mat1.ndims != 2 { + artiq_raise!( + "ValueError", + "expected 2D Vector Input, but received {1}D input)", + 0, + mat1.ndims as i64, + 0 + ); + } + + let dim1 = (*mat1).get_dims(); + let outu_dim = (*outu).get_dims(); + let outs_dim = (*outs).get_dims(); + let outvh_dim = (*outvh).get_dims(); + + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let out_u_slice = unsafe { slice::from_raw_parts_mut(outu.data, outu_dim[0] * outu_dim[1]) }; + let out_s_slice = unsafe { slice::from_raw_parts_mut(outs.data, outs_dim[0]) }; + let out_vh_slice = unsafe { slice::from_raw_parts_mut(outvh.data, outvh_dim[0] * outvh_dim[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let result = matrix.svd(true, true); + out_u_slice.copy_from_slice(result.u.unwrap().transpose().as_slice()); + out_s_slice.copy_from_slice(result.singular_values.as_slice()); + out_vh_slice.copy_from_slice(result.v_t.unwrap().transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + artiq_raise!( + "ValueError", + "expected 2D Vector Input, but received {1}D input)", + 0, + mat1.ndims as i64, + 0 + ); + } + let dim1 = (*mat1).get_dims(); + + if dim1[0] != dim1[1] { + artiq_raise!( + "ValueError", + "last 2 dimensions of the array must be square: {1} != {2}", + 0, + dim1[0] as i64, + dim1[1] as i64 + ); + } + + let outdim = out.get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + if !matrix.is_invertible() { + artiq_raise!("LinAlgError", "no inverse for Singular Matrix"); + } + let inv = matrix.try_inverse().unwrap(); + out_slice.copy_from_slice(inv.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + artiq_raise!( + "ValueError", + "expected 2D Vector Input, but received {1}D input)", + 0, + mat1.ndims as i64, + 0 + ); + } + let dim1 = (*mat1).get_dims(); + let outdim = out.get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let svd = matrix.svd(true, true); + let inv = svd.pseudo_inverse(1e-15); + + match inv { + Ok(m) => { + out_slice.copy_from_slice(m.transpose().as_slice()); + } + Err(_) => { + artiq_raise!("LinAlgError", "SVD computation does not converge"); + } + } +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_matrix_power(mat1: *mut InputMatrix, mat2: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let mat2 = mat2.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + artiq_raise!( + "ValueError", + "expected 2D Vector Input, but received {1}D input)", + 0, + mat1.ndims as i64, + 0 + ); + } + + let dim1 = (*mat1).get_dims(); + let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) }; + let power = power[0]; + let outdim = out.get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let mut abs_power = power; + if abs_power < 0.0 { + abs_power = abs_power * -1.0; + } + let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + if !matrix1.is_square() { + artiq_raise!( + "ValueError", + "last 2 dimensions of the array must be square: {1} != {2}", + 0, + dim1[0] as i64, + dim1[1] as i64 + ); + } + let mut result = matrix1.pow(abs_power as u32); + + if power < 0.0 { + if !matrix1.is_invertible() { + artiq_raise!("LinAlgError", "no inverse for Singular Matrix"); + } + result = result.try_inverse().unwrap(); + } + out_slice.copy_from_slice(result.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + artiq_raise!( + "ValueError", + "expected 2D Vector Input, but received {1}D input)", + 0, + mat1.ndims as i64, + 0 + ); + } + let dim1 = (*mat1).get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + if !matrix.is_square() { + artiq_raise!( + "ValueError", + "last 2 dimensions of the array must be square: {1} != {2}", + 0, + dim1[0] as i64, + dim1[1] as i64 + ); + } + out_slice[0] = matrix.determinant(); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn sp_linalg_lu(mat1: *mut InputMatrix, out_l: *mut InputMatrix, out_u: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out_l = out_l.as_mut().unwrap(); + let out_u = out_u.as_mut().unwrap(); + + if mat1.ndims != 2 { + artiq_raise!( + "ValueError", + "expected 2D Vector Input, but received {1}D input)", + 0, + mat1.ndims as i64, + 0 + ); + } + + let dim1 = (*mat1).get_dims(); + let outl_dim = (*out_l).get_dims(); + let outu_dim = (*out_u).get_dims(); + + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l.data, outl_dim[0] * outl_dim[1]) }; + let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u.data, outu_dim[0] * outu_dim[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let (_, l, u) = matrix.lu().unpack(); + + out_l_slice.copy_from_slice(l.transpose().as_slice()); + out_u_slice.copy_from_slice(u.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn sp_linalg_schur(mat1: *mut InputMatrix, out_t: *mut InputMatrix, out_z: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out_t = out_t.as_mut().unwrap(); + let out_z = out_z.as_mut().unwrap(); + + if mat1.ndims != 2 { + artiq_raise!( + "ValueError", + "expected 2D Vector Input, but received {1}D input)", + 0, + mat1.ndims as i64, + 0 + ); + } + + let dim1 = (*mat1).get_dims(); + + if dim1[0] != dim1[1] { + artiq_raise!( + "ValueError", + "last 2 dimensions of the array must be square: {1} != {2}", + 0, + dim1[0] as i64, + dim1[1] as i64 + ); + } + + let out_t_dim = (*out_t).get_dims(); + let out_z_dim = (*out_z).get_dims(); + + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t.data, out_t_dim[0] * out_t_dim[1]) }; + let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z.data, out_z_dim[0] * out_z_dim[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let (z, t) = matrix.schur().unpack(); + + out_t_slice.copy_from_slice(t.transpose().as_slice()); + out_z_slice.copy_from_slice(z.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn sp_linalg_hessenberg( + mat1: *mut InputMatrix, + out_h: *mut InputMatrix, + out_q: *mut InputMatrix, +) { + let mat1 = mat1.as_mut().unwrap(); + let out_h = out_h.as_mut().unwrap(); + let out_q = out_q.as_mut().unwrap(); + + if mat1.ndims != 2 { + artiq_raise!( + "ValueError", + "expected 2D Vector Input, but received {1}D input)", + 0, + mat1.ndims as i64, + 0 + ); + } + + let dim1 = (*mat1).get_dims(); + + if dim1[0] != dim1[1] { + artiq_raise!( + "ValueError", + "last 2 dimensions of the array must be square: {1} != {2}", + 0, + dim1[0] as i64, + dim1[1] as i64 + ); + } + + let out_h_dim = (*out_h).get_dims(); + let out_q_dim = (*out_q).get_dims(); + + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h.data, out_h_dim[0] * out_h_dim[1]) }; + let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, out_q_dim[0] * out_q_dim[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let (q, h) = matrix.hessenberg().unpack(); + + out_h_slice.copy_from_slice(h.transpose().as_slice()); + out_q_slice.copy_from_slice(q.transpose().as_slice()); +} diff --git a/src/libksupport/src/kernel/mod.rs b/src/libksupport/src/kernel/mod.rs index 8a8d48d..cb4255f 100644 --- a/src/libksupport/src/kernel/mod.rs +++ b/src/libksupport/src/kernel/mod.rs @@ -13,6 +13,7 @@ mod dma; mod rpc; pub use dma::DmaRecorder; mod cache; +mod linalg; #[cfg(has_drtio)] mod subkernel; From fee30033ec5fe874d510109b67e79bb098650be5 Mon Sep 17 00:00:00 2001 From: Simon Renblad Date: Fri, 1 Mar 2024 16:46:51 +0800 Subject: [PATCH 2/4] comms: run idle kernel on start-up --- src/runtime/src/comms.rs | 66 +++++++++++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 17 deletions(-) diff --git a/src/runtime/src/comms.rs b/src/runtime/src/comms.rs index e6abfb4..979b5d9 100644 --- a/src/runtime/src/comms.rs +++ b/src/runtime/src/comms.rs @@ -696,6 +696,31 @@ async fn handle_connection( } } +async fn load_and_run_idle_kernel( + buffer: &Vec, + control: &Rc>, + up_destinations: &Rc>, + aux_mutex: &Rc>, + routing_table: &drtio_routing::RoutingTable, + timer: GlobalTimer, +) { + info!("Loading idle kernel"); + let res = handle_flash_kernel(buffer, control, up_destinations, aux_mutex, routing_table, timer).await; + match res { + #[cfg(has_drtio)] + Err(Error::DestinationDown) => { + let mut countdown = timer.countdown(); + delay(&mut countdown, Milliseconds(500)).await; + } + Err(_) => warn!("error loading idle kernel"), + _ => (), + } + info!("Running idle kernel"); + let _ = handle_run_kernel(None, control, up_destinations, aux_mutex, routing_table, timer) + .await.map_err(|_| warn!("error running idle kernel")); + info!("Idle kernel terminated"); +} + pub fn main(timer: GlobalTimer, cfg: Config) { let net_addresses = net_settings::get_addresses(&cfg); info!("network addresses: {}", net_addresses); @@ -786,8 +811,30 @@ pub fn main(timer: GlobalTimer, cfg: Config) { mgmt::start(cfg); task::spawn(async move { - let connection = Rc::new(Semaphore::new(1, 1)); + let connection = Rc::new(Semaphore::new(0, 1)); let terminate = Rc::new(Semaphore::new(0, 1)); + { + let control = control.clone(); + let idle_kernel = idle_kernel.clone(); + let connection = connection.clone(); + let terminate = terminate.clone(); + let up_destinations = up_destinations.clone(); + let aux_mutex = aux_mutex.clone(); + let routing_table = drtio_routing_table.clone(); + task::spawn(async move { + let routing_table = routing_table.borrow(); + select_biased! { + _ = (async { + if let Some(buffer) = &*idle_kernel { + load_and_run_idle_kernel(&buffer, &control, &up_destinations, &aux_mutex, &routing_table, timer).await; + } + }).fuse() => (), + _ = terminate.async_wait().fuse() => () + } + connection.signal(); + }); + } + loop { let mut stream = TcpStream::accept(1381, 0x10_000, 0x10_000).await.unwrap(); @@ -815,22 +862,7 @@ pub fn main(timer: GlobalTimer, cfg: Config) { .await .map_err(|e| warn!("connection terminated: {}", e)); if let Some(buffer) = &*idle_kernel { - info!("Loading idle kernel"); - let res = handle_flash_kernel(&buffer, &control, &up_destinations, &aux_mutex, &routing_table, timer) - .await; - match res { - #[cfg(has_drtio)] - Err(Error::DestinationDown) => { - let mut countdown = timer.countdown(); - delay(&mut countdown, Milliseconds(500)).await; - } - Err(_) => warn!("error loading idle kernel"), - _ => (), - } - info!("Running idle kernel"); - let _ = handle_run_kernel(None, &control, &up_destinations, &aux_mutex, &routing_table, timer) - .await.map_err(|_| warn!("error running idle kernel")); - info!("Idle kernel terminated"); + load_and_run_idle_kernel(&buffer, &control, &up_destinations, &aux_mutex, &routing_table, timer).await; } }).fuse() => (), _ = terminate.async_wait().fuse() => () From fad1db9796bb194f72009dbd2ab852b386e10853 Mon Sep 17 00:00:00 2001 From: Simon Renblad Date: Wed, 31 Jul 2024 18:00:01 +0800 Subject: [PATCH 3/4] comms: remove idle kernel DRTIO error case --- src/runtime/src/comms.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/runtime/src/comms.rs b/src/runtime/src/comms.rs index 979b5d9..e7c8a14 100644 --- a/src/runtime/src/comms.rs +++ b/src/runtime/src/comms.rs @@ -707,11 +707,6 @@ async fn load_and_run_idle_kernel( info!("Loading idle kernel"); let res = handle_flash_kernel(buffer, control, up_destinations, aux_mutex, routing_table, timer).await; match res { - #[cfg(has_drtio)] - Err(Error::DestinationDown) => { - let mut countdown = timer.countdown(); - delay(&mut countdown, Milliseconds(500)).await; - } Err(_) => warn!("error loading idle kernel"), _ => (), } From 78d6b7ddcffa405f9a62366bfa984c658cd79f79 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Mon, 5 Aug 2024 19:37:55 +0800 Subject: [PATCH 4/4] cargo fmt --- src/runtime/src/comms.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/runtime/src/comms.rs b/src/runtime/src/comms.rs index e7c8a14..0e21c98 100644 --- a/src/runtime/src/comms.rs +++ b/src/runtime/src/comms.rs @@ -705,14 +705,15 @@ async fn load_and_run_idle_kernel( timer: GlobalTimer, ) { info!("Loading idle kernel"); - let res = handle_flash_kernel(buffer, control, up_destinations, aux_mutex, routing_table, timer).await; + let res = handle_flash_kernel(buffer, control, up_destinations, aux_mutex, routing_table, timer).await; match res { Err(_) => warn!("error loading idle kernel"), _ => (), } info!("Running idle kernel"); let _ = handle_run_kernel(None, control, up_destinations, aux_mutex, routing_table, timer) - .await.map_err(|_| warn!("error running idle kernel")); + .await + .map_err(|_| warn!("error running idle kernel")); info!("Idle kernel terminated"); }