forked from M-Labs/nac3
kernel: add linalg functions
This commit is contained in:
parent
e4d7ce114f
commit
fe6f259d48
|
@ -2,6 +2,15 @@
|
|||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.4"
|
||||
|
@ -246,10 +255,10 @@ dependencies = [
|
|||
"libsupport_zynq",
|
||||
"log",
|
||||
"log_buffer",
|
||||
"nalgebra",
|
||||
"nb 0.1.3",
|
||||
"unwind",
|
||||
"vcell",
|
||||
"nalgebra",
|
||||
"void",
|
||||
]
|
||||
|
||||
|
@ -383,6 +392,19 @@ version = "0.7.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c75de51135344a4f8ed3cfe2720dc27736f7711989703a0b43aadf3753c55577"
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra"
|
||||
version = "0.32.6"
|
||||
source = "git+https://git.m-labs.hk/M-labs/nalgebra?rev=dd00f9b#dd00f9b46046e0b931d1b470166db02fd29591be"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"simba",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nb"
|
||||
version = "0.1.3"
|
||||
|
@ -398,6 +420,15 @@ version = "1.0.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "546c37ac5d9e56f55e73b677106873d9d9f5190605e41a856503623648488cae"
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-derive"
|
||||
version = "0.3.3"
|
||||
|
@ -409,6 +440,26 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.15"
|
||||
|
@ -416,8 +467,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.9"
|
||||
|
@ -524,6 +582,18 @@ version = "0.1.20"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d4f410fedcf71af0345d7607d246e7ad15faaadd49d240ee3b24e5dc21a820ac"
|
||||
|
||||
[[package]]
|
||||
name = "simba"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "50582927ed6f77e4ac020c057f37a268fc6aebc29225050365aacbb9deeeddc4"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smoltcp"
|
||||
version = "0.7.5"
|
||||
|
@ -556,6 +626,12 @@ dependencies = [
|
|||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.5"
|
||||
|
@ -572,147 +648,6 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra"
|
||||
version = "0.32.6"
|
||||
source = "git+https://git.m-labs.hk/M-labs/nalgebra?rev=dd00f9b#dd00f9b46046e0b931d1b470166db02fd29591be"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"matrixmultiply",
|
||||
"nalgebra-macros",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"simba",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra-macros"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91761aed67d03ad966ef783ae962ef9bbaca728d2dd7ceb7939ec110fffad998"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simba"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "50582927ed6f77e4ac020c057f37a268fc6aebc29225050365aacbb9deeeddc4"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"wide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||
|
||||
[[package]]
|
||||
name = "rawpointer"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4e0d047c1062aa51e256408c560894e5251f08925980e53cf1aa5bd00eec6512"
|
||||
dependencies = [
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "wide"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd89cf484471f953ee84f07c0dff0ea20e9ddf976f03cabdf5dda48b221f22e7"
|
||||
features = ["no_std"]
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"safe_arch",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.16.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e"
|
||||
|
||||
[[package]]
|
||||
name = "safe_arch"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "794821e4ccb0d9f979512f9c1973480123f9bd62a90d74ab0f9426fcf8f4a529"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "vcell"
|
||||
version = "0.1.3"
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
use alloc::vec;
|
||||
use core::{ffi::VaList, ptr, slice, str};
|
||||
use core::{ffi::VaList, ptr, str};
|
||||
|
||||
use libc::{c_char, c_int, size_t};
|
||||
use libm;
|
||||
use log::{info, warn};
|
||||
use nalgebra::{linalg, DMatrix};
|
||||
|
||||
#[cfg(has_drtio)]
|
||||
use super::subkernel;
|
||||
use super::{cache,
|
||||
core1::rtio_get_destination_status,
|
||||
dma,
|
||||
dma, linalg,
|
||||
rpc::{rpc_recv, rpc_send, rpc_send_async}};
|
||||
use crate::{eh_artiq, i2c, rtio};
|
||||
|
||||
|
@ -39,26 +38,6 @@ unsafe extern "C" fn rtio_log(fmt: *const c_char, mut args: ...) {
|
|||
rtio::write_log(buf.as_slice());
|
||||
}
|
||||
|
||||
unsafe extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8 {
|
||||
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
|
||||
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
|
||||
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
|
||||
|
||||
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
|
||||
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64 {
|
||||
let data_slice = slice::from_raw_parts_mut(data, dim0 * dim1);
|
||||
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
|
||||
|
||||
linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)])
|
||||
}
|
||||
|
||||
macro_rules! api {
|
||||
($i:ident) => ({
|
||||
extern { static $i: u8; }
|
||||
|
@ -342,8 +321,17 @@ pub fn resolve(required: &[u8]) -> Option<u32> {
|
|||
},
|
||||
|
||||
// linalg
|
||||
api!(linalg_try_invert_to = linalg_try_invert_to),
|
||||
api!(linalg_wilkinson_shift = linalg_wilkinson_shift),
|
||||
api!(np_linalg_cholesky = linalg::np_linalg_cholesky),
|
||||
api!(np_linalg_qr = linalg::np_linalg_qr),
|
||||
api!(np_linalg_svd = linalg::np_linalg_svd),
|
||||
api!(np_linalg_inv = linalg::np_linalg_inv),
|
||||
api!(np_linalg_pinv = linalg::np_linalg_pinv),
|
||||
api!(np_linalg_matrix_power = linalg::np_linalg_matrix_power),
|
||||
api!(np_linalg_det = linalg::np_linalg_det),
|
||||
api!(sp_linalg_lu = linalg::sp_linalg_lu),
|
||||
api!(sp_linalg_schur = linalg::sp_linalg_schur),
|
||||
api!(sp_linalg_hessenberg = linalg::sp_linalg_hessenberg),
|
||||
|
||||
];
|
||||
api.iter()
|
||||
.find(|&&(exported, _)| exported.as_bytes() == required)
|
||||
|
|
|
@ -0,0 +1,440 @@
|
|||
// Uses `nalgebra` crate to invoke `np_linalg` and `sp_linalg` functions
|
||||
// When converting between `nalgebra::Matrix` and `NDArray` following considerations are necessary
|
||||
//
|
||||
// * Both `nalgebra::Matrix` and `NDArray` require their content to be stored in row-major order
|
||||
// * `NDArray` data pointer can be directly read and converted to `nalgebra::Matrix` (row and column number must be known)
|
||||
// * `nalgebra::Matrix::as_slice` returns the content of matrix in column-major order and initial data needs to be transposed before storing it in `NDArray` data pointer
|
||||
|
||||
use alloc::vec::Vec;
|
||||
use core::slice;
|
||||
|
||||
use nalgebra::DMatrix;
|
||||
|
||||
use crate::artiq_raise;
|
||||
|
||||
pub struct InputMatrix {
|
||||
pub ndims: usize,
|
||||
pub dims: *const usize,
|
||||
pub data: *mut f64,
|
||||
}
|
||||
|
||||
impl InputMatrix {
|
||||
fn get_dims(&mut self) -> Vec<usize> {
|
||||
let dims = unsafe { slice::from_raw_parts(self.dims, self.ndims) };
|
||||
dims.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"expected 2D Vector Input, but received {1}D input)",
|
||||
0,
|
||||
mat1.ndims as i64,
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
if dim1[0] != dim1[1] {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"last 2 dimensions of the array must be square: {1} != {2}",
|
||||
0,
|
||||
dim1[0] as i64,
|
||||
dim1[1] as i64
|
||||
);
|
||||
}
|
||||
|
||||
let outdim = out.get_dims();
|
||||
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
|
||||
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
let result = matrix1.cholesky();
|
||||
match result {
|
||||
Some(res) => {
|
||||
out_slice.copy_from_slice(res.unpack().transpose().as_slice());
|
||||
}
|
||||
None => {
|
||||
artiq_raise!("LinAlgError", "Matrix is not positive definite");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_qr(mat1: *mut InputMatrix, out_q: *mut InputMatrix, out_r: *mut InputMatrix) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let out_q = out_q.as_mut().unwrap();
|
||||
let out_r = out_r.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"expected 2D Vector Input, but received {1}D input)",
|
||||
0,
|
||||
mat1.ndims as i64,
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let outq_dim = (*out_q).get_dims();
|
||||
let outr_dim = (*out_r).get_dims();
|
||||
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, outq_dim[0] * outq_dim[1]) };
|
||||
let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r.data, outr_dim[0] * outr_dim[1]) };
|
||||
|
||||
// Refer to https://github.com/dimforge/nalgebra/issues/735
|
||||
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
|
||||
let res = matrix1.qr();
|
||||
let (q, r) = res.unpack();
|
||||
|
||||
// Uses different algo need to match numpy
|
||||
out_q_slice.copy_from_slice(q.transpose().as_slice());
|
||||
out_r_slice.copy_from_slice(r.transpose().as_slice());
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_svd(
|
||||
mat1: *mut InputMatrix,
|
||||
outu: *mut InputMatrix,
|
||||
outs: *mut InputMatrix,
|
||||
outvh: *mut InputMatrix,
|
||||
) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let outu = outu.as_mut().unwrap();
|
||||
let outs = outs.as_mut().unwrap();
|
||||
let outvh = outvh.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"expected 2D Vector Input, but received {1}D input)",
|
||||
0,
|
||||
mat1.ndims as i64,
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let outu_dim = (*outu).get_dims();
|
||||
let outs_dim = (*outs).get_dims();
|
||||
let outvh_dim = (*outvh).get_dims();
|
||||
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
let out_u_slice = unsafe { slice::from_raw_parts_mut(outu.data, outu_dim[0] * outu_dim[1]) };
|
||||
let out_s_slice = unsafe { slice::from_raw_parts_mut(outs.data, outs_dim[0]) };
|
||||
let out_vh_slice = unsafe { slice::from_raw_parts_mut(outvh.data, outvh_dim[0] * outvh_dim[1]) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
let result = matrix.svd(true, true);
|
||||
out_u_slice.copy_from_slice(result.u.unwrap().transpose().as_slice());
|
||||
out_s_slice.copy_from_slice(result.singular_values.as_slice());
|
||||
out_vh_slice.copy_from_slice(result.v_t.unwrap().transpose().as_slice());
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"expected 2D Vector Input, but received {1}D input)",
|
||||
0,
|
||||
mat1.ndims as i64,
|
||||
0
|
||||
);
|
||||
}
|
||||
let dim1 = (*mat1).get_dims();
|
||||
|
||||
if dim1[0] != dim1[1] {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"last 2 dimensions of the array must be square: {1} != {2}",
|
||||
0,
|
||||
dim1[0] as i64,
|
||||
dim1[1] as i64
|
||||
);
|
||||
}
|
||||
|
||||
let outdim = out.get_dims();
|
||||
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
if !matrix.is_invertible() {
|
||||
artiq_raise!("LinAlgError", "no inverse for Singular Matrix");
|
||||
}
|
||||
let inv = matrix.try_inverse().unwrap();
|
||||
out_slice.copy_from_slice(inv.transpose().as_slice());
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"expected 2D Vector Input, but received {1}D input)",
|
||||
0,
|
||||
mat1.ndims as i64,
|
||||
0
|
||||
);
|
||||
}
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let outdim = out.get_dims();
|
||||
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
let svd = matrix.svd(true, true);
|
||||
let inv = svd.pseudo_inverse(1e-15);
|
||||
|
||||
match inv {
|
||||
Ok(m) => {
|
||||
out_slice.copy_from_slice(m.transpose().as_slice());
|
||||
}
|
||||
Err(_) => {
|
||||
artiq_raise!("LinAlgError", "SVD computation does not converge");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_matrix_power(mat1: *mut InputMatrix, mat2: *mut InputMatrix, out: *mut InputMatrix) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let mat2 = mat2.as_mut().unwrap();
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"expected 2D Vector Input, but received {1}D input)",
|
||||
0,
|
||||
mat1.ndims as i64,
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
|
||||
let power = power[0];
|
||||
let outdim = out.get_dims();
|
||||
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
let mut abs_power = power;
|
||||
if abs_power < 0.0 {
|
||||
abs_power = abs_power * -1.0;
|
||||
}
|
||||
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
if !matrix1.is_square() {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"last 2 dimensions of the array must be square: {1} != {2}",
|
||||
0,
|
||||
dim1[0] as i64,
|
||||
dim1[1] as i64
|
||||
);
|
||||
}
|
||||
let mut result = matrix1.pow(abs_power as u32);
|
||||
|
||||
if power < 0.0 {
|
||||
if !matrix1.is_invertible() {
|
||||
artiq_raise!("LinAlgError", "no inverse for Singular Matrix");
|
||||
}
|
||||
result = result.try_inverse().unwrap();
|
||||
}
|
||||
out_slice.copy_from_slice(result.transpose().as_slice());
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"expected 2D Vector Input, but received {1}D input)",
|
||||
0,
|
||||
mat1.ndims as i64,
|
||||
0
|
||||
);
|
||||
}
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
if !matrix.is_square() {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"last 2 dimensions of the array must be square: {1} != {2}",
|
||||
0,
|
||||
dim1[0] as i64,
|
||||
dim1[1] as i64
|
||||
);
|
||||
}
|
||||
out_slice[0] = matrix.determinant();
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sp_linalg_lu(mat1: *mut InputMatrix, out_l: *mut InputMatrix, out_u: *mut InputMatrix) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let out_l = out_l.as_mut().unwrap();
|
||||
let out_u = out_u.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"expected 2D Vector Input, but received {1}D input)",
|
||||
0,
|
||||
mat1.ndims as i64,
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let outl_dim = (*out_l).get_dims();
|
||||
let outu_dim = (*out_u).get_dims();
|
||||
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l.data, outl_dim[0] * outl_dim[1]) };
|
||||
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u.data, outu_dim[0] * outu_dim[1]) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
let (_, l, u) = matrix.lu().unpack();
|
||||
|
||||
out_l_slice.copy_from_slice(l.transpose().as_slice());
|
||||
out_u_slice.copy_from_slice(u.transpose().as_slice());
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sp_linalg_schur(mat1: *mut InputMatrix, out_t: *mut InputMatrix, out_z: *mut InputMatrix) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let out_t = out_t.as_mut().unwrap();
|
||||
let out_z = out_z.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"expected 2D Vector Input, but received {1}D input)",
|
||||
0,
|
||||
mat1.ndims as i64,
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
|
||||
if dim1[0] != dim1[1] {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"last 2 dimensions of the array must be square: {1} != {2}",
|
||||
0,
|
||||
dim1[0] as i64,
|
||||
dim1[1] as i64
|
||||
);
|
||||
}
|
||||
|
||||
let out_t_dim = (*out_t).get_dims();
|
||||
let out_z_dim = (*out_z).get_dims();
|
||||
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t.data, out_t_dim[0] * out_t_dim[1]) };
|
||||
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z.data, out_z_dim[0] * out_z_dim[1]) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
let (z, t) = matrix.schur().unpack();
|
||||
|
||||
out_t_slice.copy_from_slice(t.transpose().as_slice());
|
||||
out_z_slice.copy_from_slice(z.transpose().as_slice());
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sp_linalg_hessenberg(
|
||||
mat1: *mut InputMatrix,
|
||||
out_h: *mut InputMatrix,
|
||||
out_q: *mut InputMatrix,
|
||||
) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let out_h = out_h.as_mut().unwrap();
|
||||
let out_q = out_q.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"expected 2D Vector Input, but received {1}D input)",
|
||||
0,
|
||||
mat1.ndims as i64,
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
|
||||
if dim1[0] != dim1[1] {
|
||||
artiq_raise!(
|
||||
"ValueError",
|
||||
"last 2 dimensions of the array must be square: {1} != {2}",
|
||||
0,
|
||||
dim1[0] as i64,
|
||||
dim1[1] as i64
|
||||
);
|
||||
}
|
||||
|
||||
let out_h_dim = (*out_h).get_dims();
|
||||
let out_q_dim = (*out_q).get_dims();
|
||||
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h.data, out_h_dim[0] * out_h_dim[1]) };
|
||||
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, out_q_dim[0] * out_q_dim[1]) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
let (q, h) = matrix.hessenberg().unpack();
|
||||
|
||||
out_h_slice.copy_from_slice(h.transpose().as_slice());
|
||||
out_q_slice.copy_from_slice(q.transpose().as_slice());
|
||||
}
|
|
@ -13,6 +13,7 @@ mod dma;
|
|||
mod rpc;
|
||||
pub use dma::DmaRecorder;
|
||||
mod cache;
|
||||
mod linalg;
|
||||
#[cfg(has_drtio)]
|
||||
mod subkernel;
|
||||
|
||||
|
|
Loading…
Reference in New Issue