Compare commits

...

1 Commits

Author SHA1 Message Date
0664f3336b kernel: add np_transpose function 2024-07-29 18:34:12 +08:00
2 changed files with 23 additions and 0 deletions

View File

@ -321,6 +321,7 @@ pub fn resolve(required: &[u8]) -> Option<u32> {
},
// linalg
api!(np_transpose = linalg::np_transpose),
api!(np_dot = linalg::np_dot),
api!(np_linalg_matmul = linalg::np_linalg_matmul),
api!(np_linalg_cholesky = linalg::np_linalg_cholesky),

View File

@ -379,3 +379,25 @@ pub unsafe extern "C" fn sp_linalg_hessenberg(
out_h_slice.copy_from_slice(h.transpose().as_slice());
out_q_slice.copy_from_slice(q.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_transpose(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
artiq_raise!("ValueError", err_msg);
}
let dim1 = (*mat1).get_dims();
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
out_slice.copy_from_slice(matrix.as_slice());
}