forked from M-Labs/nac3
Compare commits
9 Commits
2237137f1a
...
c71a567a51
Author | SHA1 | Date | |
---|---|---|---|
|
c71a567a51 | ||
260a2fbb63 | |||
9e7cf4fcac | |||
688e85d13c | |||
c6bac576d4 | |||
9cdfaf96fd | |||
794138156d | |||
3980b8d353 | |||
5d74c1848d |
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,4 +1,3 @@
|
||||
__pycache__
|
||||
/target
|
||||
/nac3standalone/demo/linalg/target
|
||||
nix/windows/msys2
|
||||
|
106
Cargo.lock
generated
106
Cargo.lock
generated
@ -73,6 +73,15 @@ dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ascii-canvas"
|
||||
version = "3.0.0"
|
||||
@ -247,6 +256,12 @@ version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||
|
||||
[[package]]
|
||||
name = "cslice"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
|
||||
|
||||
[[package]]
|
||||
name = "dirs-next"
|
||||
version = "2.0.0"
|
||||
@ -521,6 +536,12 @@ dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||
|
||||
[[package]]
|
||||
name = "libredox"
|
||||
version = "0.1.3"
|
||||
@ -531,6 +552,14 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linalg"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"cslice",
|
||||
"nalgebra",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linked-hash-map"
|
||||
version = "0.5.6"
|
||||
@ -659,17 +688,70 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"inkwell",
|
||||
"linalg",
|
||||
"nac3core",
|
||||
"nac3parser",
|
||||
"parking_lot",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra"
|
||||
version = "0.32.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"simba",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "new_debug_unreachable"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
||||
dependencies = [
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.19.0"
|
||||
@ -699,6 +781,12 @@ dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "petgraph"
|
||||
version = "0.6.5"
|
||||
@ -1070,6 +1158,18 @@ dependencies = [
|
||||
"yaml-rust",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simba"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "similar"
|
||||
version = "2.6.0"
|
||||
@ -1230,6 +1330,12 @@ dependencies = [
|
||||
"crunchy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||
|
||||
[[package]]
|
||||
name = "unic-char-property"
|
||||
version = "0.9.0"
|
||||
|
@ -161,7 +161,9 @@
|
||||
clippy
|
||||
pre-commit
|
||||
rustfmt
|
||||
rust-analyzer
|
||||
];
|
||||
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
||||
};
|
||||
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
||||
name = "nac3-dev-shell-msys2";
|
||||
|
@ -3,9 +3,7 @@ use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
|
||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::classes::{
|
||||
NDArrayValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||
};
|
||||
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
|
||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||
@ -1867,6 +1865,83 @@ fn build_output_struct<'ctx>(
|
||||
out_ptr
|
||||
}
|
||||
|
||||
/// Invokes the `np_dot` linalg function
|
||||
pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_dot";
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
if let (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) = (x1, x2) {
|
||||
let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty);
|
||||
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
|
||||
|
||||
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
|
||||
else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||
};
|
||||
|
||||
Ok(extern_fns::call_np_dot(ctx, x1, x2, None).into())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_matmul` linalg function
|
||||
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_matmul";
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
|
||||
|
||||
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
|
||||
else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
||||
|
||||
let outdim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let outdim1 = unsafe {
|
||||
n2.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_matmul(ctx, x1, x2, out, None);
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_cholesky` linalg function
|
||||
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
@ -2149,104 +2224,6 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_matrix_power` linalg function
|
||||
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||
let n2_array = numpy::create_ndarray_const_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&[llvm_usize.const_int(1, false)],
|
||||
)
|
||||
.unwrap();
|
||||
unsafe {
|
||||
n2_array.data().set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_zero(),
|
||||
n2.as_basic_value_enum(),
|
||||
);
|
||||
};
|
||||
let n2_array = n2_array.as_base_value().as_basic_value_enum();
|
||||
|
||||
let outdim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let outdim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_det` linalg function
|
||||
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||
let (x1_ty, x1) = x1;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
if let BasicValueEnum::PointerValue(_) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
|
||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||
let out = numpy::create_ndarray_const_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&[llvm_usize.const_int(1, false)],
|
||||
)
|
||||
.unwrap();
|
||||
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
|
||||
let res =
|
||||
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||
Ok(res)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `sp_linalg_schur` linalg function
|
||||
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
|
@ -179,13 +179,42 @@ macro_rules! generate_linalg_extern_fn {
|
||||
};
|
||||
}
|
||||
|
||||
generate_linalg_extern_fn!(call_np_linalg_matmul, "np_linalg_matmul", 3);
|
||||
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
|
||||
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
|
||||
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
|
||||
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
|
||||
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
|
||||
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
|
||||
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
|
||||
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
|
||||
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
|
||||
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
|
||||
|
||||
/// Invokes the linalg `np_dot` function.
|
||||
pub fn call_np_dot<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
mat1: BasicValueEnum<'ctx>,
|
||||
mat2: BasicValueEnum<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "np_dot";
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type =
|
||||
ctx.ctx.f64_type().fn_type(&[mat1.get_type().into(), mat2.get_type().into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[mat1.into(), mat2.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
@ -26,15 +26,12 @@ use crate::{
|
||||
typedef::{FunSignature, Type, TypeEnum},
|
||||
},
|
||||
};
|
||||
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
|
||||
use inkwell::{
|
||||
types::BasicType,
|
||||
values::{BasicValueEnum, IntValue, PointerValue},
|
||||
AddressSpace, IntPredicate, OptimizationLevel,
|
||||
};
|
||||
use inkwell::{
|
||||
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
|
||||
values::BasicValue,
|
||||
};
|
||||
use nac3parser::ast::{Operator, StrRef};
|
||||
|
||||
/// Creates an uninitialized `NDArray` instance.
|
||||
@ -2030,7 +2027,6 @@ pub fn gen_ndarray_fill<'ctx>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.transpose`.
|
||||
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -2045,7 +2041,6 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||
|
||||
// Dimensions are reversed in the transposed array
|
||||
let out = create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
@ -2071,16 +2066,30 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||
(n_sz, false),
|
||||
|generator, ctx, _, idx| {
|
||||
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
||||
// Calculate transposed idx
|
||||
// 2, 3 => idx = row * num_col + col = 3 + 2 = 5
|
||||
// 2, 3, 4 => idx = row * (num_col*num_z) + col * (num_z) + z => 12 + 8 + 3 | 18, [4, 2] 15, ,1,2
|
||||
// num_z (col + num_col * row) + z
|
||||
// 4D => 2, 3, 4, 5 = idx = num_w (num_z (col + num_col*row) + z) + w = 119
|
||||
// [1, 1, 2]?
|
||||
// 4 (1 + 3*1) + 2 = 6
|
||||
// z = 2, col = 1, row = 0
|
||||
// 0,1,
|
||||
// 2, 3 => idx = row * num_col + col | 2, 3, 4 => idx = row * (num_col * num_z) + col * (num_z) + num_z
|
||||
// ND => idx = 1 * (dim0 + dim1 + ... dimn) + dim[-1] * (dim0 + dim1 + ... + dimn-1) + ... + dim[1] * dim0
|
||||
// 6 + 12 + 6 = 24 num_z * (row*num_col + col + 1) 4*6=24
|
||||
// 2, 3, 4, 5 at idx 1 should go to
|
||||
// 5, 4, 3, 2
|
||||
|
||||
// 18 => [2, 4] dim = 4
|
||||
// 0 * 4 + 2 = 2
|
||||
// 4 => [1, 1] dim = 3
|
||||
// 2 * 3 + 1
|
||||
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
|
||||
ctx.builder.build_store(rem_idx, idx).unwrap();
|
||||
|
||||
// Incrementally calculate the new index in the transposed array
|
||||
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
|
||||
// The formula used for indexing is:
|
||||
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
@ -2136,15 +2145,6 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||
}
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
|
||||
///
|
||||
/// * `x1` - `NDArray` to reshape.
|
||||
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
|
||||
/// Just like numpy, the `shape` argument can be:
|
||||
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
|
||||
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
||||
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
||||
/// Note that unlike other generating functions, one of the dimesions in the shape can be negative
|
||||
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -2162,19 +2162,44 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||
|
||||
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
// Check for -1 in the shapec
|
||||
let ndim_ty = match shape {
|
||||
BasicValueEnum::PointerValue(shape_list_ptr)
|
||||
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
|
||||
{
|
||||
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
|
||||
shape_list
|
||||
.data()
|
||||
.get(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
.get_type()
|
||||
}
|
||||
BasicValueEnum::StructValue(shape_tuple) => ctx
|
||||
.builder
|
||||
.build_extract_value(shape_tuple, 0, "")
|
||||
.unwrap()
|
||||
.into_int_value()
|
||||
.get_type(),
|
||||
BasicValueEnum::IntValue(shape_int) => shape_int.get_type(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let n_sz = ctx
|
||||
.builder
|
||||
.build_cast(inkwell::values::InstructionOpcode::Trunc, n_sz, ndim_ty, "")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
|
||||
let acc = generator.gen_var_alloc(ctx, ndim_ty.into(), None)?;
|
||||
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
|
||||
ctx.builder.build_store(acc, ndim_ty.const_int(1, false)).unwrap();
|
||||
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
|
||||
|
||||
let out = match shape {
|
||||
BasicValueEnum::PointerValue(shape_list_ptr)
|
||||
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
|
||||
{
|
||||
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
|
||||
|
||||
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
|
||||
// Check for -1 in dimensions
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
@ -2184,7 +2209,6 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|generator, ctx, _, idx| {
|
||||
let ele =
|
||||
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
||||
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
|
||||
|
||||
gen_if_else_expr_callback(
|
||||
generator,
|
||||
@ -2195,7 +2219,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
ele,
|
||||
llvm_usize.const_zero(),
|
||||
ndim_ty.const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap())
|
||||
@ -2229,7 +2253,6 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
)?;
|
||||
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
||||
// Generate the output shape by filling -1 with `rem`
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
@ -2239,8 +2262,6 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|generator, ctx, shape_list, idx| {
|
||||
let dim =
|
||||
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
||||
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||
|
||||
Ok(gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
@ -2250,7 +2271,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
dim,
|
||||
llvm_usize.const_zero(),
|
||||
ndim_ty.const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap())
|
||||
@ -2264,17 +2285,16 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
)
|
||||
}
|
||||
BasicValueEnum::StructValue(shape_tuple) => {
|
||||
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
||||
|
||||
let ndims = shape_tuple.get_type().count_fields();
|
||||
// Check for -1 in dims
|
||||
let acc = ctx.builder.build_alloca(ndim_ty, "").unwrap();
|
||||
ctx.builder.build_store(acc, ndim_ty.const_int(1, false)).unwrap();
|
||||
|
||||
for dim_i in 0..ndims {
|
||||
let dim = ctx
|
||||
.builder
|
||||
.build_extract_value(shape_tuple, dim_i, "")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||
|
||||
gen_if_else_expr_callback(
|
||||
generator,
|
||||
@ -2282,12 +2302,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
dim,
|
||||
llvm_usize.const_zero(),
|
||||
"",
|
||||
)
|
||||
.build_int_compare(IntPredicate::SLT, dim, ndim_ty.const_zero(), "")
|
||||
.unwrap())
|
||||
},
|
||||
|_, ctx| -> Result<Option<IntValue>, String> {
|
||||
@ -2313,14 +2328,12 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
||||
let mut shape = Vec::with_capacity(ndims as usize);
|
||||
|
||||
// Reconstruct shape filling negatives with rem
|
||||
for dim_i in 0..ndims {
|
||||
let dim = ctx
|
||||
.builder
|
||||
.build_extract_value(shape_tuple, dim_i, "")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||
|
||||
let dim = gen_if_else_expr_callback(
|
||||
generator,
|
||||
@ -2328,12 +2341,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
dim,
|
||||
llvm_usize.const_zero(),
|
||||
"",
|
||||
)
|
||||
.build_int_compare(IntPredicate::SLT, dim, ndim_ty.const_zero(), "")
|
||||
.unwrap())
|
||||
},
|
||||
|_, _| Ok(Some(rem)),
|
||||
@ -2346,7 +2354,6 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
||||
}
|
||||
BasicValueEnum::IntValue(shape_int) => {
|
||||
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
||||
let shape_int = gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
@ -2356,15 +2363,13 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
shape_int,
|
||||
llvm_usize.const_zero(),
|
||||
ndim_ty.const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap())
|
||||
},
|
||||
|_, _| Ok(Some(n_sz)),
|
||||
|_, ctx| {
|
||||
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
|
||||
},
|
||||
|_, _| Ok(Some(shape_int)),
|
||||
)?
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
@ -2374,7 +2379,6 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
}
|
||||
.unwrap();
|
||||
|
||||
// Only allow one dimension to be negative
|
||||
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
@ -2387,13 +2391,17 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// The new shape must be compatible with the old shape
|
||||
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
|
||||
let out_sz = ctx
|
||||
.builder
|
||||
.build_cast(inkwell::values::InstructionOpcode::Trunc, out_sz, ndim_ty, "")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
||||
"0:ValueError",
|
||||
"cannot reshape array of size {0} into provided shape of size {1}",
|
||||
"cannot reshape array of size {} into provided shape of size {}",
|
||||
[Some(n_sz), Some(out_sz), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
@ -2420,102 +2428,3 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.dot`.
|
||||
/// Calculate inner product of two vectors or literals
|
||||
/// For matrix multiplication use `np_matmul`
|
||||
///
|
||||
/// The input `NDArray` are flattened and treated as 1D
|
||||
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
|
||||
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "ndarray_dot";
|
||||
let (x1_ty, x1) = x1;
|
||||
let (_, x2) = x2;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
match (x1, x2) {
|
||||
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
||||
|
||||
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
|
||||
"0:ValueError",
|
||||
"shapes ({0}), ({1}) not aligned",
|
||||
[Some(n1_sz), Some(n2_sz), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let identity =
|
||||
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
|
||||
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(n1_sz, false),
|
||||
|generator, ctx, _, idx| {
|
||||
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
||||
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
|
||||
|
||||
let product = match elem1 {
|
||||
BasicValueEnum::IntValue(e1) => ctx
|
||||
.builder
|
||||
.build_int_mul(e1, elem2.into_int_value(), "")
|
||||
.unwrap()
|
||||
.as_basic_value_enum(),
|
||||
BasicValueEnum::FloatValue(e1) => ctx
|
||||
.builder
|
||||
.build_float_mul(e1, elem2.into_float_value(), "")
|
||||
.unwrap()
|
||||
.as_basic_value_enum(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
||||
let acc_val = match acc_val {
|
||||
BasicValueEnum::IntValue(e1) => ctx
|
||||
.builder
|
||||
.build_int_add(e1, product.into_int_value(), "")
|
||||
.unwrap()
|
||||
.as_basic_value_enum(),
|
||||
BasicValueEnum::FloatValue(e1) => ctx
|
||||
.builder
|
||||
.build_float_add(e1, product.into_float_value(), "")
|
||||
.unwrap()
|
||||
.as_basic_value_enum(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
ctx.builder.build_store(acc, acc_val).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
||||
Ok(acc_val)
|
||||
}
|
||||
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
|
||||
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
||||
}
|
||||
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
|
||||
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
||||
}
|
||||
_ => unreachable!(
|
||||
"{FN_NAME}() not supported for '{}'",
|
||||
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||
),
|
||||
}
|
||||
}
|
||||
|
@ -558,17 +558,16 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
||||
|
||||
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
|
||||
self.build_np_sp_ndarray_function(prim)
|
||||
self.build_np_sp_ndarray_1ary_function(prim)
|
||||
}
|
||||
|
||||
PrimDef::FunNpDot
|
||||
| PrimDef::FunNpLinalgMatmul
|
||||
| PrimDef::FunNpLinalgCholesky
|
||||
| PrimDef::FunNpLinalgQr
|
||||
| PrimDef::FunNpLinalgSvd
|
||||
| PrimDef::FunNpLinalgInv
|
||||
| PrimDef::FunNpLinalgPinv
|
||||
| PrimDef::FunNpLinalgMatrixPower
|
||||
| PrimDef::FunNpLinalgDet
|
||||
| PrimDef::FunSpLinalgLu
|
||||
| PrimDef::FunSpLinalgSchur
|
||||
| PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim),
|
||||
@ -1890,53 +1889,62 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Build np/sp functions that take as input `NDArray` only
|
||||
fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
/// Build 1-ary numpy/scipy functions that take in an ndarray and return a value of the same type as the input.
|
||||
fn build_np_sp_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
|
||||
|
||||
let elem_type = self.unifier.get_fresh_var(Some("R".into()), None);
|
||||
let ndarray_type = make_ndarray_ty(self.unifier, self.primitives, Some(elem_type.ty), None);
|
||||
let ndarray_ty =
|
||||
self.unifier.get_fresh_var_with_range(&[ndarray_type], Some("T".into()), None);
|
||||
let var_map = into_var_map([elem_type, ndarray_ty]);
|
||||
|
||||
match prim {
|
||||
PrimDef::FunNpTranspose => {
|
||||
let ndarray_ty = self.unifier.get_fresh_var_with_range(
|
||||
&[self.ndarray_num_ty],
|
||||
Some("T".into()),
|
||||
None,
|
||||
);
|
||||
create_fn_by_codegen(
|
||||
PrimDef::FunNpTranspose => create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&into_var_map([ndarray_ty]),
|
||||
&var_map,
|
||||
prim.name(),
|
||||
ndarray_ty.ty,
|
||||
&[(ndarray_ty.ty, "x")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg_val =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
|
||||
}),
|
||||
),
|
||||
|
||||
PrimDef::FunNpReshape => {
|
||||
// Return type can have differend ndims
|
||||
let ret_ty =
|
||||
self.unifier.get_fresh_var_with_range(&[ndarray_type], Some("U".into()), None);
|
||||
let var_map = var_map
|
||||
.clone()
|
||||
.into_iter()
|
||||
.chain(once((ret_ty.id, ret_ty.ty)))
|
||||
.chain(once((
|
||||
self.ndarray_factory_fn_shape_arg_tvar.id,
|
||||
self.ndarray_factory_fn_shape_arg_tvar.ty,
|
||||
)))
|
||||
.collect::<IndexMap<_, _>>();
|
||||
|
||||
create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&var_map,
|
||||
prim.name(),
|
||||
ret_ty.ty,
|
||||
&[(ndarray_ty.ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val =
|
||||
args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
|
||||
// the `param_ty` for `create_fn_by_codegen`.
|
||||
//
|
||||
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
|
||||
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
||||
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
||||
PrimDef::FunNpReshape => create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
self.ndarray_num_ty,
|
||||
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||
}),
|
||||
),
|
||||
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -1949,13 +1957,12 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
prim,
|
||||
&[
|
||||
PrimDef::FunNpDot,
|
||||
PrimDef::FunNpLinalgMatmul,
|
||||
PrimDef::FunNpLinalgCholesky,
|
||||
PrimDef::FunNpLinalgQr,
|
||||
PrimDef::FunNpLinalgSvd,
|
||||
PrimDef::FunNpLinalgInv,
|
||||
PrimDef::FunNpLinalgPinv,
|
||||
PrimDef::FunNpLinalgMatrixPower,
|
||||
PrimDef::FunNpLinalgDet,
|
||||
PrimDef::FunSpLinalgLu,
|
||||
PrimDef::FunSpLinalgSchur,
|
||||
PrimDef::FunSpLinalgHessenberg,
|
||||
@ -1967,7 +1974,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
self.unifier,
|
||||
&self.num_or_ndarray_var_map,
|
||||
prim.name(),
|
||||
self.num_ty.ty,
|
||||
self.primitives.float,
|
||||
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
@ -1975,7 +1982,33 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
|
||||
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||
Ok(Some(builtin_fns::call_np_dot(
|
||||
generator,
|
||||
ctx,
|
||||
(x1_ty, x1_val),
|
||||
(x2_ty, x2_val),
|
||||
)?))
|
||||
}),
|
||||
),
|
||||
|
||||
PrimDef::FunNpLinalgMatmul => create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
self.ndarray_float_2d,
|
||||
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_np_linalg_matmul(
|
||||
generator,
|
||||
ctx,
|
||||
(x1_ty, x1_val),
|
||||
(x2_ty, x2_val),
|
||||
)?))
|
||||
}),
|
||||
),
|
||||
|
||||
@ -2053,39 +2086,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}),
|
||||
)
|
||||
}
|
||||
PrimDef::FunNpLinalgMatrixPower => create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
self.ndarray_float_2d,
|
||||
&[(self.ndarray_float_2d, "x1"), (self.primitives.int32, "power")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_np_linalg_matrix_power(
|
||||
generator,
|
||||
ctx,
|
||||
(x1_ty, x1_val),
|
||||
(x2_ty, x2_val),
|
||||
)?))
|
||||
}),
|
||||
),
|
||||
PrimDef::FunNpLinalgDet => create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
self.primitives.float,
|
||||
&[(self.ndarray_float_2d, "x1")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
Ok(Some(builtin_fns::call_np_linalg_det(generator, ctx, (x1_ty, x1_val))?))
|
||||
}),
|
||||
),
|
||||
_ => unreachable!(),
|
||||
_ => {
|
||||
println!("{:?}", prim.name());
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,13 +104,12 @@ pub enum PrimDef {
|
||||
|
||||
// Linalg functions
|
||||
FunNpDot,
|
||||
FunNpLinalgMatmul,
|
||||
FunNpLinalgCholesky,
|
||||
FunNpLinalgQr,
|
||||
FunNpLinalgSvd,
|
||||
FunNpLinalgInv,
|
||||
FunNpLinalgPinv,
|
||||
FunNpLinalgMatrixPower,
|
||||
FunNpLinalgDet,
|
||||
FunSpLinalgLu,
|
||||
FunSpLinalgSchur,
|
||||
FunSpLinalgHessenberg,
|
||||
@ -290,13 +289,12 @@ impl PrimDef {
|
||||
|
||||
// Linalg functions
|
||||
PrimDef::FunNpDot => fun("np_dot", None),
|
||||
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
|
||||
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
|
||||
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
|
||||
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
|
||||
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
|
||||
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
|
||||
PrimDef::FunNpLinalgMatrixPower => fun("np_linalg_matrix_power", None),
|
||||
PrimDef::FunNpLinalgDet => fun("np_linalg_det", None),
|
||||
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
|
||||
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
|
||||
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
|
||||
|
@ -5,7 +5,7 @@ expression: res_vec
|
||||
[
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||
|
@ -7,7 +7,7 @@ expression: res_vec
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||
|
@ -5,8 +5,8 @@ expression: res_vec
|
||||
[
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n",
|
||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
||||
expression: res_vec
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||
|
@ -6,12 +6,12 @@ expression: res_vec
|
||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
|
||||
]
|
||||
|
@ -1130,44 +1130,6 @@ impl<'a> Inferencer<'a> {
|
||||
}));
|
||||
}
|
||||
|
||||
if id == &"np_dot".into() {
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let arg1 = self.fold_expr(args.remove(0))?;
|
||||
let arg0_ty = arg0.custom.unwrap();
|
||||
|
||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||
{
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||
|
||||
ndarray_dtype
|
||||
} else {
|
||||
arg0_ty
|
||||
};
|
||||
|
||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
|
||||
FuncArg { name: "x2".into(), ty: arg1.custom.unwrap(), default_value: None },
|
||||
],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(custom),
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||
}),
|
||||
args: vec![arg0, arg1],
|
||||
keywords: vec![],
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let arg0_ty = arg0.custom.unwrap();
|
||||
@ -1427,45 +1389,7 @@ impl<'a> Inferencer<'a> {
|
||||
},
|
||||
}));
|
||||
}
|
||||
// 2-argument ndarray n-dimensional factory functions
|
||||
if id == &"np_reshape".into() && args.len() == 2 {
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
|
||||
let shape_expr = args.remove(0);
|
||||
let (ndims, shape) =
|
||||
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
|
||||
|
||||
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(self.unifier, arg0.custom.unwrap());
|
||||
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(elem_ty), Some(ndims));
|
||||
|
||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
|
||||
FuncArg {
|
||||
name: "shape".into(),
|
||||
ty: shape.custom.unwrap(),
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(custom),
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||
}),
|
||||
args: vec![arg0, shape],
|
||||
keywords: vec![],
|
||||
},
|
||||
}));
|
||||
}
|
||||
// 2-argument ndarray n-dimensional creation functions
|
||||
if id == &"np_full".into() && args.len() == 2 {
|
||||
let ExprKind::List { elts, .. } = &args[0].node else {
|
||||
|
@ -8,6 +8,7 @@ edition = "2021"
|
||||
parking_lot = "0.12"
|
||||
nac3parser = { path = "../nac3parser" }
|
||||
nac3core = { path = "../nac3core" }
|
||||
linalg = { path = "./demo/linalg" }
|
||||
|
||||
[dependencies.clap]
|
||||
version = "4.5"
|
||||
|
@ -5,8 +5,8 @@ import importlib.util
|
||||
import importlib.machinery
|
||||
import math
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import scipy as sp
|
||||
import numpy.typing as npt
|
||||
import pathlib
|
||||
|
||||
from numpy import int32, int64, uint32, uint64
|
||||
@ -231,13 +231,12 @@ def patch(module):
|
||||
|
||||
# Linalg functions
|
||||
module.np_dot = np.dot
|
||||
module.np_linalg_matmul = np.matmul
|
||||
module.np_linalg_cholesky = np.linalg.cholesky
|
||||
module.np_linalg_qr = np.linalg.qr
|
||||
module.np_linalg_svd = np.linalg.svd
|
||||
module.np_linalg_inv = np.linalg.inv
|
||||
module.np_linalg_pinv = np.linalg.pinv
|
||||
module.np_linalg_matrix_power = np.linalg.matrix_power
|
||||
module.np_linalg_det = np.linalg.det
|
||||
|
||||
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
|
||||
module.sp_linalg_schur = sp.linalg.schur
|
||||
|
114
nac3standalone/demo/linalg/Cargo.lock
generated
114
nac3standalone/demo/linalg/Cargo.lock
generated
@ -1,114 +0,0 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
|
||||
|
||||
[[package]]
|
||||
name = "cslice"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||
|
||||
[[package]]
|
||||
name = "linalg"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"cslice",
|
||||
"nalgebra",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra"
|
||||
version = "0.32.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"simba",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
||||
dependencies = [
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "simba"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
@ -9,5 +9,3 @@ crate-type = ["staticlib"]
|
||||
[dependencies]
|
||||
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
|
||||
cslice = "0.3.0"
|
||||
|
||||
[workspace]
|
||||
|
@ -1,6 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Uses rustup to compile the linalg library for i386 and x86_84 architecture
|
||||
|
||||
nix-shell -p rustup --command "RUSTC_BOOTSTRAP=1 cargo build -Z unstable-options --target x86_64-unknown-linux-gnu --out-dir liblinalg/x86_64"
|
||||
nix-shell -p rustup --command "RUSTC_BOOTSTRAP=1 RUSTFLAGS=\"-C target-cpu=i386 -C target-feature=+sse2\" cargo build -Z unstable-options --target i686-unknown-linux-gnu --out-dir liblinalg/i386"
|
Binary file not shown.
Binary file not shown.
@ -34,6 +34,83 @@ impl InputMatrix {
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let mat2 = mat2.as_mut().unwrap();
|
||||
|
||||
if !(mat1.ndims == 1 && mat2.ndims == 1) {
|
||||
let err_msg = format!(
|
||||
"expected 1D Vector Input, but received {}-D and {}-D input",
|
||||
mat1.ndims, mat2.ndims
|
||||
);
|
||||
report_error("ValueError", "np_dot", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let dim2 = (*mat2).get_dims();
|
||||
|
||||
if dim1[0] != dim2[0] {
|
||||
let err_msg = format!("shapes ({},) and ({},) not aligned", dim1[0], dim2[0]);
|
||||
report_error("ValueError", "np_dot", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) };
|
||||
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) };
|
||||
|
||||
let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1);
|
||||
let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2);
|
||||
|
||||
matrix1.dot(&matrix2)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_matmul(
|
||||
mat1: *mut InputMatrix,
|
||||
mat2: *mut InputMatrix,
|
||||
out: *mut InputMatrix,
|
||||
) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let mat2 = mat2.as_mut().unwrap();
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if !(mat1.ndims == 2 && mat2.ndims == 2) {
|
||||
let err_msg = format!(
|
||||
"expected 2D Vector Input, but received {}-D and {}-D input",
|
||||
mat1.ndims, mat2.ndims
|
||||
);
|
||||
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let dim2 = (*mat2).get_dims();
|
||||
|
||||
if dim1[1] != dim2[0] {
|
||||
let err_msg = format!(
|
||||
"shapes ({},{}) and ({},{}) not aligned: {} (dim 1) != {} (dim 0)",
|
||||
dim1[0], dim1[1], dim2[0], dim2[1], dim1[1], dim2[0]
|
||||
);
|
||||
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
|
||||
let outdim = out.get_dims();
|
||||
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0] * dim2[1]) };
|
||||
|
||||
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
let matrix2 = DMatrix::from_row_slice(dim2[0], dim2[1], data_slice2);
|
||||
let mut result = DMatrix::<f64>::zeros(outdim[0], outdim[1]);
|
||||
|
||||
matrix1.mul_to(&matrix2, &mut result);
|
||||
out_slice.copy_from_slice(result.transpose().as_slice());
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
@ -43,7 +120,7 @@ pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut In
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
||||
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
|
||||
@ -91,7 +168,7 @@ pub unsafe extern "C" fn np_linalg_qr(
|
||||
let out_r = out_r.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
||||
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
|
||||
@ -130,7 +207,7 @@ pub unsafe extern "C" fn np_linalg_svd(
|
||||
let outvh = outvh.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
||||
report_error("ValueError", "np_linalg_svd", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
|
||||
@ -161,7 +238,7 @@ pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMa
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
||||
report_error("ValueError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
let dim1 = (*mat1).get_dims();
|
||||
@ -200,7 +277,7 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
||||
report_error("ValueError", "np_linalg_pinv", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
let dim1 = (*mat1).get_dims();
|
||||
@ -222,76 +299,6 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_matrix_power(
|
||||
mat1: *mut InputMatrix,
|
||||
mat2: *mut InputMatrix,
|
||||
out: *mut InputMatrix,
|
||||
) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let mat2 = mat2.as_mut().unwrap();
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}D", mat1.ndims);
|
||||
report_error("ValueError", "np_linalg_matrix_power", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
|
||||
let power = power[0];
|
||||
let outdim = out.get_dims();
|
||||
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
|
||||
let abs_pow = power.abs();
|
||||
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
let mut result = matrix1.pow(abs_pow as u32);
|
||||
|
||||
if power < 0.0 {
|
||||
if !result.is_invertible() {
|
||||
report_error(
|
||||
"LinAlgError",
|
||||
"np_linalg_inv",
|
||||
file!(),
|
||||
line!(),
|
||||
column!(),
|
||||
"no inverse for Singular Matrix",
|
||||
);
|
||||
}
|
||||
result = result.try_inverse().unwrap();
|
||||
}
|
||||
out_slice.copy_from_slice(result.transpose().as_slice());
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||
report_error("ValueError", "np_linalg_det", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
if !matrix.is_square() {
|
||||
let err_msg =
|
||||
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
|
||||
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
out_slice[0] = matrix.determinant();
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
@ -306,7 +313,7 @@ pub unsafe extern "C" fn sp_linalg_lu(
|
||||
let out_u = out_u.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
||||
report_error("ValueError", "sp_linalg_lu", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
|
||||
@ -339,7 +346,7 @@ pub unsafe extern "C" fn sp_linalg_schur(
|
||||
let out_z = out_z.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
||||
report_error("ValueError", "sp_linalg_schur", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
|
||||
@ -379,7 +386,7 @@ pub unsafe extern "C" fn sp_linalg_hessenberg(
|
||||
let out_q = out_q.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
|
||||
report_error("ValueError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
|
||||
}
|
||||
|
||||
|
@ -42,11 +42,14 @@ done
|
||||
|
||||
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
|
||||
nac3standalone=../../target/debug/nac3standalone
|
||||
linalg=../../target/debug/deps/liblinalg-?*.a
|
||||
elif [ -e ../../target/release/nac3standalone ]; then
|
||||
nac3standalone=../../target/release/nac3standalone
|
||||
linalg=../../target/release/deps/liblinalg-?*.a
|
||||
else
|
||||
# used by Nix builds
|
||||
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
|
||||
linalg=../../target/x86_64-unknown-linux-gnu/release/deps/liblinalg-?*.a
|
||||
fi
|
||||
|
||||
rm -f ./*.o ./*.bc demo
|
||||
@ -54,16 +57,21 @@ rm -f ./*.o ./*.bc demo
|
||||
if [ -z "$i386" ]; then
|
||||
$nac3standalone "${nac3args[@]}"
|
||||
|
||||
cd linalg && cargo build -q && cd ..
|
||||
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
|
||||
clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o linalg/liblinalg/x86_64/liblinalg.a
|
||||
clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o $linalg
|
||||
else
|
||||
# Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations
|
||||
|
||||
$nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}"
|
||||
|
||||
# Compile linalg crate to provide functions compatible with i386 architecture
|
||||
if [ ! -f ../../target/i686-unknown-linux-gnu/release/liblinalg.a ]; then
|
||||
cd linalg && nix-shell -p rustup --command "RUSTFLAGS=\"-C target-cpu=i386 -C target-feature=+sse2\" cargo build -q --release --target=i686-unknown-linux-gnu" && cd ..
|
||||
fi
|
||||
|
||||
linalg=../../target/i686-unknown-linux-gnu/release/liblinalg.a
|
||||
clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
|
||||
clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o linalg/liblinalg/i386/liblinalg.a
|
||||
clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o $linalg
|
||||
fi
|
||||
|
||||
if [ -z "$outfile" ]; then
|
||||
|
@ -68,6 +68,12 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
|
||||
for c in range(len(n[r])):
|
||||
output_float64(n[r][c])
|
||||
|
||||
def output_ndarray_float_3(n: ndarray[float, Literal[3]]):
|
||||
for r in range(len(n)):
|
||||
for c in range(len(n[r])):
|
||||
for z in range(len(n[r][c])):
|
||||
output_float64(n[r][c][z])
|
||||
|
||||
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
|
||||
pass
|
||||
|
||||
@ -1430,49 +1436,43 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
||||
output_ndarray_float_2(nextafter_x_ones)
|
||||
|
||||
def test_ndarray_transpose():
|
||||
x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]])
|
||||
# x: ndarray[float, 3] = np_array([[[1., 2.], [3., 4.], [5., 6.]]])
|
||||
x: ndarray[float, 1] = np_array([1., 2., 3.])
|
||||
y = np_transpose(x)
|
||||
z = np_transpose(y)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
output_ndarray_float_1(x)
|
||||
output_ndarray_float_1(y)
|
||||
|
||||
def test_ndarray_reshape():
|
||||
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
|
||||
x = np_reshape(w, (1, 2, 1, -1))
|
||||
y = np_reshape(x, [2, -1])
|
||||
z = np_reshape(y, 10)
|
||||
|
||||
x1: ndarray[int32, 1] = np_array([1, 2, 3, 4])
|
||||
x2: ndarray[int32, 2] = np_reshape(x1, (2, 2))
|
||||
x: ndarray[float, 4] = np_reshape(w, (1, 2, 1, -1))
|
||||
y: ndarray[float, 2] = np_reshape(x, [2, -1])
|
||||
z: ndarray[float, 1] = np_reshape(w, 10)
|
||||
|
||||
output_ndarray_float_1(w)
|
||||
output_ndarray_float_2(y)
|
||||
output_ndarray_float_1(z)
|
||||
|
||||
def test_ndarray_dot():
|
||||
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
|
||||
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
|
||||
z1 = np_dot(x1, y1)
|
||||
x: ndarray[float, 1] = np_array([5.0, 1.0])
|
||||
y: ndarray[float, 1] = np_array([5.0, 1.0])
|
||||
z = np_dot(x, y)
|
||||
|
||||
x2: ndarray[int32, 1] = np_array([5, 1, 4, 2])
|
||||
y2: ndarray[int32, 1] = np_array([5, 1, 6, 6])
|
||||
z2 = np_dot(x2, y2)
|
||||
output_ndarray_float_1(x)
|
||||
output_ndarray_float_1(y)
|
||||
output_float64(z)
|
||||
|
||||
x3: ndarray[bool, 1] = np_array([True, True, True, True])
|
||||
y3: ndarray[bool, 1] = np_array([True, True, True, True])
|
||||
z3 = np_dot(x3, y3)
|
||||
def test_ndarray_linalg_matmul():
|
||||
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
||||
y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
||||
z = np_linalg_matmul(x, y)
|
||||
|
||||
z4 = np_dot(2, 3)
|
||||
z5 = np_dot(2., 3.)
|
||||
z6 = np_dot(True, False)
|
||||
m = np_argmax(z)
|
||||
|
||||
output_float64(z1)
|
||||
output_int32(z2)
|
||||
output_bool(z3)
|
||||
output_int32(z4)
|
||||
output_float64(z5)
|
||||
output_bool(z6)
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
output_ndarray_float_2(z)
|
||||
output_int64(m)
|
||||
|
||||
def test_ndarray_cholesky():
|
||||
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
||||
@ -1489,7 +1489,7 @@ def test_ndarray_qr():
|
||||
|
||||
# QR Factorization is not unique and gives different results in numpy and nalgebra
|
||||
# Reverting the decomposition to compare the initial arrays
|
||||
a = y @ z
|
||||
a = np_linalg_matmul(y, z)
|
||||
output_ndarray_float_2(a)
|
||||
|
||||
def test_ndarray_linalg_inv():
|
||||
@ -1506,20 +1506,6 @@ def test_ndarray_pinv():
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_matrix_power():
|
||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||
y = np_linalg_matrix_power(x, -9)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_det():
|
||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||
y = np_linalg_det(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_float64(y)
|
||||
|
||||
def test_ndarray_schur():
|
||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||
t, z = sp_linalg_schur(x)
|
||||
@ -1528,7 +1514,7 @@ def test_ndarray_schur():
|
||||
|
||||
# Schur Factorization is not unique and gives different results in scipy and nalgebra
|
||||
# Reverting the decomposition to compare the initial arrays
|
||||
a = (z @ t) @ np_linalg_inv(z)
|
||||
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z))
|
||||
output_ndarray_float_2(a)
|
||||
|
||||
def test_ndarray_hessenberg():
|
||||
@ -1539,7 +1525,7 @@ def test_ndarray_hessenberg():
|
||||
|
||||
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
|
||||
# Reverting the decomposition to compare the initial arrays
|
||||
a = (q @ h) @ np_linalg_inv(q)
|
||||
a = np_linalg_matmul(np_linalg_matmul(q, h), np_linalg_inv(q))
|
||||
output_ndarray_float_2(a)
|
||||
|
||||
|
||||
@ -1560,7 +1546,7 @@ def test_ndarray_svd():
|
||||
|
||||
# SVD Factorization is not unique and gives different results in numpy and nalgebra
|
||||
# Reverting the decomposition to compare the initial arrays
|
||||
a = x @ z
|
||||
a = np_linalg_matmul(x, z)
|
||||
output_ndarray_float_2(a)
|
||||
output_ndarray_float_1(y)
|
||||
|
||||
@ -1747,13 +1733,12 @@ def run() -> int32:
|
||||
test_ndarray_reshape()
|
||||
|
||||
test_ndarray_dot()
|
||||
test_ndarray_linalg_matmul()
|
||||
test_ndarray_cholesky()
|
||||
test_ndarray_qr()
|
||||
test_ndarray_svd()
|
||||
test_ndarray_linalg_inv()
|
||||
test_ndarray_pinv()
|
||||
test_ndarray_matrix_power()
|
||||
test_ndarray_det()
|
||||
test_ndarray_lu()
|
||||
test_ndarray_schur()
|
||||
test_ndarray_hessenberg()
|
||||
|
1
pyo3_output/nac3artiq.so
Symbolic link
1
pyo3_output/nac3artiq.so
Symbolic link
@ -0,0 +1 @@
|
||||
../target/debug/libnac3artiq.so
|
Loading…
Reference in New Issue
Block a user