Compare commits

..

11 Commits

25 changed files with 677 additions and 504 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
__pycache__ __pycache__
/target /target
/nac3standalone/demo/linalg/target
nix/windows/msys2 nix/windows/msys2

106
Cargo.lock generated
View File

@ -73,15 +73,6 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "ascii-canvas" name = "ascii-canvas"
version = "3.0.0" version = "3.0.0"
@ -256,12 +247,6 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]] [[package]]
name = "dirs-next" name = "dirs-next"
version = "2.0.0" version = "2.0.0"
@ -536,12 +521,6 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]] [[package]]
name = "libredox" name = "libredox"
version = "0.1.3" version = "0.1.3"
@ -552,14 +531,6 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "linalg"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]] [[package]]
name = "linked-hash-map" name = "linked-hash-map"
version = "0.5.6" version = "0.5.6"
@ -688,70 +659,17 @@ version = "0.1.0"
dependencies = [ dependencies = [
"clap", "clap",
"inkwell", "inkwell",
"linalg",
"nac3core", "nac3core",
"nac3parser", "nac3parser",
"parking_lot", "parking_lot",
] ]
[[package]]
name = "nalgebra"
version = "0.32.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
dependencies = [
"approx",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]] [[package]]
name = "new_debug_unreachable" name = "new_debug_unreachable"
version = "1.0.6" version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.19.0" version = "1.19.0"
@ -781,12 +699,6 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]] [[package]]
name = "petgraph" name = "petgraph"
version = "0.6.5" version = "0.6.5"
@ -1158,18 +1070,6 @@ dependencies = [
"yaml-rust", "yaml-rust",
] ]
[[package]]
name = "simba"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
]
[[package]] [[package]]
name = "similar" name = "similar"
version = "2.6.0" version = "2.6.0"
@ -1330,12 +1230,6 @@ dependencies = [
"crunchy", "crunchy",
] ]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]] [[package]]
name = "unic-char-property" name = "unic-char-property"
version = "0.9.0" version = "0.9.0"

View File

