Compare commits

...

2 Commits

13 changed files with 127 additions and 256 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
__pycache__ __pycache__
/target /target
/nac3standalone/demo/linalg/target
nix/windows/msys2 nix/windows/msys2

106
Cargo.lock generated
View File

@ -73,15 +73,6 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "ascii-canvas" name = "ascii-canvas"
version = "3.0.0" version = "3.0.0"
@ -256,12 +247,6 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]] [[package]]
name = "dirs-next" name = "dirs-next"
version = "2.0.0" version = "2.0.0"
@ -536,12 +521,6 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]] [[package]]
name = "libredox" name = "libredox"
version = "0.1.3" version = "0.1.3"
@ -552,14 +531,6 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "linalg"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]] [[package]]
name = "linked-hash-map" name = "linked-hash-map"
version = "0.5.6" version = "0.5.6"
@ -688,70 +659,17 @@ version = "0.1.0"
dependencies = [ dependencies = [
"clap", "clap",
"inkwell", "inkwell",
"linalg",
"nac3core", "nac3core",
"nac3parser", "nac3parser",
"parking_lot", "parking_lot",
] ]
[[package]]
name = "nalgebra"
version = "0.32.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
dependencies = [
"approx",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]] [[package]]
name = "new_debug_unreachable" name = "new_debug_unreachable"
version = "1.0.6" version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.19.0" version = "1.19.0"
@ -781,12 +699,6 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]] [[package]]
name = "petgraph" name = "petgraph"
version = "0.6.5" version = "0.6.5"
@ -1158,18 +1070,6 @@ dependencies = [
"yaml-rust", "yaml-rust",
] ]
[[package]]
name = "simba"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
]
[[package]] [[package]]
name = "similar" name = "similar"
version = "2.6.0" version = "2.6.0"
@ -1330,12 +1230,6 @@ dependencies = [
"crunchy", "crunchy",
] ]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]] [[package]]
name = "unic-char-property" name = "unic-char-property"
version = "0.9.0" version = "0.9.0"

View File

@ -1867,55 +1867,6 @@ fn build_output_struct<'ctx>(
out_ptr out_ptr
} }
/// Invokes the `np_linalg_matmul` linalg function
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matmul";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
let outdim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let outdim1 = unsafe {
n2.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_matmul(ctx, x1, x2, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
/// Invokes the `np_linalg_cholesky` linalg function /// Invokes the `np_linalg_cholesky` linalg function
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,

View File

@ -179,7 +179,6 @@ macro_rules! generate_linalg_extern_fn {
}; };
} }
generate_linalg_extern_fn!(call_np_linalg_matmul, "np_linalg_matmul", 3);
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2); generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3); generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4); generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);

View File

@ -562,7 +562,6 @@ impl<'a> BuiltinBuilder<'a> {
} }
PrimDef::FunNpDot PrimDef::FunNpDot
| PrimDef::FunNpLinalgMatmul
| PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgQr | PrimDef::FunNpLinalgQr
| PrimDef::FunNpLinalgSvd | PrimDef::FunNpLinalgSvd
@ -1950,7 +1949,6 @@ impl<'a> BuiltinBuilder<'a> {
prim, prim,
&[ &[
PrimDef::FunNpDot, PrimDef::FunNpDot,
PrimDef::FunNpLinalgMatmul,
PrimDef::FunNpLinalgCholesky, PrimDef::FunNpLinalgCholesky,
PrimDef::FunNpLinalgQr, PrimDef::FunNpLinalgQr,
PrimDef::FunNpLinalgSvd, PrimDef::FunNpLinalgSvd,
@ -1981,27 +1979,6 @@ impl<'a> BuiltinBuilder<'a> {
}), }),
), ),
PrimDef::FunNpLinalgMatmul => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_linalg_matmul(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => { PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => {
create_fn_by_codegen( create_fn_by_codegen(
self.unifier, self.unifier,

View File

@ -104,7 +104,6 @@ pub enum PrimDef {
// Linalg functions // Linalg functions
FunNpDot, FunNpDot,
FunNpLinalgMatmul,
FunNpLinalgCholesky, FunNpLinalgCholesky,
FunNpLinalgQr, FunNpLinalgQr,
FunNpLinalgSvd, FunNpLinalgSvd,
@ -291,7 +290,6 @@ impl PrimDef {
// Linalg functions // Linalg functions
PrimDef::FunNpDot => fun("np_dot", None), PrimDef::FunNpDot => fun("np_dot", None),
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None), PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None), PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None), PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),

View File

@ -8,7 +8,6 @@ edition = "2021"
parking_lot = "0.12" parking_lot = "0.12"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
linalg = { path = "./demo/linalg" }
[dependencies.clap] [dependencies.clap]
version = "4.5" version = "4.5"

View File

@ -5,8 +5,8 @@ import importlib.util
import importlib.machinery import importlib.machinery
import math import math
import numpy as np import numpy as np
import scipy as sp
import numpy.typing as npt import numpy.typing as npt
import scipy as sp
import pathlib import pathlib
from numpy import int32, int64, uint32, uint64 from numpy import int32, int64, uint32, uint64
@ -231,7 +231,6 @@ def patch(module):
# Linalg functions # Linalg functions
module.np_dot = np.dot module.np_dot = np.dot
module.np_linalg_matmul = np.matmul
module.np_linalg_cholesky = np.linalg.cholesky module.np_linalg_cholesky = np.linalg.cholesky
module.np_linalg_qr = np.linalg.qr module.np_linalg_qr = np.linalg.qr
module.np_linalg_svd = np.linalg.svd module.np_linalg_svd = np.linalg.svd

114
nac3standalone/demo/linalg/Cargo.lock generated Normal file
View File

@ -0,0 +1,114 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]]
name = "autocfg"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "linalg"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]]
name = "nalgebra"
version = "0.32.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
dependencies = [
"approx",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "simba"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"

View File

@ -9,3 +9,5 @@ crate-type = ["staticlib"]
[dependencies] [dependencies]
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]} nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
cslice = "0.3.0" cslice = "0.3.0"
[workspace]

