Compare commits

..

1 Commits

Author SHA1 Message Date
0664f3336b kernel: add np_transpose function 2024-07-29 18:34:12 +08:00
3 changed files with 72 additions and 90 deletions

View File

@ -422,7 +422,7 @@ extern "C" fn stop_fn(
} }
// Must be kept in sync with preallocate_runtime_exception_names() in artiq/language/embedding_map.py // Must be kept in sync with preallocate_runtime_exception_names() in artiq/language/embedding_map.py
static EXCEPTION_ID_LOOKUP: [(&str, u32); 12] = [ static EXCEPTION_ID_LOOKUP: [(&str, u32); 14] = [
("RuntimeError", 0), ("RuntimeError", 0),
("RTIOUnderflow", 1), ("RTIOUnderflow", 1),
("RTIOOverflow", 2), ("RTIOOverflow", 2),
@ -435,6 +435,8 @@ static EXCEPTION_ID_LOOKUP: [(&str, u32); 12] = [
("IndexError", 9), ("IndexError", 9),
("UnwrapNoneError", 10), ("UnwrapNoneError", 10),
("SubkernelError", 11), ("SubkernelError", 11),
("ValueError", 12),
("LinAlgError", 13),
]; ];
pub fn get_exception_id(name: &str) -> u32 { pub fn get_exception_id(name: &str) -> u32 {
@ -467,7 +469,5 @@ macro_rules! artiq_raise {
$crate::eh_artiq::raise(&exn) $crate::eh_artiq::raise(&exn)
} }
}}; }};
($name:expr, $message:expr) => {{ ($name:expr, $message:expr) => {{ artiq_raise!($name, $message, 0, 0, 0) }};
artiq_raise!($name, $message, 0, 0, 0)
}};
} }

View File