@ -161,9 +161,7 @@
clippy clippy
pre-commit pre-commit
rustfmt rustfmt
rust-analyzer
]; ];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -3,7 +3,9 @@ use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools; use itertools::Itertools;
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor}; use crate::codegen::classes::{
NDArrayValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
};
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
@ -1865,83 +1867,6 @@ fn build_output_struct<'ctx>(
out_ptr out_ptr
} }
/// Invokes the `np_dot` linalg function
pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_dot";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
if let (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) = (x1, x2) {
let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty);
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
};
Ok(extern_fns::call_np_dot(ctx, x1, x2, None).into())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
/// Invokes the `np_linalg_matmul` linalg function
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matmul";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
let outdim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let outdim1 = unsafe {
n2.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_matmul(ctx, x1, x2, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
/// Invokes the `np_linalg_cholesky` linalg function /// Invokes the `np_linalg_cholesky` linalg function
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
@ -2224,6 +2149,104 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `np_linalg_matrix_power` linalg function
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
let llvm_usize = generator.get_size_type(ctx.ctx);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
// Changing second parameter to a `NDArray` for uniformity in function call
let n2_array = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
unsafe {
n2_array.data().set_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
n2.as_basic_value_enum(),
);
};
let n2_array = n2_array.as_base_value().as_basic_value_enum();
let outdim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let outdim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
/// Invokes the `np_linalg_det` linalg function
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(_) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
// Changing second parameter to a `NDArray` for uniformity in function call
let out = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
let res =
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
Ok(res)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `sp_linalg_schur` linalg function /// Invokes the `sp_linalg_schur` linalg function
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,

View File

@ -179,42 +179,13 @@ macro_rules! generate_linalg_extern_fn {
}; };
} }
generate_linalg_extern_fn!(call_np_linalg_matmul, "np_linalg_matmul", 3);
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2); generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3); generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4); generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2); generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2); generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3); generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3); generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3); generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
/// Invokes the linalg `np_dot` function.
pub fn call_np_dot<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
mat1: BasicValueEnum<'ctx>,
mat2: BasicValueEnum<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "np_dot";
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type =
ctx.ctx.f64_type().fn_type(&[mat1.get_type().into(), mat2.get_type().into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[mat1.into(), mat2.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -26,12 +26,15 @@ use crate::{
typedef::{FunSignature, Type, TypeEnum}, typedef::{FunSignature, Type, TypeEnum},
}, },
}; };
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
use inkwell::{ use inkwell::{
types::BasicType, types::BasicType,
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use inkwell::{
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
values::BasicValue,
};
use nac3parser::ast::{Operator, StrRef}; use nac3parser::ast::{Operator, StrRef};
/// Creates an uninitialized `NDArray` instance. /// Creates an uninitialized `NDArray` instance.
@ -2027,6 +2030,7 @@ pub fn gen_ndarray_fill<'ctx>(
Ok(()) Ok(())
} }
/// Generates LLVM IR for `ndarray.transpose`.
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2041,6 +2045,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
// Dimensions are reversed in the transposed array
let out = create_ndarray_dyn_shape( let out = create_ndarray_dyn_shape(
generator, generator,
ctx, ctx,
@ -2066,30 +2071,16 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
(n_sz, false), (n_sz, false),
|generator, ctx, _, idx| { |generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
// Calculate transposed idx
// 2, 3 => idx = row * num_col + col = 3 + 2 = 5
// 2, 3, 4 => idx = row * (num_col*num_z) + col * (num_z) + z => 12 + 8 + 3 | 18, [4, 2] 15, ,1,2
// num_z (col + num_col * row) + z
// 4D => 2, 3, 4, 5 = idx = num_w (num_z (col + num_col*row) + z) + w = 119
// [1, 1, 2]?
// 4 (1 + 3*1) + 2 = 6
// z = 2, col = 1, row = 0
// 0,1,
// 2, 3 => idx = row * num_col + col | 2, 3, 4 => idx = row * (num_col * num_z) + col * (num_z) + num_z
// ND => idx = 1 * (dim0 + dim1 + ... dimn) + dim[-1] * (dim0 + dim1 + ... + dimn-1) + ... + dim[1] * dim0
// 6 + 12 + 6 = 24 num_z * (row*num_col + col + 1) 4*6=24
// 2, 3, 4, 5 at idx 1 should go to
// 5, 4, 3, 2
// 18 => [2, 4] dim = 4
// 0 * 4 + 2 = 2
// 4 => [1, 1] dim = 3
// 2 * 3 + 1
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap(); ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
ctx.builder.build_store(rem_idx, idx).unwrap(); ctx.builder.build_store(rem_idx, idx).unwrap();
// Incrementally calculate the new index in the transposed array
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
// The formula used for indexing is:
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
@ -2145,6 +2136,15 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
///
/// * `x1` - `NDArray` to reshape.
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
/// Note that unlike other generating functions, one of the dimesions in the shape can be negative
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2162,44 +2162,19 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
// Check for -1 in the shapec let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let ndim_ty = match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
{
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
shape_list
.data()
.get(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
.get_type()
}
BasicValueEnum::StructValue(shape_tuple) => ctx
.builder
.build_extract_value(shape_tuple, 0, "")
.unwrap()
.into_int_value()
.get_type(),
BasicValueEnum::IntValue(shape_int) => shape_int.get_type(),
_ => unreachable!(),
};
let n_sz = ctx
.builder
.build_cast(inkwell::values::InstructionOpcode::Trunc, n_sz, ndim_ty, "")
.unwrap()
.into_int_value();
let acc = generator.gen_var_alloc(ctx, ndim_ty.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(acc, ndim_ty.const_int(1, false)).unwrap(); ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap(); ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
let out = match shape { let out = match shape {
BasicValueEnum::PointerValue(shape_list_ptr) BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() => if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
{ {
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None); let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
// Check for -1 in dimensions
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
ctx, ctx,
@ -2209,6 +2184,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, _, idx| { |generator, ctx, _, idx| {
let ele = let ele =
shape_list.data().get(ctx, generator, &idx, None).into_int_value(); shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
gen_if_else_expr_callback( gen_if_else_expr_callback(
generator, generator,
@ -2219,7 +2195,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
.build_int_compare( .build_int_compare(
IntPredicate::SLT, IntPredicate::SLT,
ele, ele,
ndim_ty.const_zero(), llvm_usize.const_zero(),
"", "",
) )
.unwrap()) .unwrap())
@ -2253,6 +2229,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
)?; )?;
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
// Generate the output shape by filling -1 with `rem`
create_ndarray_dyn_shape( create_ndarray_dyn_shape(
generator, generator,
ctx, ctx,
@ -2262,6 +2239,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, shape_list, idx| { |generator, ctx, shape_list, idx| {
let dim = let dim =
shape_list.data().get(ctx, generator, &idx, None).into_int_value(); shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
Ok(gen_if_else_expr_callback( Ok(gen_if_else_expr_callback(
generator, generator,
ctx, ctx,
@ -2271,7 +2250,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
.build_int_compare( .build_int_compare(
IntPredicate::SLT, IntPredicate::SLT,
dim, dim,
ndim_ty.const_zero(), llvm_usize.const_zero(),
"", "",
) )
.unwrap()) .unwrap())
@ -2285,16 +2264,17 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
) )
} }
BasicValueEnum::StructValue(shape_tuple) => { BasicValueEnum::StructValue(shape_tuple) => {
let ndims = shape_tuple.get_type().count_fields(); // 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
let acc = ctx.builder.build_alloca(ndim_ty, "").unwrap();
ctx.builder.build_store(acc, ndim_ty.const_int(1, false)).unwrap();
let ndims = shape_tuple.get_type().count_fields();
// Check for -1 in dims
for dim_i in 0..ndims { for dim_i in 0..ndims {
let dim = ctx let dim = ctx
.builder .builder
.build_extract_value(shape_tuple, dim_i, "") .build_extract_value(shape_tuple, dim_i, "")
.unwrap() .unwrap()
.into_int_value(); .into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
gen_if_else_expr_callback( gen_if_else_expr_callback(
generator, generator,
@ -2302,7 +2282,12 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|_, ctx| { |_, ctx| {
Ok(ctx Ok(ctx
.builder .builder
.build_int_compare(IntPredicate::SLT, dim, ndim_ty.const_zero(), "") .build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap()) .unwrap())
}, },
|_, ctx| -> Result<Option<IntValue>, String> { |_, ctx| -> Result<Option<IntValue>, String> {
@ -2328,12 +2313,14 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
let mut shape = Vec::with_capacity(ndims as usize); let mut shape = Vec::with_capacity(ndims as usize);
// Reconstruct shape filling negatives with rem
for dim_i in 0..ndims { for dim_i in 0..ndims {
let dim = ctx let dim = ctx
.builder .builder
.build_extract_value(shape_tuple, dim_i, "") .build_extract_value(shape_tuple, dim_i, "")
.unwrap() .unwrap()
.into_int_value(); .into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
let dim = gen_if_else_expr_callback( let dim = gen_if_else_expr_callback(
generator, generator,
@ -2341,7 +2328,12 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|_, ctx| { |_, ctx| {
Ok(ctx Ok(ctx
.builder .builder
.build_int_compare(IntPredicate::SLT, dim, ndim_ty.const_zero(), "") .build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap()) .unwrap())
}, },
|_, _| Ok(Some(rem)), |_, _| Ok(Some(rem)),
@ -2354,6 +2346,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
} }
BasicValueEnum::IntValue(shape_int) => { BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
let shape_int = gen_if_else_expr_callback( let shape_int = gen_if_else_expr_callback(
generator, generator,
ctx, ctx,
@ -2363,13 +2356,15 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
.build_int_compare( .build_int_compare(
IntPredicate::SLT, IntPredicate::SLT,
shape_int, shape_int,
ndim_ty.const_zero(), llvm_usize.const_zero(),
"", "",
) )
.unwrap()) .unwrap())
}, },
|_, _| Ok(Some(n_sz)), |_, _| Ok(Some(n_sz)),
|_, _| Ok(Some(shape_int)), |_, ctx| {
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
},
)? )?
.unwrap() .unwrap()
.into_int_value(); .into_int_value();
@ -2379,6 +2374,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
} }
.unwrap(); .unwrap();
// Only allow one dimension to be negative
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
ctx.make_assert( ctx.make_assert(
generator, generator,
@ -2391,17 +2387,13 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
ctx.current_loc, ctx.current_loc,
); );
// The new shape must be compatible with the old shape
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None)); let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
let out_sz = ctx
.builder
.build_cast(inkwell::values::InstructionOpcode::Trunc, out_sz, ndim_ty, "")
.unwrap()
.into_int_value();
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
"0:ValueError", "0:ValueError",
"cannot reshape array of size {} into provided shape of size {}", "cannot reshape array of size {0} into provided shape of size {1}",
[Some(n_sz), Some(out_sz), None], [Some(n_sz), Some(out_sz), None],
ctx.current_loc, ctx.current_loc,
); );
@ -2428,3 +2420,102 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
) )
} }
} }
/// Generates LLVM IR for `ndarray.dot`.
/// Calculate inner product of two vectors or literals
/// For matrix multiplication use `np_matmul`
///
/// The input `NDArray` are flattened and treated as 1D
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_dot";
let (x1_ty, x1) = x1;
let (_, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
match (x1, x2) {
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
"0:ValueError",
"shapes ({0}), ({1}) not aligned",
[Some(n1_sz), Some(n2_sz), None],
ctx.current_loc,
);
let identity =
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1_sz, false),
|generator, ctx, _, idx| {
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
let product = match elem1 {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_mul(e1, elem2.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_mul(e1, elem2.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
};
let acc_val = ctx.builder.build_load(acc, "").unwrap();
let acc_val = match acc_val {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_add(e1, product.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_add(e1, product.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
};
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap();
Ok(acc_val)
}
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
_ => unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
),
}
}

View File

@ -558,16 +558,17 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
self.build_np_sp_ndarray_1ary_function(prim) self.build_np_sp_ndarray_function(prim)
} }
PrimDef::FunNpDot PrimDef::FunNpDot
| PrimDef::FunNpLinalgMatmul
| PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgQr | PrimDef::FunNpLinalgQr
| PrimDef::FunNpLinalgSvd | PrimDef::FunNpLinalgSvd
| PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgInv
| PrimDef::FunNpLinalgPinv | PrimDef::FunNpLinalgPinv
| PrimDef::FunNpLinalgMatrixPower
| PrimDef::FunNpLinalgDet
| PrimDef::FunSpLinalgLu | PrimDef::FunSpLinalgLu
| PrimDef::FunSpLinalgSchur | PrimDef::FunSpLinalgSchur
| PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim), | PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim),
@ -1889,62 +1890,53 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
/// Build 1-ary numpy/scipy functions that take in an ndarray and return a value of the same type as the input. /// Build np/sp functions that take as input `NDArray` only
fn build_np_sp_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef { fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
let elem_type = self.unifier.get_fresh_var(Some("R".into()), None);
let ndarray_type = make_ndarray_ty(self.unifier, self.primitives, Some(elem_type.ty), None);
let ndarray_ty =
self.unifier.get_fresh_var_with_range(&[ndarray_type], Some("T".into()), None);
let var_map = into_var_map([elem_type, ndarray_ty]);
match prim { match prim {
PrimDef::FunNpTranspose => create_fn_by_codegen( PrimDef::FunNpTranspose => {
let ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.ndarray_num_ty],
Some("T".into()),
None,
);
create_fn_by_codegen(
self.unifier, self.unifier,
&var_map, &into_var_map([ndarray_ty]),
prim.name(), prim.name(),
ndarray_ty.ty, ndarray_ty.ty,
&[(ndarray_ty.ty, "x")], &[(ndarray_ty.ty, "x")],
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
}), }),
),
PrimDef::FunNpReshape => {
// Return type can have differend ndims
let ret_ty =
self.unifier.get_fresh_var_with_range(&[ndarray_type], Some("U".into()), None);
let var_map = var_map
.clone()
.into_iter()
.chain(once((ret_ty.id, ret_ty.ty)))
.chain(once((
self.ndarray_factory_fn_shape_arg_tvar.id,
self.ndarray_factory_fn_shape_arg_tvar.ty,
)))
.collect::<IndexMap<_, _>>();
create_fn_by_codegen(
self.unifier,
&var_map,
prim.name(),
ret_ty.ty,
&[(ndarray_ty.ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val =
args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
) )
} }
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
// the `param_ty` for `create_fn_by_codegen`.
//
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
PrimDef::FunNpReshape => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_num_ty,
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
),
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -1957,12 +1949,13 @@ impl<'a> BuiltinBuilder<'a> {
prim, prim,
&[ &[
PrimDef::FunNpDot, PrimDef::FunNpDot,
PrimDef::FunNpLinalgMatmul,
PrimDef::FunNpLinalgCholesky, PrimDef::FunNpLinalgCholesky,
PrimDef::FunNpLinalgQr, PrimDef::FunNpLinalgQr,
PrimDef::FunNpLinalgSvd, PrimDef::FunNpLinalgSvd,
PrimDef::FunNpLinalgInv, PrimDef::FunNpLinalgInv,
PrimDef::FunNpLinalgPinv, PrimDef::FunNpLinalgPinv,
PrimDef::FunNpLinalgMatrixPower,
PrimDef::FunNpLinalgDet,
PrimDef::FunSpLinalgLu, PrimDef::FunSpLinalgLu,
PrimDef::FunSpLinalgSchur, PrimDef::FunSpLinalgSchur,
PrimDef::FunSpLinalgHessenberg, PrimDef::FunSpLinalgHessenberg,
@ -1974,7 +1967,7 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier, self.unifier,
&self.num_or_ndarray_var_map, &self.num_or_ndarray_var_map,
prim.name(), prim.name(),
self.primitives.float, self.num_ty.ty,
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")], &[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty; let x1_ty = fun.0.args[0].ty;
@ -1982,33 +1975,7 @@ impl<'a> BuiltinBuilder<'a> {
let x2_ty = fun.0.args[1].ty; let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_dot( Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgMatmul => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_linalg_matmul(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}), }),
), ),
@ -2086,10 +2053,39 @@ impl<'a> BuiltinBuilder<'a> {
}), }),
) )
} }
_ => { PrimDef::FunNpLinalgMatrixPower => create_fn_by_codegen(
println!("{:?}", prim.name()); self.unifier,
unreachable!() &VarMap::new(),
} prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1"), (self.primitives.int32, "power")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_linalg_matrix_power(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgDet => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.primitives.float,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
Ok(Some(builtin_fns::call_np_linalg_det(generator, ctx, (x1_ty, x1_val))?))
}),
),
_ => unreachable!(),
} }
} }