View File

@ -34,51 +34,6 @@ impl InputMatrix {
} }
} }
/// # Safety
///
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matmul(
mat1: *mut InputMatrix,
mat2: *mut InputMatrix,
out: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
let out = out.as_mut().unwrap();
if !(mat1.ndims == 2 && mat2.ndims == 2) {
let err_msg = format!(
"expected 2D Vector Input, but received {}D and {}D input",
mat1.ndims, mat2.ndims
);
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let dim2 = (*mat2).get_dims();
if dim1[1] != dim2[0] {
let err_msg = format!(
"shapes ({},{}) and ({},{}) not aligned: {} (dim 1) != {} (dim 0)",
dim1[0], dim1[1], dim2[0], dim2[1], dim1[1], dim2[0]
);
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0] * dim2[1]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2[0], dim2[1], data_slice2);
let mut result = DMatrix::<f64>::zeros(outdim[0], outdim[1]);
matrix1.mul_to(&matrix2, &mut result);
out_slice.copy_from_slice(result.transpose().as_slice());
}
/// # Safety /// # Safety
/// ///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order /// `mat1` should point to a valid 2DArray of `f64` floats in row-major order

View File

@ -42,14 +42,11 @@ done
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
nac3standalone=../../target/debug/nac3standalone nac3standalone=../../target/debug/nac3standalone
linalg=../../target/debug/deps/liblinalg-?*.a
elif [ -e ../../target/release/nac3standalone ]; then elif [ -e ../../target/release/nac3standalone ]; then
nac3standalone=../../target/release/nac3standalone nac3standalone=../../target/release/nac3standalone
linalg=../../target/release/deps/liblinalg-?*.a
else else
# used by Nix builds # used by Nix builds
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
linalg=../../target/x86_64-unknown-linux-gnu/release/deps/liblinalg-?*.a
fi fi
rm -f ./*.o ./*.bc demo rm -f ./*.o ./*.bc demo
@ -57,19 +54,17 @@ rm -f ./*.o ./*.bc demo
if [ -z "$i386" ]; then if [ -z "$i386" ]; then
$nac3standalone "${nac3args[@]}" $nac3standalone "${nac3args[@]}"
cd linalg && cargo build --release --target x86_64-unknown-linux-gnu -q && cd ..
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o $linalg clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o linalg/target/x86_64-unknown-linux-gnu/release/liblinalg.a
else else
# Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations # Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations
$nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}" $nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}"
# Compile linalg crate to provide functions compatible with i386 architecture cd linalg && cargo build --release --target i686-unknown-linux-gnu -q && cd ..
cd linalg && nix-shell -p rustup --command "RUSTFLAGS=\"-C target-cpu=i386 -C target-feature=+sse2\" cargo build -q --release --target=i686-unknown-linux-gnu" && cd .. clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o linalg/target/i686-unknown-linux-gnu/release/liblinalg.a
linalg=../../target/i686-unknown-linux-gnu/release/liblinalg.a
clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o $linalg
fi fi
if [ -z "$outfile" ]; then if [ -z "$outfile" ]; then

View File

@ -1474,18 +1474,6 @@ def test_ndarray_dot():
output_float64(z5) output_float64(z5)
output_bool(z6) output_bool(z6)
def test_ndarray_linalg_matmul():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
z = np_linalg_matmul(x, y)
m = np_argmax(z)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
output_ndarray_float_2(z)
output_int64(m)
def test_ndarray_cholesky(): def test_ndarray_cholesky():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y = np_linalg_cholesky(x) y = np_linalg_cholesky(x)
@ -1501,7 +1489,7 @@ def test_ndarray_qr():
# QR Factorization is not unique and gives different results in numpy and nalgebra # QR Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(y, z) a = y @ z
output_ndarray_float_2(a) output_ndarray_float_2(a)
def test_ndarray_linalg_inv(): def test_ndarray_linalg_inv():
@ -1540,7 +1528,7 @@ def test_ndarray_schur():
# Schur Factorization is not unique and gives different results in scipy and nalgebra # Schur Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z)) a = (z @ t) @ np_linalg_inv(z)
output_ndarray_float_2(a) output_ndarray_float_2(a)
def test_ndarray_hessenberg(): def test_ndarray_hessenberg():
@ -1551,7 +1539,7 @@ def test_ndarray_hessenberg():
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra # Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(np_linalg_matmul(q, h), np_linalg_inv(q)) a = (q @ h) @ np_linalg_inv(q)
output_ndarray_float_2(a) output_ndarray_float_2(a)
@ -1572,7 +1560,7 @@ def test_ndarray_svd():
# SVD Factorization is not unique and gives different results in numpy and nalgebra # SVD Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(x, z) a = x @ z
output_ndarray_float_2(a) output_ndarray_float_2(a)
output_ndarray_float_1(y) output_ndarray_float_1(y)
@ -1759,7 +1747,6 @@ def run() -> int32:
test_ndarray_reshape() test_ndarray_reshape()
test_ndarray_dot() test_ndarray_dot()
test_ndarray_linalg_matmul()
test_ndarray_cholesky() test_ndarray_cholesky()
test_ndarray_qr() test_ndarray_qr()
test_ndarray_svd() test_ndarray_svd()