@ -7,12 +7,10 @@ use log::{info, warn};
#[cfg(has_drtio)] #[cfg(has_drtio)]
use super::subkernel; use super::subkernel;
use super::{ use super::{cache,
cache, core1::rtio_get_destination_status,
core1::rtio_get_destination_status, dma, linalg,
dma, linalg, rpc::{rpc_recv, rpc_send, rpc_send_async}};
rpc::{rpc_recv, rpc_send, rpc_send_async},
};
use crate::{eh_artiq, i2c, rtio}; use crate::{eh_artiq, i2c, rtio};
extern "C" { extern "C" {
@ -323,14 +321,14 @@ pub fn resolve(required: &[u8]) -> Option<u32> {
}, },
// linalg // linalg
api!(np_transpose = linalg::np_transpose),
api!(np_dot = linalg::np_dot),
api!(np_linalg_matmul = linalg::np_linalg_matmul), api!(np_linalg_matmul = linalg::np_linalg_matmul),
api!(np_linalg_cholesky = linalg::np_linalg_cholesky), api!(np_linalg_cholesky = linalg::np_linalg_cholesky),
api!(np_linalg_qr = linalg::np_linalg_qr), api!(np_linalg_qr = linalg::np_linalg_qr),
api!(np_linalg_svd = linalg::np_linalg_svd), api!(np_linalg_svd = linalg::np_linalg_svd),
api!(np_linalg_inv = linalg::np_linalg_inv), api!(np_linalg_inv = linalg::np_linalg_inv),
api!(np_linalg_pinv = linalg::np_linalg_pinv), api!(np_linalg_pinv = linalg::np_linalg_pinv),
api!(np_linalg_matrix_power = linalg::np_linalg_matrix_power),
api!(np_linalg_det = linalg::np_linalg_det),
api!(sp_linalg_lu = linalg::sp_linalg_lu), api!(sp_linalg_lu = linalg::sp_linalg_lu),
api!(sp_linalg_schur = linalg::sp_linalg_schur), api!(sp_linalg_schur = linalg::sp_linalg_schur),
api!(sp_linalg_hessenberg = linalg::sp_linalg_hessenberg), api!(sp_linalg_hessenberg = linalg::sp_linalg_hessenberg),

View File

@ -25,6 +25,38 @@ impl InputMatrix {
} }
} }
/// # Safety
///
/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
if !(mat1.ndims == 1 && mat2.ndims == 1) {
let err_msg = format!(
"expected 1D Vector Input, but received {}-D and {}-D input",
mat1.ndims, mat2.ndims
);
artiq_raise!("ValueError", err_msg);
}
let dim1 = (*mat1).get_dims();
let dim2 = (*mat2).get_dims();
if dim1[0] != dim2[0] {
let err_msg = format!("shapes ({},) and ({},) not aligned", dim1[0], dim2[0]);
artiq_raise!("ValueError", err_msg);
}
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2);
matrix1.dot(&matrix2)
}
/// # Safety /// # Safety
/// ///
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order /// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
@ -75,7 +107,7 @@ pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut In
let out = out.as_mut().unwrap(); let out = out.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
@ -114,7 +146,7 @@ pub unsafe extern "C" fn np_linalg_qr(mat1: *mut InputMatrix, out_q: *mut InputM
let out_r = out_r.as_mut().unwrap(); let out_r = out_r.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
@ -153,7 +185,7 @@ pub unsafe extern "C" fn np_linalg_svd(
let outvh = outvh.as_mut().unwrap(); let outvh = outvh.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
@ -183,7 +215,7 @@ pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMa
let out = out.as_mut().unwrap(); let out = out.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
let dim1 = (*mat1).get_dims(); let dim1 = (*mat1).get_dims();
@ -217,7 +249,7 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
let out = out.as_mut().unwrap(); let out = out.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
let dim1 = (*mat1).get_dims(); let dim1 = (*mat1).get_dims();
@ -239,76 +271,6 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
} }
} }
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matrix_power(mat1: *mut InputMatrix, mat2: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
artiq_raise!("ValueError", err_msg);
}
let dim1 = (*mat1).get_dims();
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
let power = power[0];
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let mut abs_power = power;
if abs_power < 0.0 {
abs_power = abs_power * -1.0;
}
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix1.is_square() {
let err_msg = format!(
"last 2 dimensions of the array must be square: {0} != {1}",
dim1[0], dim1[1]
);
artiq_raise!("LinAlgError", err_msg);
}
let mut result = matrix1.pow(abs_power as u32);
if power < 0.0 {
if !matrix1.is_invertible() {
artiq_raise!("LinAlgError", "no inverse for Singular Matrix");
}
result = result.try_inverse().unwrap();
}
out_slice.copy_from_slice(result.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
artiq_raise!("ValueError", err_msg);
}
let dim1 = (*mat1).get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix.is_square() {
let err_msg = format!(
"last 2 dimensions of the array must be square: {0} != {1}",
dim1[0], dim1[1]
);
artiq_raise!("LinAlgError", err_msg);
}
out_slice[0] = matrix.determinant();
}
/// # Safety /// # Safety
/// ///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order /// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
@ -319,7 +281,7 @@ pub unsafe extern "C" fn sp_linalg_lu(mat1: *mut InputMatrix, out_l: *mut InputM
let out_u = out_u.as_mut().unwrap(); let out_u = out_u.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
@ -348,7 +310,7 @@ pub unsafe extern "C" fn sp_linalg_schur(mat1: *mut InputMatrix, out_t: *mut Inp
let out_z = out_z.as_mut().unwrap(); let out_z = out_z.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
@ -390,7 +352,7 @@ pub unsafe extern "C" fn sp_linalg_hessenberg(
let out_q = out_q.as_mut().unwrap(); let out_q = out_q.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
artiq_raise!("ValueError", err_msg); artiq_raise!("ValueError", err_msg);
} }
@ -417,3 +379,25 @@ pub unsafe extern "C" fn sp_linalg_hessenberg(
out_h_slice.copy_from_slice(h.transpose().as_slice()); out_h_slice.copy_from_slice(h.transpose().as_slice());
out_q_slice.copy_from_slice(q.transpose().as_slice()); out_q_slice.copy_from_slice(q.transpose().as_slice());
} }
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_transpose(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
artiq_raise!("ValueError", err_msg);
}
let dim1 = (*mat1).get_dims();
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
out_slice.copy_from_slice(matrix.as_slice());
}