View File

@ -104,12 +104,13 @@ pub enum PrimDef {
// Linalg functions // Linalg functions
FunNpDot, FunNpDot,
FunNpLinalgMatmul,
FunNpLinalgCholesky, FunNpLinalgCholesky,
FunNpLinalgQr, FunNpLinalgQr,
FunNpLinalgSvd, FunNpLinalgSvd,
FunNpLinalgInv, FunNpLinalgInv,
FunNpLinalgPinv, FunNpLinalgPinv,
FunNpLinalgMatrixPower,
FunNpLinalgDet,
FunSpLinalgLu, FunSpLinalgLu,
FunSpLinalgSchur, FunSpLinalgSchur,
FunSpLinalgHessenberg, FunSpLinalgHessenberg,
@ -289,12 +290,13 @@ impl PrimDef {
// Linalg functions // Linalg functions
PrimDef::FunNpDot => fun("np_dot", None), PrimDef::FunNpDot => fun("np_dot", None),
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None), PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None), PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None), PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None), PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None), PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
PrimDef::FunNpLinalgMatrixPower => fun("np_linalg_matrix_power", None),
PrimDef::FunNpLinalgDet => fun("np_linalg_det", None),
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None), PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None), PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None), PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n",
] ]

View File

@ -1130,6 +1130,44 @@ impl<'a> Inferencer<'a> {
})); }));
} }
if id == &"np_dot".into() {
let arg0 = self.fold_expr(args.remove(0))?;
let arg1 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap();
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
ndarray_dtype
} else {
arg0_ty
};
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
FuncArg { name: "x2".into(), ty: arg1.custom.unwrap(), default_value: None },
],
ret,
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![arg0, arg1],
keywords: vec![],
},
}));
}
if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
let arg0 = self.fold_expr(args.remove(0))?; let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap(); let arg0_ty = arg0.custom.unwrap();
@ -1389,7 +1427,45 @@ impl<'a> Inferencer<'a> {
}, },
})); }));
} }
// 2-argument ndarray n-dimensional factory functions
if id == &"np_reshape".into() && args.len() == 2 {
let arg0 = self.fold_expr(args.remove(0))?;
let shape_expr = args.remove(0);
let (ndims, shape) =
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let (elem_ty, _) = unpack_ndarray_var_tys(self.unifier, arg0.custom.unwrap());
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(elem_ty), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
FuncArg {
name: "shape".into(),
ty: shape.custom.unwrap(),
default_value: None,
},
],
ret,
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![arg0, shape],
keywords: vec![],
},
}));
}
// 2-argument ndarray n-dimensional creation functions // 2-argument ndarray n-dimensional creation functions
if id == &"np_full".into() && args.len() == 2 { if id == &"np_full".into() && args.len() == 2 {
let ExprKind::List { elts, .. } = &args[0].node else { let ExprKind::List { elts, .. } = &args[0].node else {

View File

@ -8,7 +8,6 @@ edition = "2021"
parking_lot = "0.12" parking_lot = "0.12"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
linalg = { path = "./demo/linalg" }
[dependencies.clap] [dependencies.clap]
version = "4.5" version = "4.5"

View File

@ -5,8 +5,8 @@ import importlib.util
import importlib.machinery import importlib.machinery
import math import math
import numpy as np import numpy as np
import scipy as sp
import numpy.typing as npt import numpy.typing as npt
import scipy as sp
import pathlib import pathlib
from numpy import int32, int64, uint32, uint64 from numpy import int32, int64, uint32, uint64
@ -231,12 +231,13 @@ def patch(module):
# Linalg functions # Linalg functions
module.np_dot = np.dot module.np_dot = np.dot
module.np_linalg_matmul = np.matmul
module.np_linalg_cholesky = np.linalg.cholesky module.np_linalg_cholesky = np.linalg.cholesky
module.np_linalg_qr = np.linalg.qr module.np_linalg_qr = np.linalg.qr
module.np_linalg_svd = np.linalg.svd module.np_linalg_svd = np.linalg.svd
module.np_linalg_inv = np.linalg.inv module.np_linalg_inv = np.linalg.inv
module.np_linalg_pinv = np.linalg.pinv module.np_linalg_pinv = np.linalg.pinv
module.np_linalg_matrix_power = np.linalg.matrix_power
module.np_linalg_det = np.linalg.det
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True) module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
module.sp_linalg_schur = sp.linalg.schur module.sp_linalg_schur = sp.linalg.schur

114
nac3standalone/demo/linalg/Cargo.lock generated Normal file
View File

@ -0,0 +1,114 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]]
name = "autocfg"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "linalg"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]]
name = "nalgebra"
version = "0.32.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
dependencies = [
"approx",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "simba"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"

View File

@ -9,3 +9,5 @@ crate-type = ["staticlib"]
[dependencies] [dependencies]
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]} nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
cslice = "0.3.0" cslice = "0.3.0"
[workspace]

