1
0
forked from M-Labs/nac3

core: remove np_linalg_matmul

This commit is contained in:
abdul124 2024-08-01 18:43:06 +08:00
parent bf709889c4
commit 63d2b49b09
7 changed files with 5 additions and 139 deletions

View File

@ -1867,55 +1867,6 @@ fn build_output_struct<'ctx>(
out_ptr out_ptr
} }
/// Invokes the `np_linalg_matmul` linalg function
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matmul";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
let outdim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let outdim1 = unsafe {
n2.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_matmul(ctx, x1, x2, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
/// Invokes the `np_linalg_cholesky` linalg function /// Invokes the `np_linalg_cholesky` linalg function
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,

View File

@ -179,7 +179,6 @@ macro_rules! generate_linalg_extern_fn {
}; };
} }
generate_linalg_extern_fn!(call_np_linalg_matmul, "np_linalg_matmul", 3);
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2); generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3); generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4); generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);

View File

@ -562,7 +562,6 @@ impl<'a> BuiltinBuilder<'a> {
} }
PrimDef::FunNpDot PrimDef::FunNpDot
| PrimDef::FunNpLinalgMatmul
| PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgQr | PrimDef::FunNpLinalgQr
| PrimDef::FunNpLinalgSvd | PrimDef::FunNpLinalgSvd
@ -1950,7 +1949,6 @@ impl<'a> BuiltinBuilder<'a> {
prim, prim,
&[ &[
PrimDef::FunNpDot, PrimDef::FunNpDot,
PrimDef::FunNpLinalgMatmul,
PrimDef::FunNpLinalgCholesky, PrimDef::FunNpLinalgCholesky,
PrimDef::FunNpLinalgQr, PrimDef::FunNpLinalgQr,
PrimDef::FunNpLinalgSvd, PrimDef::FunNpLinalgSvd,
@ -1981,27 +1979,6 @@ impl<'a> BuiltinBuilder<'a> {
}), }),
), ),
PrimDef::FunNpLinalgMatmul => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_linalg_matmul(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => { PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => {
create_fn_by_codegen( create_fn_by_codegen(
self.unifier, self.unifier,

View File

@ -104,7 +104,6 @@ pub enum PrimDef {
// Linalg functions // Linalg functions
FunNpDot, FunNpDot,
FunNpLinalgMatmul,
FunNpLinalgCholesky, FunNpLinalgCholesky,
FunNpLinalgQr, FunNpLinalgQr,
FunNpLinalgSvd, FunNpLinalgSvd,
@ -291,7 +290,6 @@ impl PrimDef {
// Linalg functions // Linalg functions
PrimDef::FunNpDot => fun("np_dot", None), PrimDef::FunNpDot => fun("np_dot", None),
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None), PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None), PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None), PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),

View File

@ -5,8 +5,8 @@ import importlib.util
import importlib.machinery import importlib.machinery
import math import math
import numpy as np import numpy as np
import scipy as sp
import numpy.typing as npt import numpy.typing as npt
import scipy as sp
import pathlib import pathlib
from numpy import int32, int64, uint32, uint64 from numpy import int32, int64, uint32, uint64
@ -231,7 +231,6 @@ def patch(module):
# Linalg functions # Linalg functions
module.np_dot = np.dot module.np_dot = np.dot
module.np_linalg_matmul = np.matmul
module.np_linalg_cholesky = np.linalg.cholesky module.np_linalg_cholesky = np.linalg.cholesky
module.np_linalg_qr = np.linalg.qr module.np_linalg_qr = np.linalg.qr
module.np_linalg_svd = np.linalg.svd module.np_linalg_svd = np.linalg.svd

View File

@ -34,51 +34,6 @@ impl InputMatrix {
} }
} }
/// # Safety
///
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matmul(
mat1: *mut InputMatrix,
mat2: *mut InputMatrix,
out: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
let out = out.as_mut().unwrap();
if !(mat1.ndims == 2 && mat2.ndims == 2) {
let err_msg = format!(
"expected 2D Vector Input, but received {}D and {}D input",
mat1.ndims, mat2.ndims
);
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let dim2 = (*mat2).get_dims();
if dim1[1] != dim2[0] {
let err_msg = format!(
"shapes ({},{}) and ({},{}) not aligned: {} (dim 1) != {} (dim 0)",
dim1[0], dim1[1], dim2[0], dim2[1], dim1[1], dim2[0]
);
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0] * dim2[1]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2[0], dim2[1], data_slice2);
let mut result = DMatrix::<f64>::zeros(outdim[0], outdim[1]);
matrix1.mul_to(&matrix2, &mut result);
out_slice.copy_from_slice(result.transpose().as_slice());
}
/// # Safety /// # Safety
/// ///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order /// `mat1` should point to a valid 2DArray of `f64` floats in row-major order

View File

@ -1474,18 +1474,6 @@ def test_ndarray_dot():
output_float64(z5) output_float64(z5)
output_bool(z6) output_bool(z6)
def test_ndarray_linalg_matmul():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
z = np_linalg_matmul(x, y)
m = np_argmax(z)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
output_ndarray_float_2(z)
output_int64(m)
def test_ndarray_cholesky(): def test_ndarray_cholesky():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y = np_linalg_cholesky(x) y = np_linalg_cholesky(x)
@ -1501,7 +1489,7 @@ def test_ndarray_qr():
# QR Factorization is not unique and gives different results in numpy and nalgebra # QR Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(y, z) a = y @ z
output_ndarray_float_2(a) output_ndarray_float_2(a)
def test_ndarray_linalg_inv(): def test_ndarray_linalg_inv():
@ -1540,7 +1528,7 @@ def test_ndarray_schur():
# Schur Factorization is not unique and gives different results in scipy and nalgebra # Schur Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z)) a = (z @ t) @ np_linalg_inv(z)
output_ndarray_float_2(a) output_ndarray_float_2(a)
def test_ndarray_hessenberg(): def test_ndarray_hessenberg():
@ -1551,7 +1539,7 @@ def test_ndarray_hessenberg():
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra # Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(np_linalg_matmul(q, h), np_linalg_inv(q)) a = (q @ h) @ np_linalg_inv(q)
output_ndarray_float_2(a) output_ndarray_float_2(a)
@ -1572,7 +1560,7 @@ def test_ndarray_svd():
# SVD Factorization is not unique and gives different results in numpy and nalgebra # SVD Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(x, z) a = x @ z
output_ndarray_float_2(a) output_ndarray_float_2(a)
output_ndarray_float_1(y) output_ndarray_float_1(y)
@ -1759,7 +1747,6 @@ def run() -> int32:
test_ndarray_reshape() test_ndarray_reshape()
test_ndarray_dot() test_ndarray_dot()
test_ndarray_linalg_matmul()
test_ndarray_cholesky() test_ndarray_cholesky()
test_ndarray_qr() test_ndarray_qr()
test_ndarray_svd() test_ndarray_svd()