forked from M-Labs/nac3
Compare commits
11 Commits
c71a567a51
...
2237137f1a
Author | SHA1 | Date | |
---|---|---|---|
2237137f1a | |||
f8d3a374e6 | |||
1c72698d02 | |||
54f883f0a5 | |||
4a6845dac6 | |||
00236f48bc | |||
a3e6bb2292 | |||
17171065b1 | |||
540b35ec84 | |||
4bb00c52e3 | |||
faf07527cb |
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
/target
|
/target
|
||||||
|
/nac3standalone/demo/linalg/target
|
||||||
nix/windows/msys2
|
nix/windows/msys2
|
||||||
|
106
Cargo.lock
generated
106
Cargo.lock
generated
@ -73,15 +73,6 @@ dependencies = [
|
|||||||
"windows-sys",
|
"windows-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "approx"
|
|
||||||
version = "0.5.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
|
|
||||||
dependencies = [
|
|
||||||
"num-traits",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ascii-canvas"
|
name = "ascii-canvas"
|
||||||
version = "3.0.0"
|
version = "3.0.0"
|
||||||
@ -256,12 +247,6 @@ version = "0.2.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "cslice"
|
|
||||||
version = "0.3.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dirs-next"
|
name = "dirs-next"
|
||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
@ -536,12 +521,6 @@ dependencies = [
|
|||||||
"windows-targets",
|
"windows-targets",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "libm"
|
|
||||||
version = "0.2.8"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libredox"
|
name = "libredox"
|
||||||
version = "0.1.3"
|
version = "0.1.3"
|
||||||
@ -552,14 +531,6 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "linalg"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"cslice",
|
|
||||||
"nalgebra",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "linked-hash-map"
|
name = "linked-hash-map"
|
||||||
version = "0.5.6"
|
version = "0.5.6"
|
||||||
@ -688,70 +659,17 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"inkwell",
|
"inkwell",
|
||||||
"linalg",
|
|
||||||
"nac3core",
|
"nac3core",
|
||||||
"nac3parser",
|
"nac3parser",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "nalgebra"
|
|
||||||
version = "0.32.6"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
|
|
||||||
dependencies = [
|
|
||||||
"approx",
|
|
||||||
"num-complex",
|
|
||||||
"num-rational",
|
|
||||||
"num-traits",
|
|
||||||
"simba",
|
|
||||||
"typenum",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "new_debug_unreachable"
|
name = "new_debug_unreachable"
|
||||||
version = "1.0.6"
|
version = "1.0.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
|
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "num-complex"
|
|
||||||
version = "0.4.6"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
|
||||||
dependencies = [
|
|
||||||
"num-traits",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "num-integer"
|
|
||||||
version = "0.1.46"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
|
||||||
dependencies = [
|
|
||||||
"num-traits",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "num-rational"
|
|
||||||
version = "0.4.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
|
||||||
dependencies = [
|
|
||||||
"num-integer",
|
|
||||||
"num-traits",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "num-traits"
|
|
||||||
version = "0.2.19"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
|
||||||
dependencies = [
|
|
||||||
"autocfg",
|
|
||||||
"libm",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "once_cell"
|
name = "once_cell"
|
||||||
version = "1.19.0"
|
version = "1.19.0"
|
||||||
@ -781,12 +699,6 @@ dependencies = [
|
|||||||
"windows-targets",
|
"windows-targets",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "paste"
|
|
||||||
version = "1.0.15"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "petgraph"
|
name = "petgraph"
|
||||||
version = "0.6.5"
|
version = "0.6.5"
|
||||||
@ -1158,18 +1070,6 @@ dependencies = [
|
|||||||
"yaml-rust",
|
"yaml-rust",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "simba"
|
|
||||||
version = "0.8.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
|
|
||||||
dependencies = [
|
|
||||||
"approx",
|
|
||||||
"num-complex",
|
|
||||||
"num-traits",
|
|
||||||
"paste",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "similar"
|
name = "similar"
|
||||||
version = "2.6.0"
|
version = "2.6.0"
|
||||||
@ -1330,12 +1230,6 @@ dependencies = [
|
|||||||
"crunchy",
|
"crunchy",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "typenum"
|
|
||||||
version = "1.17.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unic-char-property"
|
name = "unic-char-property"
|
||||||
version = "0.9.0"
|
version = "0.9.0"
|
||||||
|
@ -161,9 +161,7 @@
|
|||||||
clippy
|
clippy
|
||||||
pre-commit
|
pre-commit
|
||||||
rustfmt
|
rustfmt
|
||||||
rust-analyzer
|
|
||||||
];
|
];
|
||||||
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
|
||||||
};
|
};
|
||||||
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
||||||
name = "nac3-dev-shell-msys2";
|
name = "nac3-dev-shell-msys2";
|
||||||
|
@ -3,7 +3,9 @@ use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
|
|||||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
|
use crate::codegen::classes::{
|
||||||
|
NDArrayValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
|
};
|
||||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||||
@ -1865,83 +1867,6 @@ fn build_output_struct<'ctx>(
|
|||||||
out_ptr
|
out_ptr
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `np_dot` linalg function
|
|
||||||
pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
x2: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "np_dot";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let (x2_ty, x2) = x2;
|
|
||||||
|
|
||||||
if let (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) = (x1, x2) {
|
|
||||||
let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty);
|
|
||||||
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
|
||||||
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
|
|
||||||
|
|
||||||
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
|
|
||||||
else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(extern_fns::call_np_dot(ctx, x1, x2, None).into())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invokes the `np_linalg_matmul` linalg function
|
|
||||||
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
x2: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "np_linalg_matmul";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let (x2_ty, x2) = x2;
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
|
||||||
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
|
|
||||||
|
|
||||||
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
|
|
||||||
else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
|
||||||
|
|
||||||
let outdim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let outdim1 = unsafe {
|
|
||||||
n2.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_matmul(ctx, x1, x2, out, None);
|
|
||||||
Ok(out)
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invokes the `np_linalg_cholesky` linalg function
|
/// Invokes the `np_linalg_cholesky` linalg function
|
||||||
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
@ -2224,6 +2149,104 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_linalg_matrix_power` linalg function
|
||||||
|
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
x2: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let (x2_ty, x2) = x2;
|
||||||
|
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||||
|
let n2_array = numpy::create_ndarray_const_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&[llvm_usize.const_int(1, false)],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
unsafe {
|
||||||
|
n2_array.data().set_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_zero(),
|
||||||
|
n2.as_basic_value_enum(),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
let n2_array = n2_array.as_base_value().as_basic_value_enum();
|
||||||
|
|
||||||
|
let outdim0 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let outdim1 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
|
||||||
|
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
|
||||||
|
Ok(out)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_linalg_det` linalg function
|
||||||
|
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
if let BasicValueEnum::PointerValue(_) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||||
|
let out = numpy::create_ndarray_const_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&[llvm_usize.const_int(1, false)],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
|
||||||
|
let res =
|
||||||
|
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||||
|
Ok(res)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Invokes the `sp_linalg_schur` linalg function
|
/// Invokes the `sp_linalg_schur` linalg function
|
||||||
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -179,42 +179,13 @@ macro_rules! generate_linalg_extern_fn {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
generate_linalg_extern_fn!(call_np_linalg_matmul, "np_linalg_matmul", 3);
|
|
||||||
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
|
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
|
||||||
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
|
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
|
||||||
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
|
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
|
||||||
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
|
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
|
||||||
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
|
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
|
||||||
|
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
|
||||||
|
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
|
||||||
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
|
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
|
||||||
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
|
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
|
||||||
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
|
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
|
||||||
|
|
||||||
/// Invokes the linalg `np_dot` function.
|
|
||||||
pub fn call_np_dot<'ctx>(
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
mat1: BasicValueEnum<'ctx>,
|
|
||||||
mat2: BasicValueEnum<'ctx>,
|
|
||||||
name: Option<&str>,
|
|
||||||
) -> FloatValue<'ctx> {
|
|
||||||
const FN_NAME: &str = "np_dot";
|
|
||||||
|
|
||||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
|
||||||
let fn_type =
|
|
||||||
ctx.ctx.f64_type().fn_type(&[mat1.get_type().into(), mat2.get_type().into()], false);
|
|
||||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
|
||||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
|
||||||
func.add_attribute(
|
|
||||||
AttributeLoc::Function,
|
|
||||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
func
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(extern_fn, &[mat1.into(), mat2.into()], name.unwrap_or_default())
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
@ -26,12 +26,15 @@ use crate::{
|
|||||||
typedef::{FunSignature, Type, TypeEnum},
|
typedef::{FunSignature, Type, TypeEnum},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
|
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::BasicType,
|
types::BasicType,
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
|
use inkwell::{
|
||||||
|
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
|
||||||
|
values::BasicValue,
|
||||||
|
};
|
||||||
use nac3parser::ast::{Operator, StrRef};
|
use nac3parser::ast::{Operator, StrRef};
|
||||||
|
|
||||||
/// Creates an uninitialized `NDArray` instance.
|
/// Creates an uninitialized `NDArray` instance.
|
||||||
@ -2027,6 +2030,7 @@ pub fn gen_ndarray_fill<'ctx>(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.transpose`.
|
||||||
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
@ -2041,6 +2045,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
// Dimensions are reversed in the transposed array
|
||||||
let out = create_ndarray_dyn_shape(
|
let out = create_ndarray_dyn_shape(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
@ -2066,30 +2071,16 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
(n_sz, false),
|
(n_sz, false),
|
||||||
|generator, ctx, _, idx| {
|
|generator, ctx, _, idx| {
|
||||||
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
// Calculate transposed idx
|
|
||||||
// 2, 3 => idx = row * num_col + col = 3 + 2 = 5
|
|
||||||
// 2, 3, 4 => idx = row * (num_col*num_z) + col * (num_z) + z => 12 + 8 + 3 | 18, [4, 2] 15, ,1,2
|
|
||||||
// num_z (col + num_col * row) + z
|
|
||||||
// 4D => 2, 3, 4, 5 = idx = num_w (num_z (col + num_col*row) + z) + w = 119
|
|
||||||
// [1, 1, 2]?
|
|
||||||
// 4 (1 + 3*1) + 2 = 6
|
|
||||||
// z = 2, col = 1, row = 0
|
|
||||||
// 0,1,
|
|
||||||
// 2, 3 => idx = row * num_col + col | 2, 3, 4 => idx = row * (num_col * num_z) + col * (num_z) + num_z
|
|
||||||
// ND => idx = 1 * (dim0 + dim1 + ... dimn) + dim[-1] * (dim0 + dim1 + ... + dimn-1) + ... + dim[1] * dim0
|
|
||||||
// 6 + 12 + 6 = 24 num_z * (row*num_col + col + 1) 4*6=24
|
|
||||||
// 2, 3, 4, 5 at idx 1 should go to
|
|
||||||
// 5, 4, 3, 2
|
|
||||||
|
|
||||||
// 18 => [2, 4] dim = 4
|
|
||||||
// 0 * 4 + 2 = 2
|
|
||||||
// 4 => [1, 1] dim = 3
|
|
||||||
// 2 * 3 + 1
|
|
||||||
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
|
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
|
||||||
ctx.builder.build_store(rem_idx, idx).unwrap();
|
ctx.builder.build_store(rem_idx, idx).unwrap();
|
||||||
|
|
||||||
|
// Incrementally calculate the new index in the transposed array
|
||||||
|
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
|
||||||
|
// The formula used for indexing is:
|
||||||
|
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
|
||||||
gen_for_callback_incrementing(
|
gen_for_callback_incrementing(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
@ -2145,6 +2136,15 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
|
||||||
|
///
|
||||||
|
/// * `x1` - `NDArray` to reshape.
|
||||||
|
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
|
||||||
|
/// Just like numpy, the `shape` argument can be:
|
||||||
|
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
|
||||||
|
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
||||||
|
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
||||||
|
/// Note that unlike other generating functions, one of the dimesions in the shape can be negative
|
||||||
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
@ -2162,44 +2162,19 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
|
||||||
// Check for -1 in the shapec
|
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
let ndim_ty = match shape {
|
|
||||||
BasicValueEnum::PointerValue(shape_list_ptr)
|
|
||||||
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
|
|
||||||
{
|
|
||||||
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
|
|
||||||
shape_list
|
|
||||||
.data()
|
|
||||||
.get(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
.get_type()
|
|
||||||
}
|
|
||||||
BasicValueEnum::StructValue(shape_tuple) => ctx
|
|
||||||
.builder
|
|
||||||
.build_extract_value(shape_tuple, 0, "")
|
|
||||||
.unwrap()
|
|
||||||
.into_int_value()
|
|
||||||
.get_type(),
|
|
||||||
BasicValueEnum::IntValue(shape_int) => shape_int.get_type(),
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let n_sz = ctx
|
|
||||||
.builder
|
|
||||||
.build_cast(inkwell::values::InstructionOpcode::Trunc, n_sz, ndim_ty, "")
|
|
||||||
.unwrap()
|
|
||||||
.into_int_value();
|
|
||||||
|
|
||||||
let acc = generator.gen_var_alloc(ctx, ndim_ty.into(), None)?;
|
|
||||||
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
ctx.builder.build_store(acc, ndim_ty.const_int(1, false)).unwrap();
|
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
|
||||||
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
|
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
|
||||||
|
|
||||||
let out = match shape {
|
let out = match shape {
|
||||||
BasicValueEnum::PointerValue(shape_list_ptr)
|
BasicValueEnum::PointerValue(shape_list_ptr)
|
||||||
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
|
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
|
||||||
{
|
{
|
||||||
|
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
|
||||||
|
|
||||||
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
|
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
|
||||||
|
// Check for -1 in dimensions
|
||||||
gen_for_callback_incrementing(
|
gen_for_callback_incrementing(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
@ -2209,6 +2184,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|generator, ctx, _, idx| {
|
|generator, ctx, _, idx| {
|
||||||
let ele =
|
let ele =
|
||||||
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
||||||
|
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
gen_if_else_expr_callback(
|
gen_if_else_expr_callback(
|
||||||
generator,
|
generator,
|
||||||
@ -2219,7 +2195,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
.build_int_compare(
|
.build_int_compare(
|
||||||
IntPredicate::SLT,
|
IntPredicate::SLT,
|
||||||
ele,
|
ele,
|
||||||
ndim_ty.const_zero(),
|
llvm_usize.const_zero(),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
.unwrap())
|
.unwrap())
|
||||||
@ -2253,6 +2229,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
)?;
|
)?;
|
||||||
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||||
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
||||||
|
// Generate the output shape by filling -1 with `rem`
|
||||||
create_ndarray_dyn_shape(
|
create_ndarray_dyn_shape(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
@ -2262,6 +2239,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|generator, ctx, shape_list, idx| {
|
|generator, ctx, shape_list, idx| {
|
||||||
let dim =
|
let dim =
|
||||||
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
||||||
|
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
Ok(gen_if_else_expr_callback(
|
Ok(gen_if_else_expr_callback(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
@ -2271,7 +2250,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
.build_int_compare(
|
.build_int_compare(
|
||||||
IntPredicate::SLT,
|
IntPredicate::SLT,
|
||||||
dim,
|
dim,
|
||||||
ndim_ty.const_zero(),
|
llvm_usize.const_zero(),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
.unwrap())
|
.unwrap())
|
||||||
@ -2285,16 +2264,17 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
BasicValueEnum::StructValue(shape_tuple) => {
|
BasicValueEnum::StructValue(shape_tuple) => {
|
||||||
let ndims = shape_tuple.get_type().count_fields();
|
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
||||||
let acc = ctx.builder.build_alloca(ndim_ty, "").unwrap();
|
|
||||||
ctx.builder.build_store(acc, ndim_ty.const_int(1, false)).unwrap();
|
|
||||||
|
|
||||||
|
let ndims = shape_tuple.get_type().count_fields();
|
||||||
|
// Check for -1 in dims
|
||||||
for dim_i in 0..ndims {
|
for dim_i in 0..ndims {
|
||||||
let dim = ctx
|
let dim = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_extract_value(shape_tuple, dim_i, "")
|
.build_extract_value(shape_tuple, dim_i, "")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.into_int_value();
|
.into_int_value();
|
||||||
|
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
gen_if_else_expr_callback(
|
gen_if_else_expr_callback(
|
||||||
generator,
|
generator,
|
||||||
@ -2302,7 +2282,12 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|_, ctx| {
|
|_, ctx| {
|
||||||
Ok(ctx
|
Ok(ctx
|
||||||
.builder
|
.builder
|
||||||
.build_int_compare(IntPredicate::SLT, dim, ndim_ty.const_zero(), "")
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
dim,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
.unwrap())
|
.unwrap())
|
||||||
},
|
},
|
||||||
|_, ctx| -> Result<Option<IntValue>, String> {
|
|_, ctx| -> Result<Option<IntValue>, String> {
|
||||||
@ -2328,12 +2313,14 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
||||||
let mut shape = Vec::with_capacity(ndims as usize);
|
let mut shape = Vec::with_capacity(ndims as usize);
|
||||||
|
|
||||||
|
// Reconstruct shape filling negatives with rem
|
||||||
for dim_i in 0..ndims {
|
for dim_i in 0..ndims {
|
||||||
let dim = ctx
|
let dim = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_extract_value(shape_tuple, dim_i, "")
|
.build_extract_value(shape_tuple, dim_i, "")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.into_int_value();
|
.into_int_value();
|
||||||
|
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
let dim = gen_if_else_expr_callback(
|
let dim = gen_if_else_expr_callback(
|
||||||
generator,
|
generator,
|
||||||
@ -2341,7 +2328,12 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|_, ctx| {
|
|_, ctx| {
|
||||||
Ok(ctx
|
Ok(ctx
|
||||||
.builder
|
.builder
|
||||||
.build_int_compare(IntPredicate::SLT, dim, ndim_ty.const_zero(), "")
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
dim,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
.unwrap())
|
.unwrap())
|
||||||
},
|
},
|
||||||
|_, _| Ok(Some(rem)),
|
|_, _| Ok(Some(rem)),
|
||||||
@ -2354,6 +2346,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
||||||
}
|
}
|
||||||
BasicValueEnum::IntValue(shape_int) => {
|
BasicValueEnum::IntValue(shape_int) => {
|
||||||
|
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
||||||
let shape_int = gen_if_else_expr_callback(
|
let shape_int = gen_if_else_expr_callback(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
@ -2363,13 +2356,15 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
.build_int_compare(
|
.build_int_compare(
|
||||||
IntPredicate::SLT,
|
IntPredicate::SLT,
|
||||||
shape_int,
|
shape_int,
|
||||||
ndim_ty.const_zero(),
|
llvm_usize.const_zero(),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
.unwrap())
|
.unwrap())
|
||||||
},
|
},
|
||||||
|_, _| Ok(Some(n_sz)),
|
|_, _| Ok(Some(n_sz)),
|
||||||
|_, _| Ok(Some(shape_int)),
|
|_, ctx| {
|
||||||
|
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
|
||||||
|
},
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.into_int_value();
|
.into_int_value();
|
||||||
@ -2379,6 +2374,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
}
|
}
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
// Only allow one dimension to be negative
|
||||||
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
generator,
|
generator,
|
||||||
@ -2391,17 +2387,13 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// The new shape must be compatible with the old shape
|
||||||
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
|
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
|
||||||
let out_sz = ctx
|
|
||||||
.builder
|
|
||||||
.build_cast(inkwell::values::InstructionOpcode::Trunc, out_sz, ndim_ty, "")
|
|
||||||
.unwrap()
|
|
||||||
.into_int_value();
|
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
generator,
|
generator,
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
||||||
"0:ValueError",
|
"0:ValueError",
|
||||||
"cannot reshape array of size {} into provided shape of size {}",
|
"cannot reshape array of size {0} into provided shape of size {1}",
|
||||||
[Some(n_sz), Some(out_sz), None],
|
[Some(n_sz), Some(out_sz), None],
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
);
|
);
|
||||||
@ -2428,3 +2420,102 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.dot`.
|
||||||
|
/// Calculate inner product of two vectors or literals
|
||||||
|
/// For matrix multiplication use `np_matmul`
|
||||||
|
///
|
||||||
|
/// The input `NDArray` are flattened and treated as 1D
|
||||||
|
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
|
||||||
|
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
x2: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "ndarray_dot";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let (_, x2) = x2;
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
match (x1, x2) {
|
||||||
|
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
||||||
|
|
||||||
|
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"shapes ({0}), ({1}) not aligned",
|
||||||
|
[Some(n1_sz), Some(n2_sz), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
let identity =
|
||||||
|
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||||
|
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
|
||||||
|
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(n1_sz, false),
|
||||||
|
|generator, ctx, _, idx| {
|
||||||
|
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
|
||||||
|
let product = match elem1 {
|
||||||
|
BasicValueEnum::IntValue(e1) => ctx
|
||||||
|
.builder
|
||||||
|
.build_int_mul(e1, elem2.into_int_value(), "")
|
||||||
|
.unwrap()
|
||||||
|
.as_basic_value_enum(),
|
||||||
|
BasicValueEnum::FloatValue(e1) => ctx
|
||||||
|
.builder
|
||||||
|
.build_float_mul(e1, elem2.into_float_value(), "")
|
||||||
|
.unwrap()
|
||||||
|
.as_basic_value_enum(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
||||||
|
let acc_val = match acc_val {
|
||||||
|
BasicValueEnum::IntValue(e1) => ctx
|
||||||
|
.builder
|
||||||
|
.build_int_add(e1, product.into_int_value(), "")
|
||||||
|
.unwrap()
|
||||||
|
.as_basic_value_enum(),
|
||||||
|
BasicValueEnum::FloatValue(e1) => ctx
|
||||||
|
.builder
|
||||||
|
.build_float_add(e1, product.into_float_value(), "")
|
||||||
|
.unwrap()
|
||||||
|
.as_basic_value_enum(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(acc, acc_val).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
||||||
|
Ok(acc_val)
|
||||||
|
}
|
||||||
|
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
|
||||||
|
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
||||||
|
}
|
||||||
|
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
|
||||||
|
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
||||||
|
}
|
||||||
|
_ => unreachable!(
|
||||||
|
"{FN_NAME}() not supported for '{}'",
|
||||||
|
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -558,16 +558,17 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
||||||
|
|
||||||
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
|
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
|
||||||
self.build_np_sp_ndarray_1ary_function(prim)
|
self.build_np_sp_ndarray_function(prim)
|
||||||
}
|
}
|
||||||
|
|
||||||
PrimDef::FunNpDot
|
PrimDef::FunNpDot
|
||||||
| PrimDef::FunNpLinalgMatmul
|
|
||||||
| PrimDef::FunNpLinalgCholesky
|
| PrimDef::FunNpLinalgCholesky
|
||||||
| PrimDef::FunNpLinalgQr
|
| PrimDef::FunNpLinalgQr
|
||||||
| PrimDef::FunNpLinalgSvd
|
| PrimDef::FunNpLinalgSvd
|
||||||
| PrimDef::FunNpLinalgInv
|
| PrimDef::FunNpLinalgInv
|
||||||
| PrimDef::FunNpLinalgPinv
|
| PrimDef::FunNpLinalgPinv
|
||||||
|
| PrimDef::FunNpLinalgMatrixPower
|
||||||
|
| PrimDef::FunNpLinalgDet
|
||||||
| PrimDef::FunSpLinalgLu
|
| PrimDef::FunSpLinalgLu
|
||||||
| PrimDef::FunSpLinalgSchur
|
| PrimDef::FunSpLinalgSchur
|
||||||
| PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim),
|
| PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim),
|
||||||
@ -1889,62 +1890,53 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build 1-ary numpy/scipy functions that take in an ndarray and return a value of the same type as the input.
|
/// Build np/sp functions that take as input `NDArray` only
|
||||||
fn build_np_sp_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
|
||||||
|
|
||||||
let elem_type = self.unifier.get_fresh_var(Some("R".into()), None);
|
|
||||||
let ndarray_type = make_ndarray_ty(self.unifier, self.primitives, Some(elem_type.ty), None);
|
|
||||||
let ndarray_ty =
|
|
||||||
self.unifier.get_fresh_var_with_range(&[ndarray_type], Some("T".into()), None);
|
|
||||||
let var_map = into_var_map([elem_type, ndarray_ty]);
|
|
||||||
|
|
||||||
match prim {
|
match prim {
|
||||||
PrimDef::FunNpTranspose => create_fn_by_codegen(
|
PrimDef::FunNpTranspose => {
|
||||||
|
let ndarray_ty = self.unifier.get_fresh_var_with_range(
|
||||||
|
&[self.ndarray_num_ty],
|
||||||
|
Some("T".into()),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
create_fn_by_codegen(
|
||||||
self.unifier,
|
self.unifier,
|
||||||
&var_map,
|
&into_var_map([ndarray_ty]),
|
||||||
prim.name(),
|
prim.name(),
|
||||||
ndarray_ty.ty,
|
ndarray_ty.ty,
|
||||||
&[(ndarray_ty.ty, "x")],
|
&[(ndarray_ty.ty, "x")],
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let arg_ty = fun.0.args[0].ty;
|
let arg_ty = fun.0.args[0].ty;
|
||||||
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg_val =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
|
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
|
||||||
}),
|
}),
|
||||||
),
|
|
||||||
|
|
||||||
PrimDef::FunNpReshape => {
|
|
||||||
// Return type can have differend ndims
|
|
||||||
let ret_ty =
|
|
||||||
self.unifier.get_fresh_var_with_range(&[ndarray_type], Some("U".into()), None);
|
|
||||||
let var_map = var_map
|
|
||||||
.clone()
|
|
||||||
.into_iter()
|
|
||||||
.chain(once((ret_ty.id, ret_ty.ty)))
|
|
||||||
.chain(once((
|
|
||||||
self.ndarray_factory_fn_shape_arg_tvar.id,
|
|
||||||
self.ndarray_factory_fn_shape_arg_tvar.ty,
|
|
||||||
)))
|
|
||||||
.collect::<IndexMap<_, _>>();
|
|
||||||
|
|
||||||
create_fn_by_codegen(
|
|
||||||
self.unifier,
|
|
||||||
&var_map,
|
|
||||||
prim.name(),
|
|
||||||
ret_ty.ty,
|
|
||||||
&[(ndarray_ty.ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
|
||||||
let x1_ty = fun.0.args[0].ty;
|
|
||||||
let x1_val =
|
|
||||||
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
|
||||||
let x2_ty = fun.0.args[1].ty;
|
|
||||||
let x2_val =
|
|
||||||
args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
|
||||||
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
|
||||||
|
// the `param_ty` for `create_fn_by_codegen`.
|
||||||
|
//
|
||||||
|
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
|
||||||
|
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
||||||
|
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
||||||
|
PrimDef::FunNpReshape => create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&VarMap::new(),
|
||||||
|
prim.name(),
|
||||||
|
self.ndarray_num_ty,
|
||||||
|
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
let x2_ty = fun.0.args[1].ty;
|
||||||
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1957,12 +1949,13 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
prim,
|
prim,
|
||||||
&[
|
&[
|
||||||
PrimDef::FunNpDot,
|
PrimDef::FunNpDot,
|
||||||
PrimDef::FunNpLinalgMatmul,
|
|
||||||
PrimDef::FunNpLinalgCholesky,
|
PrimDef::FunNpLinalgCholesky,
|
||||||
PrimDef::FunNpLinalgQr,
|
PrimDef::FunNpLinalgQr,
|
||||||
PrimDef::FunNpLinalgSvd,
|
PrimDef::FunNpLinalgSvd,
|
||||||
PrimDef::FunNpLinalgInv,
|
PrimDef::FunNpLinalgInv,
|
||||||
PrimDef::FunNpLinalgPinv,
|
PrimDef::FunNpLinalgPinv,
|
||||||
|
PrimDef::FunNpLinalgMatrixPower,
|
||||||
|
PrimDef::FunNpLinalgDet,
|
||||||
PrimDef::FunSpLinalgLu,
|
PrimDef::FunSpLinalgLu,
|
||||||
PrimDef::FunSpLinalgSchur,
|
PrimDef::FunSpLinalgSchur,
|
||||||
PrimDef::FunSpLinalgHessenberg,
|
PrimDef::FunSpLinalgHessenberg,
|
||||||
@ -1974,7 +1967,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
self.unifier,
|
self.unifier,
|
||||||
&self.num_or_ndarray_var_map,
|
&self.num_or_ndarray_var_map,
|
||||||
prim.name(),
|
prim.name(),
|
||||||
self.primitives.float,
|
self.num_ty.ty,
|
||||||
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
|
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let x1_ty = fun.0.args[0].ty;
|
let x1_ty = fun.0.args[0].ty;
|
||||||
@ -1982,33 +1975,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
let x2_ty = fun.0.args[1].ty;
|
let x2_ty = fun.0.args[1].ty;
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
|
||||||
Ok(Some(builtin_fns::call_np_dot(
|
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(x1_ty, x1_val),
|
|
||||||
(x2_ty, x2_val),
|
|
||||||
)?))
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
|
|
||||||
PrimDef::FunNpLinalgMatmul => create_fn_by_codegen(
|
|
||||||
self.unifier,
|
|
||||||
&VarMap::new(),
|
|
||||||
prim.name(),
|
|
||||||
self.ndarray_float_2d,
|
|
||||||
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")],
|
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
|
||||||
let x1_ty = fun.0.args[0].ty;
|
|
||||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
|
||||||
let x2_ty = fun.0.args[1].ty;
|
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
|
||||||
|
|
||||||
Ok(Some(builtin_fns::call_np_linalg_matmul(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(x1_ty, x1_val),
|
|
||||||
(x2_ty, x2_val),
|
|
||||||
)?))
|
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
|
||||||
@ -2086,10 +2053,39 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
_ => {
|
PrimDef::FunNpLinalgMatrixPower => create_fn_by_codegen(
|
||||||
println!("{:?}", prim.name());
|
self.unifier,
|
||||||
unreachable!()
|
&VarMap::new(),
|
||||||
}
|
prim.name(),
|
||||||
|
self.ndarray_float_2d,
|
||||||
|
&[(self.ndarray_float_2d, "x1"), (self.primitives.int32, "power")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
let x2_ty = fun.0.args[1].ty;
|
||||||
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
|
||||||
|
Ok(Some(builtin_fns::call_np_linalg_matrix_power(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(x1_ty, x1_val),
|
||||||
|
(x2_ty, x2_val),
|
||||||
|
)?))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
PrimDef::FunNpLinalgDet => create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&VarMap::new(),
|
||||||
|
prim.name(),
|
||||||
|
self.primitives.float,
|
||||||
|
&[(self.ndarray_float_2d, "x1")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
Ok(Some(builtin_fns::call_np_linalg_det(generator, ctx, (x1_ty, x1_val))?))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,12 +104,13 @@ pub enum PrimDef {
|
|||||||
|
|
||||||
// Linalg functions
|
// Linalg functions
|
||||||
FunNpDot,
|
FunNpDot,
|
||||||
FunNpLinalgMatmul,
|
|
||||||
FunNpLinalgCholesky,
|
FunNpLinalgCholesky,
|
||||||
FunNpLinalgQr,
|
FunNpLinalgQr,
|
||||||
FunNpLinalgSvd,
|
FunNpLinalgSvd,
|
||||||
FunNpLinalgInv,
|
FunNpLinalgInv,
|
||||||
FunNpLinalgPinv,
|
FunNpLinalgPinv,
|
||||||
|
FunNpLinalgMatrixPower,
|
||||||
|
FunNpLinalgDet,
|
||||||
FunSpLinalgLu,
|
FunSpLinalgLu,
|
||||||
FunSpLinalgSchur,
|
FunSpLinalgSchur,
|
||||||
FunSpLinalgHessenberg,
|
FunSpLinalgHessenberg,
|
||||||
@ -289,12 +290,13 @@ impl PrimDef {
|
|||||||
|
|
||||||
// Linalg functions
|
// Linalg functions
|
||||||
PrimDef::FunNpDot => fun("np_dot", None),
|
PrimDef::FunNpDot => fun("np_dot", None),
|
||||||
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
|
|
||||||
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
|
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
|
||||||
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
|
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
|
||||||
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
|
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
|
||||||
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
|
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
|
||||||
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
|
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
|
||||||
|
PrimDef::FunNpLinalgMatrixPower => fun("np_linalg_matrix_power", None),
|
||||||
|
PrimDef::FunNpLinalgDet => fun("np_linalg_det", None),
|
||||||
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
|
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
|
||||||
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
|
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
|
||||||
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
|
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
|
||||||
|
@ -5,7 +5,7 @@ expression: res_vec
|
|||||||
[
|
[
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
|
@ -7,7 +7,7 @@ expression: res_vec
|
|||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
@ -5,8 +5,8 @@ expression: res_vec
|
|||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
|||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
@ -6,12 +6,12 @@ expression: res_vec
|
|||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n",
|
||||||
]
|
]
|
||||||
|
@ -1130,6 +1130,44 @@ impl<'a> Inferencer<'a> {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if id == &"np_dot".into() {
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
let arg1 = self.fold_expr(args.remove(0))?;
|
||||||
|
let arg0_ty = arg0.custom.unwrap();
|
||||||
|
|
||||||
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|
{
|
||||||
|
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||||
|
|
||||||
|
ndarray_dtype
|
||||||
|
} else {
|
||||||
|
arg0_ty
|
||||||
|
};
|
||||||
|
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
|
||||||
|
FuncArg { name: "x2".into(), ty: arg1.custom.unwrap(), default_value: None },
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||||
|
}),
|
||||||
|
args: vec![arg0, arg1],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
||||||
let arg0 = self.fold_expr(args.remove(0))?;
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
let arg0_ty = arg0.custom.unwrap();
|
let arg0_ty = arg0.custom.unwrap();
|
||||||
@ -1389,7 +1427,45 @@ impl<'a> Inferencer<'a> {
|
|||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
// 2-argument ndarray n-dimensional factory functions
|
||||||
|
if id == &"np_reshape".into() && args.len() == 2 {
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
|
||||||
|
let shape_expr = args.remove(0);
|
||||||
|
let (ndims, shape) =
|
||||||
|
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
|
||||||
|
|
||||||
|
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(self.unifier, arg0.custom.unwrap());
|
||||||
|
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(elem_ty), Some(ndims));
|
||||||
|
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
|
||||||
|
FuncArg {
|
||||||
|
name: "shape".into(),
|
||||||
|
ty: shape.custom.unwrap(),
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||||
|
}),
|
||||||
|
args: vec![arg0, shape],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
// 2-argument ndarray n-dimensional creation functions
|
// 2-argument ndarray n-dimensional creation functions
|
||||||
if id == &"np_full".into() && args.len() == 2 {
|
if id == &"np_full".into() && args.len() == 2 {
|
||||||
let ExprKind::List { elts, .. } = &args[0].node else {
|
let ExprKind::List { elts, .. } = &args[0].node else {
|
||||||
|
@ -8,7 +8,6 @@ edition = "2021"
|
|||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
nac3parser = { path = "../nac3parser" }
|
nac3parser = { path = "../nac3parser" }
|
||||||
nac3core = { path = "../nac3core" }
|
nac3core = { path = "../nac3core" }
|
||||||
linalg = { path = "./demo/linalg" }
|
|
||||||
|
|
||||||
[dependencies.clap]
|
[dependencies.clap]
|
||||||
version = "4.5"
|
version = "4.5"
|
||||||
|
@ -5,8 +5,8 @@ import importlib.util
|
|||||||
import importlib.machinery
|
import importlib.machinery
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy as sp
|
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
import scipy as sp
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from numpy import int32, int64, uint32, uint64
|
from numpy import int32, int64, uint32, uint64
|
||||||
@ -231,12 +231,13 @@ def patch(module):
|
|||||||
|
|
||||||
# Linalg functions
|
# Linalg functions
|
||||||
module.np_dot = np.dot
|
module.np_dot = np.dot
|
||||||
module.np_linalg_matmul = np.matmul
|
|
||||||
module.np_linalg_cholesky = np.linalg.cholesky
|
module.np_linalg_cholesky = np.linalg.cholesky
|
||||||
module.np_linalg_qr = np.linalg.qr
|
module.np_linalg_qr = np.linalg.qr
|
||||||
module.np_linalg_svd = np.linalg.svd
|
module.np_linalg_svd = np.linalg.svd
|
||||||
module.np_linalg_inv = np.linalg.inv
|
module.np_linalg_inv = np.linalg.inv
|
||||||
module.np_linalg_pinv = np.linalg.pinv
|
module.np_linalg_pinv = np.linalg.pinv
|
||||||
|
module.np_linalg_matrix_power = np.linalg.matrix_power
|
||||||
|
module.np_linalg_det = np.linalg.det
|
||||||
|
|
||||||
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
|
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
|
||||||
module.sp_linalg_schur = sp.linalg.schur
|
module.sp_linalg_schur = sp.linalg.schur
|
||||||
|
114
nac3standalone/demo/linalg/Cargo.lock
generated
Normal file
114
nac3standalone/demo/linalg/Cargo.lock
generated
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
# This file is automatically @generated by Cargo.
|
||||||
|
# It is not intended for manual editing.
|
||||||
|
version = 3
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "approx"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "autocfg"
|
||||||
|
version = "1.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cslice"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libm"
|
||||||
|
version = "0.2.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "linalg"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"cslice",
|
||||||
|
"nalgebra",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nalgebra"
|
||||||
|
version = "0.32.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
|
||||||
|
dependencies = [
|
||||||
|
"approx",
|
||||||
|
"num-complex",
|
||||||
|
"num-rational",
|
||||||
|
"num-traits",
|
||||||
|
"simba",
|
||||||
|
"typenum",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.4.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-integer"
|
||||||
|
version = "0.1.46"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-rational"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
||||||
|
dependencies = [
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-traits"
|
||||||
|
version = "0.2.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"libm",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "paste"
|
||||||
|
version = "1.0.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "simba"
|
||||||
|
version = "0.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
|
||||||
|
dependencies = [
|
||||||
|
"approx",
|
||||||
|
"num-complex",
|
||||||
|
"num-traits",
|
||||||
|
"paste",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typenum"
|
||||||
|
version = "1.17.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
@ -9,3 +9,5 @@ crate-type = ["staticlib"]
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
|
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
|
||||||
cslice = "0.3.0"
|
cslice = "0.3.0"
|
||||||
|
|
||||||
|
[workspace]
|
||||||
|
6
nac3standalone/demo/linalg/build.sh
Executable file
6
nac3standalone/demo/linalg/build.sh
Executable file
@ -0,0 +1,6 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Uses rustup to compile the linalg library for i386 and x86_84 architecture
|
||||||
|
|
||||||
|
nix-shell -p rustup --command "RUSTC_BOOTSTRAP=1 cargo build -Z unstable-options --target x86_64-unknown-linux-gnu --out-dir liblinalg/x86_64"
|
||||||
|
nix-shell -p rustup --command "RUSTC_BOOTSTRAP=1 RUSTFLAGS=\"-C target-cpu=i386 -C target-feature=+sse2\" cargo build -Z unstable-options --target i686-unknown-linux-gnu --out-dir liblinalg/i386"
|
BIN
nac3standalone/demo/linalg/liblinalg/i386/liblinalg.a
Normal file
BIN
nac3standalone/demo/linalg/liblinalg/i386/liblinalg.a
Normal file
Binary file not shown.
BIN
nac3standalone/demo/linalg/liblinalg/x86_64/liblinalg.a
Normal file
BIN
nac3standalone/demo/linalg/liblinalg/x86_64/liblinalg.a
Normal file
Binary file not shown.
@ -34,83 +34,6 @@ impl InputMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 {
|
|
||||||
let mat1 = mat1.as_mut().unwrap();
|
|
||||||
let mat2 = mat2.as_mut().unwrap();
|
|
||||||
|
|
||||||
if !(mat1.ndims == 1 && mat2.ndims == 1) {
|
|
||||||
let err_msg = format!(
|
|
||||||
"expected 1D Vector Input, but received {}-D and {}-D input",
|
|
||||||
mat1.ndims, mat2.ndims
|
|
||||||
);
|
|
||||||
report_error("ValueError", "np_dot", file!(), line!(), column!(), &err_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
let dim1 = (*mat1).get_dims();
|
|
||||||
let dim2 = (*mat2).get_dims();
|
|
||||||
|
|
||||||
if dim1[0] != dim2[0] {
|
|
||||||
let err_msg = format!("shapes ({},) and ({},) not aligned", dim1[0], dim2[0]);
|
|
||||||
report_error("ValueError", "np_dot", file!(), line!(), column!(), &err_msg);
|
|
||||||
}
|
|
||||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) };
|
|
||||||
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) };
|
|
||||||
|
|
||||||
let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1);
|
|
||||||
let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2);
|
|
||||||
|
|
||||||
matrix1.dot(&matrix2)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn np_linalg_matmul(
|
|
||||||
mat1: *mut InputMatrix,
|
|
||||||
mat2: *mut InputMatrix,
|
|
||||||
out: *mut InputMatrix,
|
|
||||||
) {
|
|
||||||
let mat1 = mat1.as_mut().unwrap();
|
|
||||||
let mat2 = mat2.as_mut().unwrap();
|
|
||||||
let out = out.as_mut().unwrap();
|
|
||||||
|
|
||||||
if !(mat1.ndims == 2 && mat2.ndims == 2) {
|
|
||||||
let err_msg = format!(
|
|
||||||
"expected 2D Vector Input, but received {}-D and {}-D input",
|
|
||||||
mat1.ndims, mat2.ndims
|
|
||||||
);
|
|
||||||
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
let dim1 = (*mat1).get_dims();
|
|
||||||
let dim2 = (*mat2).get_dims();
|
|
||||||
|
|
||||||
if dim1[1] != dim2[0] {
|
|
||||||
let err_msg = format!(
|
|
||||||
"shapes ({},{}) and ({},{}) not aligned: {} (dim 1) != {} (dim 0)",
|
|
||||||
dim1[0], dim1[1], dim2[0], dim2[1], dim1[1], dim2[0]
|
|
||||||
);
|
|
||||||
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
let outdim = out.get_dims();
|
|
||||||
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
|
||||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
|
||||||
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0] * dim2[1]) };
|
|
||||||
|
|
||||||
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
|
||||||
let matrix2 = DMatrix::from_row_slice(dim2[0], dim2[1], data_slice2);
|
|
||||||
let mut result = DMatrix::<f64>::zeros(outdim[0], outdim[1]);
|
|
||||||
|
|
||||||
matrix1.mul_to(&matrix2, &mut result);
|
|
||||||
out_slice.copy_from_slice(result.transpose().as_slice());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
/// # Safety
|
||||||
///
|
///
|
||||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
@ -120,7 +43,7 @@ pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut In
|
|||||||
let out = out.as_mut().unwrap();
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
|
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,7 +91,7 @@ pub unsafe extern "C" fn np_linalg_qr(
|
|||||||
let out_r = out_r.as_mut().unwrap();
|
let out_r = out_r.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
|
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,7 +130,7 @@ pub unsafe extern "C" fn np_linalg_svd(
|
|||||||
let outvh = outvh.as_mut().unwrap();
|
let outvh = outvh.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
report_error("ValueError", "np_linalg_svd", file!(), line!(), column!(), &err_msg);
|
report_error("ValueError", "np_linalg_svd", file!(), line!(), column!(), &err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,7 +161,7 @@ pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMa
|
|||||||
let out = out.as_mut().unwrap();
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
report_error("ValueError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
|
report_error("ValueError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
|
||||||
}
|
}
|
||||||
let dim1 = (*mat1).get_dims();
|
let dim1 = (*mat1).get_dims();
|
||||||
@ -277,7 +200,7 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
|
|||||||
let out = out.as_mut().unwrap();
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
report_error("ValueError", "np_linalg_pinv", file!(), line!(), column!(), &err_msg);
|
report_error("ValueError", "np_linalg_pinv", file!(), line!(), column!(), &err_msg);
|
||||||
}
|
}
|
||||||
let dim1 = (*mat1).get_dims();
|
let dim1 = (*mat1).get_dims();
|
||||||
@ -299,6 +222,76 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_matrix_power(
|
||||||
|
mat1: *mut InputMatrix,
|
||||||
|
mat2: *mut InputMatrix,
|
||||||
|
out: *mut InputMatrix,
|
||||||
|
) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let mat2 = mat2.as_mut().unwrap();
|
||||||
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D", mat1.ndims);
|
||||||
|
report_error("ValueError", "np_linalg_matrix_power", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
|
||||||
|
let power = power[0];
|
||||||
|
let outdim = out.get_dims();
|
||||||
|
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
|
||||||
|
let abs_pow = power.abs();
|
||||||
|
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
let mut result = matrix1.pow(abs_pow as u32);
|
||||||
|
|
||||||
|
if power < 0.0 {
|
||||||
|
if !result.is_invertible() {
|
||||||
|
report_error(
|
||||||
|
"LinAlgError",
|
||||||
|
"np_linalg_inv",
|
||||||
|
file!(),
|
||||||
|
line!(),
|
||||||
|
column!(),
|
||||||
|
"no inverse for Singular Matrix",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
result = result.try_inverse().unwrap();
|
||||||
|
}
|
||||||
|
out_slice.copy_from_slice(result.transpose().as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
report_error("ValueError", "np_linalg_det", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
|
||||||
|
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
if !matrix.is_square() {
|
||||||
|
let err_msg =
|
||||||
|
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
|
||||||
|
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
out_slice[0] = matrix.determinant();
|
||||||
|
}
|
||||||
|
|
||||||
/// # Safety
|
/// # Safety
|
||||||
///
|
///
|
||||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
@ -313,7 +306,7 @@ pub unsafe extern "C" fn sp_linalg_lu(
|
|||||||
let out_u = out_u.as_mut().unwrap();
|
let out_u = out_u.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
report_error("ValueError", "sp_linalg_lu", file!(), line!(), column!(), &err_msg);
|
report_error("ValueError", "sp_linalg_lu", file!(), line!(), column!(), &err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -346,7 +339,7 @@ pub unsafe extern "C" fn sp_linalg_schur(
|
|||||||
let out_z = out_z.as_mut().unwrap();
|
let out_z = out_z.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
report_error("ValueError", "sp_linalg_schur", file!(), line!(), column!(), &err_msg);
|
report_error("ValueError", "sp_linalg_schur", file!(), line!(), column!(), &err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -386,7 +379,7 @@ pub unsafe extern "C" fn sp_linalg_hessenberg(
|
|||||||
let out_q = out_q.as_mut().unwrap();
|
let out_q = out_q.as_mut().unwrap();
|
||||||
|
|
||||||
if mat1.ndims != 2 {
|
if mat1.ndims != 2 {
|
||||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
report_error("ValueError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
|
report_error("ValueError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,14 +42,11 @@ done
|
|||||||
|
|
||||||
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
|
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
|
||||||
nac3standalone=../../target/debug/nac3standalone
|
nac3standalone=../../target/debug/nac3standalone
|
||||||
linalg=../../target/debug/deps/liblinalg-?*.a
|
|
||||||
elif [ -e ../../target/release/nac3standalone ]; then
|
elif [ -e ../../target/release/nac3standalone ]; then
|
||||||
nac3standalone=../../target/release/nac3standalone
|
nac3standalone=../../target/release/nac3standalone
|
||||||
linalg=../../target/release/deps/liblinalg-?*.a
|
|
||||||
else
|
else
|
||||||
# used by Nix builds
|
# used by Nix builds
|
||||||
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
|
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
|
||||||
linalg=../../target/x86_64-unknown-linux-gnu/release/deps/liblinalg-?*.a
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
rm -f ./*.o ./*.bc demo
|
rm -f ./*.o ./*.bc demo
|
||||||
@ -57,21 +54,16 @@ rm -f ./*.o ./*.bc demo
|
|||||||
if [ -z "$i386" ]; then
|
if [ -z "$i386" ]; then
|
||||||
$nac3standalone "${nac3args[@]}"
|
$nac3standalone "${nac3args[@]}"
|
||||||
|
|
||||||
|
cd linalg && cargo build -q && cd ..
|
||||||
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
|
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
|
||||||
clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o $linalg
|
clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o linalg/liblinalg/x86_64/liblinalg.a
|
||||||
else
|
else
|
||||||
# Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations
|
# Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations
|
||||||
|
|
||||||
$nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}"
|
$nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}"
|
||||||
|
|
||||||
# Compile linalg crate to provide functions compatible with i386 architecture
|
|
||||||
if [ ! -f ../../target/i686-unknown-linux-gnu/release/liblinalg.a ]; then
|
|
||||||
cd linalg && nix-shell -p rustup --command "RUSTFLAGS=\"-C target-cpu=i386 -C target-feature=+sse2\" cargo build -q --release --target=i686-unknown-linux-gnu" && cd ..
|
|
||||||
fi
|
|
||||||
|
|
||||||
linalg=../../target/i686-unknown-linux-gnu/release/liblinalg.a
|
|
||||||
clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
|
clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
|
||||||
clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o $linalg
|
clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o linalg/liblinalg/i386/liblinalg.a
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -z "$outfile" ]; then
|
if [ -z "$outfile" ]; then
|
||||||
|
@ -68,12 +68,6 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
|
|||||||
for c in range(len(n[r])):
|
for c in range(len(n[r])):
|
||||||
output_float64(n[r][c])
|
output_float64(n[r][c])
|
||||||
|
|
||||||
def output_ndarray_float_3(n: ndarray[float, Literal[3]]):
|
|
||||||
for r in range(len(n)):
|
|
||||||
for c in range(len(n[r])):
|
|
||||||
for z in range(len(n[r][c])):
|
|
||||||
output_float64(n[r][c][z])
|
|
||||||
|
|
||||||
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
|
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -1436,43 +1430,49 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
|||||||
output_ndarray_float_2(nextafter_x_ones)
|
output_ndarray_float_2(nextafter_x_ones)
|
||||||
|
|
||||||
def test_ndarray_transpose():
|
def test_ndarray_transpose():
|
||||||
# x: ndarray[float, 3] = np_array([[[1., 2.], [3., 4.], [5., 6.]]])
|
x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]])
|
||||||
x: ndarray[float, 1] = np_array([1., 2., 3.])
|
|
||||||
y = np_transpose(x)
|
y = np_transpose(x)
|
||||||
|
z = np_transpose(y)
|
||||||
|
|
||||||
output_ndarray_float_1(x)
|
output_ndarray_float_2(x)
|
||||||
output_ndarray_float_1(y)
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
def test_ndarray_reshape():
|
def test_ndarray_reshape():
|
||||||
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
|
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
|
||||||
x: ndarray[float, 4] = np_reshape(w, (1, 2, 1, -1))
|
x = np_reshape(w, (1, 2, 1, -1))
|
||||||
y: ndarray[float, 2] = np_reshape(x, [2, -1])
|
y = np_reshape(x, [2, -1])
|
||||||
z: ndarray[float, 1] = np_reshape(w, 10)
|
z = np_reshape(y, 10)
|
||||||
|
|
||||||
|
x1: ndarray[int32, 1] = np_array([1, 2, 3, 4])
|
||||||
|
x2: ndarray[int32, 2] = np_reshape(x1, (2, 2))
|
||||||
|
|
||||||
output_ndarray_float_1(w)
|
output_ndarray_float_1(w)
|
||||||
output_ndarray_float_2(y)
|
output_ndarray_float_2(y)
|
||||||
output_ndarray_float_1(z)
|
output_ndarray_float_1(z)
|
||||||
|
|
||||||
def test_ndarray_dot():
|
def test_ndarray_dot():
|
||||||
x: ndarray[float, 1] = np_array([5.0, 1.0])
|
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
|
||||||
y: ndarray[float, 1] = np_array([5.0, 1.0])
|
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
|
||||||
z = np_dot(x, y)
|
z1 = np_dot(x1, y1)
|
||||||
|
|
||||||
output_ndarray_float_1(x)
|
x2: ndarray[int32, 1] = np_array([5, 1, 4, 2])
|
||||||
output_ndarray_float_1(y)
|
y2: ndarray[int32, 1] = np_array([5, 1, 6, 6])
|
||||||
output_float64(z)
|
z2 = np_dot(x2, y2)
|
||||||
|
|
||||||
def test_ndarray_linalg_matmul():
|
x3: ndarray[bool, 1] = np_array([True, True, True, True])
|
||||||
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
y3: ndarray[bool, 1] = np_array([True, True, True, True])
|
||||||
y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
z3 = np_dot(x3, y3)
|
||||||
z = np_linalg_matmul(x, y)
|
|
||||||
|
|
||||||
m = np_argmax(z)
|
z4 = np_dot(2, 3)
|
||||||
|
z5 = np_dot(2., 3.)
|
||||||
|
z6 = np_dot(True, False)
|
||||||
|
|
||||||
output_ndarray_float_2(x)
|
output_float64(z1)
|
||||||
output_ndarray_float_2(y)
|
output_int32(z2)
|
||||||
output_ndarray_float_2(z)
|
output_bool(z3)
|
||||||
output_int64(m)
|
output_int32(z4)
|
||||||
|
output_float64(z5)
|
||||||
|
output_bool(z6)
|
||||||
|
|
||||||
def test_ndarray_cholesky():
|
def test_ndarray_cholesky():
|
||||||
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
||||||
@ -1489,7 +1489,7 @@ def test_ndarray_qr():
|
|||||||
|
|
||||||
# QR Factorization is not unique and gives different results in numpy and nalgebra
|
# QR Factorization is not unique and gives different results in numpy and nalgebra
|
||||||
# Reverting the decomposition to compare the initial arrays
|
# Reverting the decomposition to compare the initial arrays
|
||||||
a = np_linalg_matmul(y, z)
|
a = y @ z
|
||||||
output_ndarray_float_2(a)
|
output_ndarray_float_2(a)
|
||||||
|
|
||||||
def test_ndarray_linalg_inv():
|
def test_ndarray_linalg_inv():
|
||||||
@ -1506,6 +1506,20 @@ def test_ndarray_pinv():
|
|||||||
output_ndarray_float_2(x)
|
output_ndarray_float_2(x)
|
||||||
output_ndarray_float_2(y)
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_matrix_power():
|
||||||
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||||
|
y = np_linalg_matrix_power(x, -9)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_det():
|
||||||
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||||
|
y = np_linalg_det(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_float64(y)
|
||||||
|
|
||||||
def test_ndarray_schur():
|
def test_ndarray_schur():
|
||||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||||
t, z = sp_linalg_schur(x)
|
t, z = sp_linalg_schur(x)
|
||||||
@ -1514,7 +1528,7 @@ def test_ndarray_schur():
|
|||||||
|
|
||||||
# Schur Factorization is not unique and gives different results in scipy and nalgebra
|
# Schur Factorization is not unique and gives different results in scipy and nalgebra
|
||||||
# Reverting the decomposition to compare the initial arrays
|
# Reverting the decomposition to compare the initial arrays
|
||||||
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z))
|
a = (z @ t) @ np_linalg_inv(z)
|
||||||
output_ndarray_float_2(a)
|
output_ndarray_float_2(a)
|
||||||
|
|
||||||
def test_ndarray_hessenberg():
|
def test_ndarray_hessenberg():
|
||||||
@ -1525,7 +1539,7 @@ def test_ndarray_hessenberg():
|
|||||||
|
|
||||||
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
|
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
|
||||||
# Reverting the decomposition to compare the initial arrays
|
# Reverting the decomposition to compare the initial arrays
|
||||||
a = np_linalg_matmul(np_linalg_matmul(q, h), np_linalg_inv(q))
|
a = (q @ h) @ np_linalg_inv(q)
|
||||||
output_ndarray_float_2(a)
|
output_ndarray_float_2(a)
|
||||||
|
|
||||||
|
|
||||||
@ -1546,7 +1560,7 @@ def test_ndarray_svd():
|
|||||||
|
|
||||||
# SVD Factorization is not unique and gives different results in numpy and nalgebra
|
# SVD Factorization is not unique and gives different results in numpy and nalgebra
|
||||||
# Reverting the decomposition to compare the initial arrays
|
# Reverting the decomposition to compare the initial arrays
|
||||||
a = np_linalg_matmul(x, z)
|
a = x @ z
|
||||||
output_ndarray_float_2(a)
|
output_ndarray_float_2(a)
|
||||||
output_ndarray_float_1(y)
|
output_ndarray_float_1(y)
|
||||||
|
|
||||||
@ -1733,12 +1747,13 @@ def run() -> int32:
|
|||||||
test_ndarray_reshape()
|
test_ndarray_reshape()
|
||||||
|
|
||||||
test_ndarray_dot()
|
test_ndarray_dot()
|
||||||
test_ndarray_linalg_matmul()
|
|
||||||
test_ndarray_cholesky()
|
test_ndarray_cholesky()
|
||||||
test_ndarray_qr()
|
test_ndarray_qr()
|
||||||
test_ndarray_svd()
|
test_ndarray_svd()
|
||||||
test_ndarray_linalg_inv()
|
test_ndarray_linalg_inv()
|
||||||
test_ndarray_pinv()
|
test_ndarray_pinv()
|
||||||
|
test_ndarray_matrix_power()
|
||||||
|
test_ndarray_det()
|
||||||
test_ndarray_lu()
|
test_ndarray_lu()
|
||||||
test_ndarray_schur()
|
test_ndarray_schur()
|
||||||
test_ndarray_hessenberg()
|
test_ndarray_hessenberg()
|
||||||
|
@ -1 +0,0 @@
|
|||||||
../target/debug/libnac3artiq.so
|
|
Loading…
Reference in New Issue
Block a user