View File

@ -0,0 +1,6 @@
#!/usr/bin/env bash
# Uses rustup to compile the linalg library for i386 and x86_84 architecture
nix-shell -p rustup --command "RUSTC_BOOTSTRAP=1 cargo build -Z unstable-options --target x86_64-unknown-linux-gnu --out-dir liblinalg/x86_64"
nix-shell -p rustup --command "RUSTC_BOOTSTRAP=1 RUSTFLAGS=\"-C target-cpu=i386 -C target-feature=+sse2\" cargo build -Z unstable-options --target i686-unknown-linux-gnu --out-dir liblinalg/i386"

Binary file not shown.

View File

@ -34,83 +34,6 @@ impl InputMatrix {
} }
} }
/// # Safety
///
/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
if !(mat1.ndims == 1 && mat2.ndims == 1) {
let err_msg = format!(
"expected 1D Vector Input, but received {}-D and {}-D input",
mat1.ndims, mat2.ndims
);
report_error("ValueError", "np_dot", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let dim2 = (*mat2).get_dims();
if dim1[0] != dim2[0] {
let err_msg = format!("shapes ({},) and ({},) not aligned", dim1[0], dim2[0]);
report_error("ValueError", "np_dot", file!(), line!(), column!(), &err_msg);
}
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2);
matrix1.dot(&matrix2)
}
/// # Safety
///
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matmul(
mat1: *mut InputMatrix,
mat2: *mut InputMatrix,
out: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
let out = out.as_mut().unwrap();
if !(mat1.ndims == 2 && mat2.ndims == 2) {
let err_msg = format!(
"expected 2D Vector Input, but received {}-D and {}-D input",
mat1.ndims, mat2.ndims
);
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let dim2 = (*mat2).get_dims();
if dim1[1] != dim2[0] {
let err_msg = format!(
"shapes ({},{}) and ({},{}) not aligned: {} (dim 1) != {} (dim 0)",
dim1[0], dim1[1], dim2[0], dim2[1], dim1[1], dim2[0]
);
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0] * dim2[1]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2[0], dim2[1], data_slice2);
let mut result = DMatrix::<f64>::zeros(outdim[0], outdim[1]);
matrix1.mul_to(&matrix2, &mut result);
out_slice.copy_from_slice(result.transpose().as_slice());
}
/// # Safety /// # Safety
/// ///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order /// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
@ -120,7 +43,7 @@ pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut In
let out = out.as_mut().unwrap(); let out = out.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg); report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
} }
@ -168,7 +91,7 @@ pub unsafe extern "C" fn np_linalg_qr(
let out_r = out_r.as_mut().unwrap(); let out_r = out_r.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg); report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
} }
@ -207,7 +130,7 @@ pub unsafe extern "C" fn np_linalg_svd(
let outvh = outvh.as_mut().unwrap(); let outvh = outvh.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_svd", file!(), line!(), column!(), &err_msg); report_error("ValueError", "np_linalg_svd", file!(), line!(), column!(), &err_msg);
} }
@ -238,7 +161,7 @@ pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMa
let out = out.as_mut().unwrap(); let out = out.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_inv", file!(), line!(), column!(), &err_msg); report_error("ValueError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
} }
let dim1 = (*mat1).get_dims(); let dim1 = (*mat1).get_dims();
@ -277,7 +200,7 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
let out = out.as_mut().unwrap(); let out = out.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_pinv", file!(), line!(), column!(), &err_msg); report_error("ValueError", "np_linalg_pinv", file!(), line!(), column!(), &err_msg);
} }
let dim1 = (*mat1).get_dims(); let dim1 = (*mat1).get_dims();
@ -299,6 +222,76 @@ pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputM
} }
} }
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matrix_power(
mat1: *mut InputMatrix,
mat2: *mut InputMatrix,
out: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D", mat1.ndims);
report_error("ValueError", "np_linalg_matrix_power", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
let power = power[0];
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let abs_pow = power.abs();
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let mut result = matrix1.pow(abs_pow as u32);
if power < 0.0 {
if !result.is_invertible() {
report_error(
"LinAlgError",
"np_linalg_inv",
file!(),
line!(),
column!(),
"no inverse for Singular Matrix",
);
}
result = result.try_inverse().unwrap();
}
out_slice.copy_from_slice(result.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_det", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix.is_square() {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
out_slice[0] = matrix.determinant();
}
/// # Safety /// # Safety
/// ///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order /// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
@ -313,7 +306,7 @@ pub unsafe extern "C" fn sp_linalg_lu(
let out_u = out_u.as_mut().unwrap(); let out_u = out_u.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_lu", file!(), line!(), column!(), &err_msg); report_error("ValueError", "sp_linalg_lu", file!(), line!(), column!(), &err_msg);
} }
@ -346,7 +339,7 @@ pub unsafe extern "C" fn sp_linalg_schur(
let out_z = out_z.as_mut().unwrap(); let out_z = out_z.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_schur", file!(), line!(), column!(), &err_msg); report_error("ValueError", "sp_linalg_schur", file!(), line!(), column!(), &err_msg);
} }
@ -386,7 +379,7 @@ pub unsafe extern "C" fn sp_linalg_hessenberg(
let out_q = out_q.as_mut().unwrap(); let out_q = out_q.as_mut().unwrap();
if mat1.ndims != 2 { if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims); let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg); report_error("ValueError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
} }

