core: remove np_linalg_matmul
This commit is contained in:
parent
f8d3a374e6
commit
2237137f1a
|
@ -1867,55 +1867,6 @@ fn build_output_struct<'ctx>(
|
||||||
out_ptr
|
out_ptr
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `np_linalg_matmul` linalg function
|
|
||||||
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
x2: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "np_linalg_matmul";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let (x2_ty, x2) = x2;
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
|
||||||
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
|
|
||||||
|
|
||||||
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
|
|
||||||
else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
|
||||||
|
|
||||||
let outdim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let outdim1 = unsafe {
|
|
||||||
n2.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_matmul(ctx, x1, x2, out, None);
|
|
||||||
Ok(out)
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invokes the `np_linalg_cholesky` linalg function
|
/// Invokes the `np_linalg_cholesky` linalg function
|
||||||
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
|
|
@ -179,7 +179,6 @@ macro_rules! generate_linalg_extern_fn {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
generate_linalg_extern_fn!(call_np_linalg_matmul, "np_linalg_matmul", 3);
|
|
||||||
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
|
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
|
||||||
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
|
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
|
||||||
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
|
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
|
||||||
|
|
|
@ -562,7 +562,6 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
PrimDef::FunNpDot
|
PrimDef::FunNpDot
|
||||||
| PrimDef::FunNpLinalgMatmul
|
|
||||||
| PrimDef::FunNpLinalgCholesky
|
| PrimDef::FunNpLinalgCholesky
|
||||||
| PrimDef::FunNpLinalgQr
|
| PrimDef::FunNpLinalgQr
|
||||||
| PrimDef::FunNpLinalgSvd
|
| PrimDef::FunNpLinalgSvd
|
||||||
|
@ -1950,7 +1949,6 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
prim,
|
prim,
|
||||||
&[
|
&[
|
||||||
PrimDef::FunNpDot,
|
PrimDef::FunNpDot,
|
||||||
PrimDef::FunNpLinalgMatmul,
|
|
||||||
PrimDef::FunNpLinalgCholesky,
|
PrimDef::FunNpLinalgCholesky,
|
||||||
PrimDef::FunNpLinalgQr,
|
PrimDef::FunNpLinalgQr,
|
||||||
PrimDef::FunNpLinalgSvd,
|
PrimDef::FunNpLinalgSvd,
|
||||||
|
@ -1981,27 +1979,6 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
|
||||||
PrimDef::FunNpLinalgMatmul => create_fn_by_codegen(
|
|
||||||
self.unifier,
|
|
||||||
&VarMap::new(),
|
|
||||||
prim.name(),
|
|
||||||
self.ndarray_float_2d,
|
|
||||||
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")],
|
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
|
||||||
let x1_ty = fun.0.args[0].ty;
|
|
||||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
|
||||||
let x2_ty = fun.0.args[1].ty;
|
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
|
||||||
|
|
||||||
Ok(Some(builtin_fns::call_np_linalg_matmul(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(x1_ty, x1_val),
|
|
||||||
(x2_ty, x2_val),
|
|
||||||
)?))
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
|
|
||||||
PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => {
|
PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => {
|
||||||
create_fn_by_codegen(
|
create_fn_by_codegen(
|
||||||
self.unifier,
|
self.unifier,
|
||||||
|
|
|
@ -104,7 +104,6 @@ pub enum PrimDef {
|
||||||
|
|
||||||
// Linalg functions
|
// Linalg functions
|
||||||
FunNpDot,
|
FunNpDot,
|
||||||
FunNpLinalgMatmul,
|
|
||||||
FunNpLinalgCholesky,
|
FunNpLinalgCholesky,
|
||||||
FunNpLinalgQr,
|
FunNpLinalgQr,
|
||||||
FunNpLinalgSvd,
|
FunNpLinalgSvd,
|
||||||
|
@ -291,7 +290,6 @@ impl PrimDef {
|
||||||
|
|
||||||
// Linalg functions
|
// Linalg functions
|
||||||
PrimDef::FunNpDot => fun("np_dot", None),
|
PrimDef::FunNpDot => fun("np_dot", None),
|
||||||
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
|
|
||||||
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
|
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
|
||||||
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
|
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
|
||||||
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
|
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
|
||||||
|
|
|
@ -5,8 +5,8 @@ import importlib.util
|
||||||
import importlib.machinery
|
import importlib.machinery
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy as sp
|
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
import scipy as sp
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from numpy import int32, int64, uint32, uint64
|
from numpy import int32, int64, uint32, uint64
|
||||||
|
@ -231,7 +231,6 @@ def patch(module):
|
||||||
|
|
||||||
# Linalg functions
|
# Linalg functions
|
||||||
module.np_dot = np.dot
|
module.np_dot = np.dot
|
||||||
module.np_linalg_matmul = np.matmul
|
|
||||||
module.np_linalg_cholesky = np.linalg.cholesky
|
module.np_linalg_cholesky = np.linalg.cholesky
|
||||||
module.np_linalg_qr = np.linalg.qr
|
module.np_linalg_qr = np.linalg.qr
|
||||||
module.np_linalg_svd = np.linalg.svd
|
module.np_linalg_svd = np.linalg.svd
|
||||||
|
|
|
@ -34,51 +34,6 @@ impl InputMatrix {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn np_linalg_matmul(
|
|
||||||
mat1: *mut InputMatrix,
|
|
||||||
mat2: *mut InputMatrix,
|
|
||||||
out: *mut InputMatrix,
|
|
||||||
) {
|
|
||||||
let mat1 = mat1.as_mut().unwrap();
|
|
||||||
let mat2 = mat2.as_mut().unwrap();
|
|
||||||
let out = out.as_mut().unwrap();
|
|
||||||
|
|
||||||
if !(mat1.ndims == 2 && mat2.ndims == 2) {
|
|
||||||
let err_msg = format!(
|
|
||||||
"expected 2D Vector Input, but received {}D and {}D input",
|
|
||||||
mat1.ndims, mat2.ndims
|
|
||||||
);
|
|
||||||
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
let dim1 = (*mat1).get_dims();
|
|
||||||
let dim2 = (*mat2).get_dims();
|
|
||||||
|
|
||||||
if dim1[1] != dim2[0] {
|
|
||||||
let err_msg = format!(
|
|
||||||
"shapes ({},{}) and ({},{}) not aligned: {} (dim 1) != {} (dim 0)",
|
|
||||||
dim1[0], dim1[1], dim2[0], dim2[1], dim1[1], dim2[0]
|
|
||||||
);
|
|
||||||
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
let outdim = out.get_dims();
|
|
||||||
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
|
||||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
|
||||||
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0] * dim2[1]) };
|
|
||||||
|
|
||||||
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
|
||||||
let matrix2 = DMatrix::from_row_slice(dim2[0], dim2[1], data_slice2);
|
|
||||||
let mut result = DMatrix::<f64>::zeros(outdim[0], outdim[1]);
|
|
||||||
|
|
||||||
matrix1.mul_to(&matrix2, &mut result);
|
|
||||||
out_slice.copy_from_slice(result.transpose().as_slice());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
/// # Safety
|
||||||
///
|
///
|
||||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
|
|
@ -1474,18 +1474,6 @@ def test_ndarray_dot():
|
||||||
output_float64(z5)
|
output_float64(z5)
|
||||||
output_bool(z6)
|
output_bool(z6)
|
||||||
|
|
||||||
def test_ndarray_linalg_matmul():
|
|
||||||
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
|
||||||
y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
|
||||||
z = np_linalg_matmul(x, y)
|
|
||||||
|
|
||||||
m = np_argmax(z)
|
|
||||||
|
|
||||||
output_ndarray_float_2(x)
|
|
||||||
output_ndarray_float_2(y)
|
|
||||||
output_ndarray_float_2(z)
|
|
||||||
output_int64(m)
|
|
||||||
|
|
||||||
def test_ndarray_cholesky():
|
def test_ndarray_cholesky():
|
||||||
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
||||||
y = np_linalg_cholesky(x)
|
y = np_linalg_cholesky(x)
|
||||||
|
@ -1501,7 +1489,7 @@ def test_ndarray_qr():
|
||||||
|
|
||||||
# QR Factorization is not unique and gives different results in numpy and nalgebra
|
# QR Factorization is not unique and gives different results in numpy and nalgebra
|
||||||
# Reverting the decomposition to compare the initial arrays
|
# Reverting the decomposition to compare the initial arrays
|
||||||
a = np_linalg_matmul(y, z)
|
a = y @ z
|
||||||
output_ndarray_float_2(a)
|
output_ndarray_float_2(a)
|
||||||
|
|
||||||
def test_ndarray_linalg_inv():
|
def test_ndarray_linalg_inv():
|
||||||
|
@ -1540,7 +1528,7 @@ def test_ndarray_schur():
|
||||||
|
|
||||||
# Schur Factorization is not unique and gives different results in scipy and nalgebra
|
# Schur Factorization is not unique and gives different results in scipy and nalgebra
|
||||||
# Reverting the decomposition to compare the initial arrays
|
# Reverting the decomposition to compare the initial arrays
|
||||||
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z))
|
a = (z @ t) @ np_linalg_inv(z)
|
||||||
output_ndarray_float_2(a)
|
output_ndarray_float_2(a)
|
||||||
|
|
||||||
def test_ndarray_hessenberg():
|
def test_ndarray_hessenberg():
|
||||||
|
@ -1551,7 +1539,7 @@ def test_ndarray_hessenberg():
|
||||||
|
|
||||||
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
|
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
|
||||||
# Reverting the decomposition to compare the initial arrays
|
# Reverting the decomposition to compare the initial arrays
|
||||||
a = np_linalg_matmul(np_linalg_matmul(q, h), np_linalg_inv(q))
|
a = (q @ h) @ np_linalg_inv(q)
|
||||||
output_ndarray_float_2(a)
|
output_ndarray_float_2(a)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1572,7 +1560,7 @@ def test_ndarray_svd():
|
||||||
|
|
||||||
# SVD Factorization is not unique and gives different results in numpy and nalgebra
|
# SVD Factorization is not unique and gives different results in numpy and nalgebra
|
||||||
# Reverting the decomposition to compare the initial arrays
|
# Reverting the decomposition to compare the initial arrays
|
||||||
a = np_linalg_matmul(x, z)
|
a = x @ z
|
||||||
output_ndarray_float_2(a)
|
output_ndarray_float_2(a)
|
||||||
output_ndarray_float_1(y)
|
output_ndarray_float_1(y)
|
||||||
|
|
||||||
|
@ -1759,7 +1747,6 @@ def run() -> int32:
|
||||||
test_ndarray_reshape()
|
test_ndarray_reshape()
|
||||||
|
|
||||||
test_ndarray_dot()
|
test_ndarray_dot()
|
||||||
test_ndarray_linalg_matmul()
|
|
||||||
test_ndarray_cholesky()
|
test_ndarray_cholesky()
|
||||||
test_ndarray_qr()
|
test_ndarray_qr()
|
||||||
test_ndarray_svd()
|
test_ndarray_svd()
|
||||||
|
|
Loading…
Reference in New Issue