View File

@ -42,14 +42,11 @@ done
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
nac3standalone=../../target/debug/nac3standalone nac3standalone=../../target/debug/nac3standalone
linalg=../../target/debug/deps/liblinalg-?*.a
elif [ -e ../../target/release/nac3standalone ]; then elif [ -e ../../target/release/nac3standalone ]; then
nac3standalone=../../target/release/nac3standalone nac3standalone=../../target/release/nac3standalone
linalg=../../target/release/deps/liblinalg-?*.a
else else
# used by Nix builds # used by Nix builds
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
linalg=../../target/x86_64-unknown-linux-gnu/release/deps/liblinalg-?*.a
fi fi
rm -f ./*.o ./*.bc demo rm -f ./*.o ./*.bc demo
@ -57,21 +54,16 @@ rm -f ./*.o ./*.bc demo
if [ -z "$i386" ]; then if [ -z "$i386" ]; then
$nac3standalone "${nac3args[@]}" $nac3standalone "${nac3args[@]}"
cd linalg && cargo build -q && cd ..
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o $linalg clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o linalg/liblinalg/x86_64/liblinalg.a
else else
# Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations # Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations
$nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}" $nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}"
# Compile linalg crate to provide functions compatible with i386 architecture
if [ ! -f ../../target/i686-unknown-linux-gnu/release/liblinalg.a ]; then
cd linalg && nix-shell -p rustup --command "RUSTFLAGS=\"-C target-cpu=i386 -C target-feature=+sse2\" cargo build -q --release --target=i686-unknown-linux-gnu" && cd ..
fi
linalg=../../target/i686-unknown-linux-gnu/release/liblinalg.a
clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o $linalg clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o linalg/liblinalg/i386/liblinalg.a
fi fi
if [ -z "$outfile" ]; then if [ -z "$outfile" ]; then

View File

@ -68,12 +68,6 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
for c in range(len(n[r])): for c in range(len(n[r])):
output_float64(n[r][c]) output_float64(n[r][c])
def output_ndarray_float_3(n: ndarray[float, Literal[3]]):
for r in range(len(n)):
for c in range(len(n[r])):
for z in range(len(n[r][c])):
output_float64(n[r][c][z])
def consume_ndarray_1(n: ndarray[float, Literal[1]]): def consume_ndarray_1(n: ndarray[float, Literal[1]]):
pass pass
@ -1436,43 +1430,49 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_transpose(): def test_ndarray_transpose():
# x: ndarray[float, 3] = np_array([[[1., 2.], [3., 4.], [5., 6.]]]) x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]])
x: ndarray[float, 1] = np_array([1., 2., 3.])
y = np_transpose(x) y = np_transpose(x)
z = np_transpose(y)
output_ndarray_float_1(x) output_ndarray_float_2(x)
output_ndarray_float_1(y) output_ndarray_float_2(y)
def test_ndarray_reshape(): def test_ndarray_reshape():
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
x: ndarray[float, 4] = np_reshape(w, (1, 2, 1, -1)) x = np_reshape(w, (1, 2, 1, -1))
y: ndarray[float, 2] = np_reshape(x, [2, -1]) y = np_reshape(x, [2, -1])
z: ndarray[float, 1] = np_reshape(w, 10) z = np_reshape(y, 10)
x1: ndarray[int32, 1] = np_array([1, 2, 3, 4])
x2: ndarray[int32, 2] = np_reshape(x1, (2, 2))
output_ndarray_float_1(w) output_ndarray_float_1(w)
output_ndarray_float_2(y) output_ndarray_float_2(y)
output_ndarray_float_1(z) output_ndarray_float_1(z)
def test_ndarray_dot(): def test_ndarray_dot():
x: ndarray[float, 1] = np_array([5.0, 1.0]) x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
y: ndarray[float, 1] = np_array([5.0, 1.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
z = np_dot(x, y) z1 = np_dot(x1, y1)
output_ndarray_float_1(x) x2: ndarray[int32, 1] = np_array([5, 1, 4, 2])
output_ndarray_float_1(y) y2: ndarray[int32, 1] = np_array([5, 1, 6, 6])
output_float64(z) z2 = np_dot(x2, y2)
def test_ndarray_linalg_matmul(): x3: ndarray[bool, 1] = np_array([True, True, True, True])
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) y3: ndarray[bool, 1] = np_array([True, True, True, True])
y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) z3 = np_dot(x3, y3)
z = np_linalg_matmul(x, y)
m = np_argmax(z) z4 = np_dot(2, 3)
z5 = np_dot(2., 3.)
z6 = np_dot(True, False)
output_ndarray_float_2(x) output_float64(z1)
output_ndarray_float_2(y) output_int32(z2)
output_ndarray_float_2(z) output_bool(z3)
output_int64(m) output_int32(z4)
output_float64(z5)
output_bool(z6)
def test_ndarray_cholesky(): def test_ndarray_cholesky():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
@ -1489,7 +1489,7 @@ def test_ndarray_qr():
# QR Factorization is not unique and gives different results in numpy and nalgebra # QR Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(y, z) a = y @ z
output_ndarray_float_2(a) output_ndarray_float_2(a)
def test_ndarray_linalg_inv(): def test_ndarray_linalg_inv():
@ -1506,6 +1506,20 @@ def test_ndarray_pinv():
output_ndarray_float_2(x) output_ndarray_float_2(x)
output_ndarray_float_2(y) output_ndarray_float_2(y)
def test_ndarray_matrix_power():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_matrix_power(x, -9)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_det():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_det(x)
output_ndarray_float_2(x)
output_float64(y)
def test_ndarray_schur(): def test_ndarray_schur():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]]) x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
t, z = sp_linalg_schur(x) t, z = sp_linalg_schur(x)
@ -1514,7 +1528,7 @@ def test_ndarray_schur():
# Schur Factorization is not unique and gives different results in scipy and nalgebra # Schur Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z)) a = (z @ t) @ np_linalg_inv(z)
output_ndarray_float_2(a) output_ndarray_float_2(a)
def test_ndarray_hessenberg(): def test_ndarray_hessenberg():
@ -1525,7 +1539,7 @@ def test_ndarray_hessenberg():
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra # Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(np_linalg_matmul(q, h), np_linalg_inv(q)) a = (q @ h) @ np_linalg_inv(q)
output_ndarray_float_2(a) output_ndarray_float_2(a)
@ -1546,7 +1560,7 @@ def test_ndarray_svd():
# SVD Factorization is not unique and gives different results in numpy and nalgebra # SVD Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(x, z) a = x @ z
output_ndarray_float_2(a) output_ndarray_float_2(a)
output_ndarray_float_1(y) output_ndarray_float_1(y)
@ -1733,12 +1747,13 @@ def run() -> int32:
test_ndarray_reshape() test_ndarray_reshape()
test_ndarray_dot() test_ndarray_dot()
test_ndarray_linalg_matmul()
test_ndarray_cholesky() test_ndarray_cholesky()
test_ndarray_qr() test_ndarray_qr()
test_ndarray_svd() test_ndarray_svd()
test_ndarray_linalg_inv() test_ndarray_linalg_inv()
test_ndarray_pinv() test_ndarray_pinv()
test_ndarray_matrix_power()
test_ndarray_det()
test_ndarray_lu() test_ndarray_lu()
test_ndarray_schur() test_ndarray_schur()
test_ndarray_hessenberg() test_ndarray_hessenberg()

View File

@ -1 +0,0 @@
../target/debug/libnac3artiq.so