Compare commits
14 Commits
318a675ea6
...
669c6aca6b
Author | SHA1 | Date |
---|---|---|
Sébastien Bourdeauducq | 669c6aca6b | |
abdul124 | 63d2b49b09 | |
abdul124 | bf709889c4 | |
abdul124 | 1c72698d02 | |
abdul124 | 54f883f0a5 | |
abdul124 | 4a6845dac6 | |
abdul124 | 00236f48bc | |
abdul124 | a3e6bb2292 | |
abdul124 | 17171065b1 | |
abdul124 | 540b35ec84 | |
abdul124 | 4bb00c52e3 | |
abdul124 | faf07527cb | |
abdul124 | d6a4d0a634 | |
abdul124 | 2242c5af43 |
|
@ -1,3 +1,4 @@
|
||||||
__pycache__
|
__pycache__
|
||||||
/target
|
/target
|
||||||
|
/nac3standalone/demo/linalg/target
|
||||||
nix/windows/msys2
|
nix/windows/msys2
|
||||||
|
|
26
flake.nix
26
flake.nix
|
@ -6,6 +6,7 @@
|
||||||
outputs = { self, nixpkgs }:
|
outputs = { self, nixpkgs }:
|
||||||
let
|
let
|
||||||
pkgs = import nixpkgs { system = "x86_64-linux"; };
|
pkgs = import nixpkgs { system = "x86_64-linux"; };
|
||||||
|
pkgs32 = import nixpkgs { system = "i686-linux"; };
|
||||||
in rec {
|
in rec {
|
||||||
packages.x86_64-linux = rec {
|
packages.x86_64-linux = rec {
|
||||||
llvm-nac3 = pkgs.callPackage ./nix/llvm {};
|
llvm-nac3 = pkgs.callPackage ./nix/llvm {};
|
||||||
|
@ -15,6 +16,22 @@
|
||||||
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
|
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
|
||||||
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
|
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
|
||||||
'';
|
'';
|
||||||
|
demo-linalg-stub = pkgs.rustPlatform.buildRustPackage {
|
||||||
|
name = "demo-linalg-stub";
|
||||||
|
src = ./nac3standalone/demo/linalg;
|
||||||
|
cargoLock = {
|
||||||
|
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
|
||||||
|
};
|
||||||
|
doCheck = false;
|
||||||
|
};
|
||||||
|
demo-linalg-stub32 = pkgs32.rustPlatform.buildRustPackage {
|
||||||
|
name = "demo-linalg-stub32";
|
||||||
|
src = ./nac3standalone/demo/linalg;
|
||||||
|
cargoLock = {
|
||||||
|
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
|
||||||
|
};
|
||||||
|
doCheck = false;
|
||||||
|
};
|
||||||
nac3artiq = pkgs.python3Packages.toPythonModule (
|
nac3artiq = pkgs.python3Packages.toPythonModule (
|
||||||
pkgs.rustPlatform.buildRustPackage rec {
|
pkgs.rustPlatform.buildRustPackage rec {
|
||||||
name = "nac3artiq";
|
name = "nac3artiq";
|
||||||
|
@ -32,7 +49,9 @@
|
||||||
echo "Checking nac3standalone demos..."
|
echo "Checking nac3standalone demos..."
|
||||||
pushd nac3standalone/demo
|
pushd nac3standalone/demo
|
||||||
patchShebangs .
|
patchShebangs .
|
||||||
./check_demos.sh
|
export DEMO_LINALG_STUB=${demo-linalg-stub}/lib/liblinalg.a
|
||||||
|
export DEMO_LINALG_STUB32=${demo-linalg-stub32}/lib/liblinalg.a
|
||||||
|
./check_demos.sh -i686
|
||||||
popd
|
popd
|
||||||
echo "Running Cargo tests..."
|
echo "Running Cargo tests..."
|
||||||
cargoCheckHook
|
cargoCheckHook
|
||||||
|
@ -162,6 +181,11 @@
|
||||||
pre-commit
|
pre-commit
|
||||||
rustfmt
|
rustfmt
|
||||||
];
|
];
|
||||||
|
shellHook =
|
||||||
|
''
|
||||||
|
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
|
||||||
|
export DEMO_LINALG_STUB32=${packages.x86_64-linux.demo-linalg-stub32}/lib/liblinalg.a
|
||||||
|
'';
|
||||||
};
|
};
|
||||||
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
||||||
name = "nac3-dev-shell-msys2";
|
name = "nac3-dev-shell-msys2";
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
use inkwell::types::BasicTypeEnum;
|
use inkwell::types::BasicTypeEnum;
|
||||||
use inkwell::values::BasicValueEnum;
|
use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
|
||||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
|
use crate::codegen::classes::{
|
||||||
|
NDArrayValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
|
};
|
||||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||||
|
@ -31,7 +33,6 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let (n_ty, n) = n;
|
let (n_ty, n) = n;
|
||||||
|
|
||||||
Ok(match n {
|
Ok(match n {
|
||||||
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
|
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
|
||||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
||||||
|
@ -1836,3 +1837,501 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
|
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it
|
||||||
|
fn build_output_struct<'ctx>(
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
out_matrices: Vec<BasicValueEnum<'ctx>>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let field_ty =
|
||||||
|
out_matrices.iter().map(BasicValueEnum::get_type).collect::<Vec<BasicTypeEnum>>();
|
||||||
|
let out_ty = ctx.ctx.struct_type(&field_ty, false);
|
||||||
|
let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap();
|
||||||
|
|
||||||
|
for (i, v) in out_matrices.into_iter().enumerate() {
|
||||||
|
unsafe {
|
||||||
|
let ptr = ctx
|
||||||
|
.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
out_ptr,
|
||||||
|
&[
|
||||||
|
ctx.ctx.i32_type().const_zero(),
|
||||||
|
ctx.ctx.i32_type().const_int(i as u64, false),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
ctx.builder.build_store(ptr, v).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out_ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_linalg_cholesky` linalg function
|
||||||
|
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_linalg_cholesky";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
let dim0 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let dim1 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
|
||||||
|
extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
|
||||||
|
Ok(out)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_linalg_qr` linalg function
|
||||||
|
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_linalg_qr";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||||
|
};
|
||||||
|
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
let dim0 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let dim1 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||||
|
|
||||||
|
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
|
||||||
|
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
|
||||||
|
|
||||||
|
let out_ptr = build_output_struct(ctx, vec![out_q, out_r]);
|
||||||
|
|
||||||
|
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_linalg_svd` linalg function
|
||||||
|
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_linalg_svd";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
|
||||||
|
let dim0 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let dim1 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||||
|
|
||||||
|
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
|
||||||
|
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None);
|
||||||
|
|
||||||
|
let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]);
|
||||||
|
|
||||||
|
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_linalg_inv` linalg function
|
||||||
|
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_linalg_inv";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
let dim0 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let dim1 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
|
||||||
|
extern_fns::call_np_linalg_inv(ctx, x1, out, None);
|
||||||
|
Ok(out)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_linalg_pinv` linalg function
|
||||||
|
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_linalg_pinv";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
|
||||||
|
let dim0 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let dim1 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
|
||||||
|
extern_fns::call_np_linalg_pinv(ctx, x1, out, None);
|
||||||
|
Ok(out)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `sp_linalg_lu` linalg function
|
||||||
|
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "sp_linalg_lu";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
|
||||||
|
let dim0 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let dim1 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||||
|
|
||||||
|
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
|
||||||
|
extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None);
|
||||||
|
|
||||||
|
let out_ptr = build_output_struct(ctx, vec![out_l, out_u]);
|
||||||
|
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_linalg_matrix_power` linalg function
|
||||||
|
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
x2: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let (x2_ty, x2) = x2;
|
||||||
|
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||||
|
let n2_array = numpy::create_ndarray_const_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&[llvm_usize.const_int(1, false)],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
unsafe {
|
||||||
|
n2_array.data().set_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_zero(),
|
||||||
|
n2.as_basic_value_enum(),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
let n2_array = n2_array.as_base_value().as_basic_value_enum();
|
||||||
|
|
||||||
|
let outdim0 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let outdim1 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
|
||||||
|
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
|
||||||
|
Ok(out)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_linalg_det` linalg function
|
||||||
|
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
if let BasicValueEnum::PointerValue(_) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||||
|
let out = numpy::create_ndarray_const_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&[llvm_usize.const_int(1, false)],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
|
||||||
|
let res =
|
||||||
|
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||||
|
Ok(res)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `sp_linalg_schur` linalg function
|
||||||
|
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "sp_linalg_schur";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
|
||||||
|
let dim0 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
|
||||||
|
extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None);
|
||||||
|
|
||||||
|
let out_ptr = build_output_struct(ctx, vec![out_t, out_z]);
|
||||||
|
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `sp_linalg_hessenberg` linalg function
|
||||||
|
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "sp_linalg_hessenberg";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
|
||||||
|
let dim0 = unsafe {
|
||||||
|
n1.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None);
|
||||||
|
|
||||||
|
let out_ptr = build_output_struct(ctx, vec![out_h, out_q]);
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_load(out_ptr, "Hessenberg_decomposition_result")
|
||||||
|
.map(Into::into)
|
||||||
|
.unwrap())
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -130,3 +130,62 @@ pub fn call_ldexp<'ctx>(
|
||||||
.map(Either::unwrap_left)
|
.map(Either::unwrap_left)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Macro to generate `np_linalg` and `sp_linalg` functions
|
||||||
|
/// The function takes as input `NDArray` and returns ()
|
||||||
|
///
|
||||||
|
/// Arguments:
|
||||||
|
/// * `$fn_name:ident`: The identifier of the rust function to be generated
|
||||||
|
/// * `$extern_fn:literal`: Name of underlying extern function
|
||||||
|
/// * (2/3/4): Number of `NDArray` that function takes as input
|
||||||
|
///
|
||||||
|
/// Note:
|
||||||
|
/// The operands and resulting `NDArray` are both passed as input to the funcion
|
||||||
|
/// It is the responsibility of caller to ensure that output `NDArray` is properly allocated on stack
|
||||||
|
/// The function changes the content of the output `NDArray` in-place
|
||||||
|
macro_rules! generate_linalg_extern_fn {
|
||||||
|
($fn_name:ident, $extern_fn:literal, 2) => {
|
||||||
|
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2);
|
||||||
|
};
|
||||||
|
($fn_name:ident, $extern_fn:literal, 3) => {
|
||||||
|
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3);
|
||||||
|
};
|
||||||
|
($fn_name:ident, $extern_fn:literal, 4) => {
|
||||||
|
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
|
||||||
|
};
|
||||||
|
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
|
||||||
|
#[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )]
|
||||||
|
pub fn $fn_name<'ctx>(
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>
|
||||||
|
$(,$input_matrix: BasicValueEnum<'ctx>)*,
|
||||||
|
name: Option<&str>,
|
||||||
|
){
|
||||||
|
const FN_NAME: &str = $extern_fn;
|
||||||
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||||
|
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
|
||||||
|
|
||||||
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||||
|
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||||
|
func.add_attribute(
|
||||||
|
AttributeLoc::Function,
|
||||||
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
func
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
|
||||||
|
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
|
||||||
|
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
|
||||||
|
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
|
||||||
|
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
|
||||||
|
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
|
||||||
|
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
|
||||||
|
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
|
||||||
|
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
|
||||||
|
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
|
||||||
|
|
|
@ -26,12 +26,15 @@ use crate::{
|
||||||
typedef::{FunSignature, Type, TypeEnum},
|
typedef::{FunSignature, Type, TypeEnum},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
|
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::BasicType,
|
types::BasicType,
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
|
use inkwell::{
|
||||||
|
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
|
||||||
|
values::BasicValue,
|
||||||
|
};
|
||||||
use nac3parser::ast::{Operator, StrRef};
|
use nac3parser::ast::{Operator, StrRef};
|
||||||
|
|
||||||
/// Creates an uninitialized `NDArray` instance.
|
/// Creates an uninitialized `NDArray` instance.
|
||||||
|
@ -159,7 +162,7 @@ where
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
|
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
|
||||||
fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
|
@ -2026,3 +2029,493 @@ pub fn gen_ndarray_fill<'ctx>(
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.transpose`.
|
||||||
|
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "ndarray_transpose";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
// Dimensions are reversed in the transposed array
|
||||||
|
let out = create_ndarray_dyn_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&n1,
|
||||||
|
|_, ctx, n| Ok(n.load_ndims(ctx)),
|
||||||
|
|generator, ctx, n, idx| {
|
||||||
|
let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap();
|
||||||
|
let new_idx = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
|
||||||
|
.unwrap();
|
||||||
|
unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) }
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(n_sz, false),
|
||||||
|
|generator, ctx, _, idx| {
|
||||||
|
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
|
||||||
|
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
|
||||||
|
ctx.builder.build_store(rem_idx, idx).unwrap();
|
||||||
|
|
||||||
|
// Incrementally calculate the new index in the transposed array
|
||||||
|
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
|
||||||
|
// The formula used for indexing is:
|
||||||
|
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(n1.load_ndims(ctx), false),
|
||||||
|
|generator, ctx, _, ndim| {
|
||||||
|
let ndim_rev =
|
||||||
|
ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap();
|
||||||
|
let ndim_rev = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
|
||||||
|
.unwrap();
|
||||||
|
let dim = unsafe {
|
||||||
|
n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
let rem_idx_val =
|
||||||
|
ctx.builder.build_load(rem_idx, "").unwrap().into_int_value();
|
||||||
|
let new_idx_val =
|
||||||
|
ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
|
||||||
|
|
||||||
|
let add_component =
|
||||||
|
ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap();
|
||||||
|
let rem_idx_val =
|
||||||
|
ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap();
|
||||||
|
|
||||||
|
let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap();
|
||||||
|
let new_idx_val =
|
||||||
|
ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap();
|
||||||
|
|
||||||
|
ctx.builder.build_store(rem_idx, rem_idx_val).unwrap();
|
||||||
|
ctx.builder.build_store(new_idx, new_idx_val).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
|
||||||
|
unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) };
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(out.as_base_value().into())
|
||||||
|
} else {
|
||||||
|
unreachable!(
|
||||||
|
"{FN_NAME}() not supported for '{}'",
|
||||||
|
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
|
||||||
|
///
|
||||||
|
/// * `x1` - `NDArray` to reshape.
|
||||||
|
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
|
||||||
|
/// Just like numpy, the `shape` argument can be:
|
||||||
|
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
|
||||||
|
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
||||||
|
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
||||||
|
/// Note that unlike other generating functions, one of the dimesions in the shape can be negative
|
||||||
|
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
shape: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "ndarray_reshape";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let (_, shape) = shape;
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
|
||||||
|
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
|
||||||
|
|
||||||
|
let out = match shape {
|
||||||
|
BasicValueEnum::PointerValue(shape_list_ptr)
|
||||||
|
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
|
||||||
|
{
|
||||||
|
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
|
||||||
|
|
||||||
|
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
|
||||||
|
// Check for -1 in dimensions
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(shape_list.load_size(ctx, None), false),
|
||||||
|
|generator, ctx, _, idx| {
|
||||||
|
let ele =
|
||||||
|
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
||||||
|
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
ele,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|_, ctx| -> Result<Option<IntValue>, String> {
|
||||||
|
let num_neg_value =
|
||||||
|
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
||||||
|
let num_neg_value = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_add(
|
||||||
|
num_neg_value,
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
|
||||||
|
Ok(None)
|
||||||
|
},
|
||||||
|
|_, ctx| {
|
||||||
|
let acc_value =
|
||||||
|
ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||||
|
let acc_value =
|
||||||
|
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
|
||||||
|
ctx.builder.build_store(acc, acc_value).unwrap();
|
||||||
|
Ok(None)
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||||
|
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
||||||
|
// Generate the output shape by filling -1 with `rem`
|
||||||
|
create_ndarray_dyn_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&shape_list,
|
||||||
|
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|
||||||
|
|generator, ctx, shape_list, idx| {
|
||||||
|
let dim =
|
||||||
|
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
||||||
|
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
Ok(gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
dim,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|_, _| Ok(Some(rem)),
|
||||||
|
|_, _| Ok(Some(dim)),
|
||||||
|
)?
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
BasicValueEnum::StructValue(shape_tuple) => {
|
||||||
|
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
||||||
|
|
||||||
|
let ndims = shape_tuple.get_type().count_fields();
|
||||||
|
// Check for -1 in dims
|
||||||
|
for dim_i in 0..ndims {
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_extract_value(shape_tuple, dim_i, "")
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
dim,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|_, ctx| -> Result<Option<IntValue>, String> {
|
||||||
|
let num_negs =
|
||||||
|
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
||||||
|
let num_negs = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
|
||||||
|
.unwrap();
|
||||||
|
ctx.builder.build_store(num_neg, num_negs).unwrap();
|
||||||
|
Ok(None)
|
||||||
|
},
|
||||||
|
|_, ctx| {
|
||||||
|
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||||
|
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
|
||||||
|
ctx.builder.build_store(acc, acc_val).unwrap();
|
||||||
|
Ok(None)
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||||
|
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
||||||
|
let mut shape = Vec::with_capacity(ndims as usize);
|
||||||
|
|
||||||
|
// Reconstruct shape filling negatives with rem
|
||||||
|
for dim_i in 0..ndims {
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_extract_value(shape_tuple, dim_i, "")
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
let dim = gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
dim,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|_, _| Ok(Some(rem)),
|
||||||
|
|_, _| Ok(Some(dim)),
|
||||||
|
)?
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
shape.push(dim);
|
||||||
|
}
|
||||||
|
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
||||||
|
}
|
||||||
|
BasicValueEnum::IntValue(shape_int) => {
|
||||||
|
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
||||||
|
let shape_int = gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
shape_int,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|_, _| Ok(Some(n_sz)),
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
|
||||||
|
},
|
||||||
|
)?
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Only allow one dimension to be negative
|
||||||
|
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder
|
||||||
|
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
|
||||||
|
.unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"can only specify one unknown dimension",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
// The new shape must be compatible with the old shape
|
||||||
|
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"cannot reshape array of size {0} into provided shape of size {1}",
|
||||||
|
[Some(n_sz), Some(out_sz), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(n_sz, false),
|
||||||
|
|generator, ctx, _, idx| {
|
||||||
|
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(out.as_base_value().into())
|
||||||
|
} else {
|
||||||
|
unreachable!(
|
||||||
|
"{FN_NAME}() not supported for '{}'",
|
||||||
|
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.dot`.
|
||||||
|
/// Calculate inner product of two vectors or literals
|
||||||
|
/// For matrix multiplication use `np_matmul`
|
||||||
|
///
|
||||||
|
/// The input `NDArray` are flattened and treated as 1D
|
||||||
|
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
|
||||||
|
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
x2: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "ndarray_dot";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let (_, x2) = x2;
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
match (x1, x2) {
|
||||||
|
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
||||||
|
|
||||||
|
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"shapes ({0}), ({1}) not aligned",
|
||||||
|
[Some(n1_sz), Some(n2_sz), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
let identity =
|
||||||
|
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||||
|
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
|
||||||
|
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(n1_sz, false),
|
||||||
|
|generator, ctx, _, idx| {
|
||||||
|
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
|
||||||
|
let product = match elem1 {
|
||||||
|
BasicValueEnum::IntValue(e1) => ctx
|
||||||
|
.builder
|
||||||
|
.build_int_mul(e1, elem2.into_int_value(), "")
|
||||||
|
.unwrap()
|
||||||
|
.as_basic_value_enum(),
|
||||||
|
BasicValueEnum::FloatValue(e1) => ctx
|
||||||
|
.builder
|
||||||
|
.build_float_mul(e1, elem2.into_float_value(), "")
|
||||||
|
.unwrap()
|
||||||
|
.as_basic_value_enum(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
||||||
|
let acc_val = match acc_val {
|
||||||
|
BasicValueEnum::IntValue(e1) => ctx
|
||||||
|
.builder
|
||||||
|
.build_int_add(e1, product.into_int_value(), "")
|
||||||
|
.unwrap()
|
||||||
|
.as_basic_value_enum(),
|
||||||
|
BasicValueEnum::FloatValue(e1) => ctx
|
||||||
|
.builder
|
||||||
|
.build_float_add(e1, product.into_float_value(), "")
|
||||||
|
.unwrap()
|
||||||
|
.as_basic_value_enum(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(acc, acc_val).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
||||||
|
Ok(acc_val)
|
||||||
|
}
|
||||||
|
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
|
||||||
|
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
||||||
|
}
|
||||||
|
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
|
||||||
|
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
||||||
|
}
|
||||||
|
_ => unreachable!(
|
||||||
|
"{FN_NAME}() not supported for '{}'",
|
||||||
|
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -556,6 +556,22 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
| PrimDef::FunNpLdExp
|
| PrimDef::FunNpLdExp
|
||||||
| PrimDef::FunNpHypot
|
| PrimDef::FunNpHypot
|
||||||
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
||||||
|
|
||||||
|
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
|
||||||
|
self.build_np_sp_ndarray_function(prim)
|
||||||
|
}
|
||||||
|
|
||||||
|
PrimDef::FunNpDot
|
||||||
|
| PrimDef::FunNpLinalgCholesky
|
||||||
|
| PrimDef::FunNpLinalgQr
|
||||||
|
| PrimDef::FunNpLinalgSvd
|
||||||
|
| PrimDef::FunNpLinalgInv
|
||||||
|
| PrimDef::FunNpLinalgPinv
|
||||||
|
| PrimDef::FunNpLinalgMatrixPower
|
||||||
|
| PrimDef::FunNpLinalgDet
|
||||||
|
| PrimDef::FunSpLinalgLu
|
||||||
|
| PrimDef::FunSpLinalgSchur
|
||||||
|
| PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim),
|
||||||
};
|
};
|
||||||
|
|
||||||
if cfg!(debug_assertions) {
|
if cfg!(debug_assertions) {
|
||||||
|
@ -1874,6 +1890,205 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build np/sp functions that take as input `NDArray` only
|
||||||
|
fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
|
||||||
|
|
||||||
|
match prim {
|
||||||
|
PrimDef::FunNpTranspose => {
|
||||||
|
let ndarray_ty = self.unifier.get_fresh_var_with_range(
|
||||||
|
&[self.ndarray_num_ty],
|
||||||
|
Some("T".into()),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&into_var_map([ndarray_ty]),
|
||||||
|
prim.name(),
|
||||||
|
ndarray_ty.ty,
|
||||||
|
&[(ndarray_ty.ty, "x")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let arg_ty = fun.0.args[0].ty;
|
||||||
|
let arg_val =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
|
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
|
||||||
|
// the `param_ty` for `create_fn_by_codegen`.
|
||||||
|
//
|
||||||
|
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
|
||||||
|
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
||||||
|
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
||||||
|
PrimDef::FunNpReshape => create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&VarMap::new(),
|
||||||
|
prim.name(),
|
||||||
|
self.ndarray_num_ty,
|
||||||
|
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
let x2_ty = fun.0.args[1].ty;
|
||||||
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build `np_linalg` and `sp_linalg` functions
|
||||||
|
///
|
||||||
|
/// The input to these functions must be floating point `NDArray`
|
||||||
|
fn build_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(
|
||||||
|
prim,
|
||||||
|
&[
|
||||||
|
PrimDef::FunNpDot,
|
||||||
|
PrimDef::FunNpLinalgCholesky,
|
||||||
|
PrimDef::FunNpLinalgQr,
|
||||||
|
PrimDef::FunNpLinalgSvd,
|
||||||
|
PrimDef::FunNpLinalgInv,
|
||||||
|
PrimDef::FunNpLinalgPinv,
|
||||||
|
PrimDef::FunNpLinalgMatrixPower,
|
||||||
|
PrimDef::FunNpLinalgDet,
|
||||||
|
PrimDef::FunSpLinalgLu,
|
||||||
|
PrimDef::FunSpLinalgSchur,
|
||||||
|
PrimDef::FunSpLinalgHessenberg,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
|
match prim {
|
||||||
|
PrimDef::FunNpDot => create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&self.num_or_ndarray_var_map,
|
||||||
|
prim.name(),
|
||||||
|
self.num_ty.ty,
|
||||||
|
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
let x2_ty = fun.0.args[1].ty;
|
||||||
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
|
||||||
|
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
|
||||||
|
PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => {
|
||||||
|
create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&VarMap::new(),
|
||||||
|
prim.name(),
|
||||||
|
self.ndarray_float_2d,
|
||||||
|
&[(self.ndarray_float_2d, "x1")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
|
||||||
|
let func = match prim {
|
||||||
|
PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky,
|
||||||
|
PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv,
|
||||||
|
PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
PrimDef::FunNpLinalgQr
|
||||||
|
| PrimDef::FunSpLinalgLu
|
||||||
|
| PrimDef::FunSpLinalgSchur
|
||||||
|
| PrimDef::FunSpLinalgHessenberg => {
|
||||||
|
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
|
||||||
|
ty: vec![self.ndarray_float_2d, self.ndarray_float_2d],
|
||||||
|
});
|
||||||
|
create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&VarMap::new(),
|
||||||
|
prim.name(),
|
||||||
|
ret_ty,
|
||||||
|
&[(self.ndarray_float_2d, "x1")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
|
||||||
|
let func = match prim {
|
||||||
|
PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr,
|
||||||
|
PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu,
|
||||||
|
PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur,
|
||||||
|
PrimDef::FunSpLinalgHessenberg => {
|
||||||
|
builtin_fns::call_sp_linalg_hessenberg
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
PrimDef::FunNpLinalgSvd => {
|
||||||
|
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
|
||||||
|
ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d],
|
||||||
|
});
|
||||||
|
create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&VarMap::new(),
|
||||||
|
prim.name(),
|
||||||
|
ret_ty,
|
||||||
|
&[(self.ndarray_float_2d, "x1")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
|
||||||
|
Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
PrimDef::FunNpLinalgMatrixPower => create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&VarMap::new(),
|
||||||
|
prim.name(),
|
||||||
|
self.ndarray_float_2d,
|
||||||
|
&[(self.ndarray_float_2d, "x1"), (self.primitives.int32, "power")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
let x2_ty = fun.0.args[1].ty;
|
||||||
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
|
||||||
|
Ok(Some(builtin_fns::call_np_linalg_matrix_power(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(x1_ty, x1_val),
|
||||||
|
(x2_ty, x2_val),
|
||||||
|
)?))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
PrimDef::FunNpLinalgDet => create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&VarMap::new(),
|
||||||
|
prim.name(),
|
||||||
|
self.primitives.float,
|
||||||
|
&[(self.ndarray_float_2d, "x1")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
Ok(Some(builtin_fns::call_np_linalg_det(generator, ctx, (x1_ty, x1_val))?))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
||||||
(prim.simple_name().into(), method_ty, prim.id())
|
(prim.simple_name().into(), method_ty, prim.id())
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,6 +99,21 @@ pub enum PrimDef {
|
||||||
FunNpLdExp,
|
FunNpLdExp,
|
||||||
FunNpHypot,
|
FunNpHypot,
|
||||||
FunNpNextAfter,
|
FunNpNextAfter,
|
||||||
|
FunNpTranspose,
|
||||||
|
FunNpReshape,
|
||||||
|
|
||||||
|
// Linalg functions
|
||||||
|
FunNpDot,
|
||||||
|
FunNpLinalgCholesky,
|
||||||
|
FunNpLinalgQr,
|
||||||
|
FunNpLinalgSvd,
|
||||||
|
FunNpLinalgInv,
|
||||||
|
FunNpLinalgPinv,
|
||||||
|
FunNpLinalgMatrixPower,
|
||||||
|
FunNpLinalgDet,
|
||||||
|
FunSpLinalgLu,
|
||||||
|
FunSpLinalgSchur,
|
||||||
|
FunSpLinalgHessenberg,
|
||||||
|
|
||||||
// Miscellaneous Python & NAC3 functions
|
// Miscellaneous Python & NAC3 functions
|
||||||
FunInt32,
|
FunInt32,
|
||||||
|
@ -270,6 +285,21 @@ impl PrimDef {
|
||||||
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
||||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||||
|
PrimDef::FunNpTranspose => fun("np_transpose", None),
|
||||||
|
PrimDef::FunNpReshape => fun("np_reshape", None),
|
||||||
|
|
||||||
|
// Linalg functions
|
||||||
|
PrimDef::FunNpDot => fun("np_dot", None),
|
||||||
|
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
|
||||||
|
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
|
||||||
|
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
|
||||||
|
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
|
||||||
|
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
|
||||||
|
PrimDef::FunNpLinalgMatrixPower => fun("np_linalg_matrix_power", None),
|
||||||
|
PrimDef::FunNpLinalgDet => fun("np_linalg_det", None),
|
||||||
|
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
|
||||||
|
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
|
||||||
|
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
|
||||||
|
|
||||||
// Miscellaneous Python & NAC3 functions
|
// Miscellaneous Python & NAC3 functions
|
||||||
PrimDef::FunInt32 => fun("int32", None),
|
PrimDef::FunInt32 => fun("int32", None),
|
||||||
|
|
|
@ -5,7 +5,7 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1130,6 +1130,44 @@ impl<'a> Inferencer<'a> {
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if id == &"np_dot".into() {
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
let arg1 = self.fold_expr(args.remove(0))?;
|
||||||
|
let arg0_ty = arg0.custom.unwrap();
|
||||||
|
|
||||||
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|
{
|
||||||
|
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||||
|
|
||||||
|
ndarray_dtype
|
||||||
|
} else {
|
||||||
|
arg0_ty
|
||||||
|
};
|
||||||
|
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
|
||||||
|
FuncArg { name: "x2".into(), ty: arg1.custom.unwrap(), default_value: None },
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||||
|
}),
|
||||||
|
args: vec![arg0, arg1],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
||||||
let arg0 = self.fold_expr(args.remove(0))?;
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
let arg0_ty = arg0.custom.unwrap();
|
let arg0_ty = arg0.custom.unwrap();
|
||||||
|
@ -1389,7 +1427,45 @@ impl<'a> Inferencer<'a> {
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
// 2-argument ndarray n-dimensional factory functions
|
||||||
|
if id == &"np_reshape".into() && args.len() == 2 {
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
|
||||||
|
let shape_expr = args.remove(0);
|
||||||
|
let (ndims, shape) =
|
||||||
|
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
|
||||||
|
|
||||||
|
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(self.unifier, arg0.custom.unwrap());
|
||||||
|
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(elem_ty), Some(ndims));
|
||||||
|
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
|
||||||
|
FuncArg {
|
||||||
|
name: "shape".into(),
|
||||||
|
ty: shape.custom.unwrap(),
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||||
|
}),
|
||||||
|
args: vec![arg0, shape],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
// 2-argument ndarray n-dimensional creation functions
|
// 2-argument ndarray n-dimensional creation functions
|
||||||
if id == &"np_full".into() && args.len() == 2 {
|
if id == &"np_full".into() && args.len() == 2 {
|
||||||
let ExprKind::List { elts, .. } = &args[0].node else {
|
let ExprKind::List { elts, .. } = &args[0].node else {
|
||||||
|
|
|
@ -3,26 +3,49 @@
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
if [ -z "$1" ]; then
|
if [ -z "$1" ]; then
|
||||||
echo "Requires at least one argument"
|
echo "No argument supplied"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
declare -a nac3args
|
declare -a nac3args
|
||||||
|
while [ $# -ge 2 ]; do
|
||||||
|
case "$1" in
|
||||||
|
--help)
|
||||||
|
echo "Usage: check_demo.sh [-i686] -- demo [NAC3ARGS...]"
|
||||||
|
exit
|
||||||
|
;;
|
||||||
|
-i686)
|
||||||
|
i686=1
|
||||||
|
;;
|
||||||
|
--)
|
||||||
|
shift
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
shift
|
||||||
|
done
|
||||||
|
|
||||||
|
demo="$1"
|
||||||
|
shift
|
||||||
while [ $# -gt 1 ]; do
|
while [ $# -gt 1 ]; do
|
||||||
nac3args+=("$1")
|
nac3args+=("$1")
|
||||||
shift
|
shift
|
||||||
done
|
done
|
||||||
demo="$1"
|
|
||||||
|
|
||||||
echo "### Checking $demo..."
|
echo "### Checking $demo..."
|
||||||
|
|
||||||
# Get reference output
|
|
||||||
echo ">>>>>> Running $demo with the Python interpreter"
|
echo ">>>>>> Running $demo with the Python interpreter"
|
||||||
./interpret_demo.py "$demo" > interpreted.log
|
./interpret_demo.py "$demo" > interpreted.log
|
||||||
|
|
||||||
|
if [ -n "$i686" ]; then
|
||||||
echo "...... Trying NAC3's 32-bit code generator output"
|
echo "...... Trying NAC3's 32-bit code generator output"
|
||||||
./run_demo.sh -i386 --out run_32.log "${nac3args[@]}" "$demo"
|
./run_demo.sh -i686 --out run_32.log "${nac3args[@]}" "$demo"
|
||||||
diff -Nau interpreted.log run_32.log
|
diff -Nau interpreted.log run_32.log
|
||||||
|
fi
|
||||||
|
|
||||||
echo "...... Trying NAC3's 64-bit code generator output"
|
echo "...... Trying NAC3's 64-bit code generator output"
|
||||||
./run_demo.sh --out run_64.log "${nac3args[@]}" "$demo"
|
./run_demo.sh --out run_64.log "${nac3args[@]}" "$demo"
|
||||||
|
|
|
@ -6,6 +6,7 @@ import importlib.machinery
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
import scipy as sp
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from numpy import int32, int64, uint32, uint64
|
from numpy import int32, int64, uint32, uint64
|
||||||
|
@ -217,6 +218,8 @@ def patch(module):
|
||||||
module.np_ldexp = np.ldexp
|
module.np_ldexp = np.ldexp
|
||||||
module.np_hypot = np.hypot
|
module.np_hypot = np.hypot
|
||||||
module.np_nextafter = np.nextafter
|
module.np_nextafter = np.nextafter
|
||||||
|
module.np_transpose = np.transpose
|
||||||
|
module.np_reshape = np.reshape
|
||||||
|
|
||||||
# SciPy Math functions
|
# SciPy Math functions
|
||||||
module.sp_spec_erf = special.erf
|
module.sp_spec_erf = special.erf
|
||||||
|
@ -226,6 +229,20 @@ def patch(module):
|
||||||
module.sp_spec_j0 = special.j0
|
module.sp_spec_j0 = special.j0
|
||||||
module.sp_spec_j1 = special.j1
|
module.sp_spec_j1 = special.j1
|
||||||
|
|
||||||
|
# Linalg functions
|
||||||
|
module.np_dot = np.dot
|
||||||
|
module.np_linalg_cholesky = np.linalg.cholesky
|
||||||
|
module.np_linalg_qr = np.linalg.qr
|
||||||
|
module.np_linalg_svd = np.linalg.svd
|
||||||
|
module.np_linalg_inv = np.linalg.inv
|
||||||
|
module.np_linalg_pinv = np.linalg.pinv
|
||||||
|
module.np_linalg_matrix_power = np.linalg.matrix_power
|
||||||
|
module.np_linalg_det = np.linalg.det
|
||||||
|
|
||||||
|
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
|
||||||
|
module.sp_linalg_schur = sp.linalg.schur
|
||||||
|
module.sp_linalg_hessenberg = lambda x: sp.linalg.hessenberg(x, True)
|
||||||
|
|
||||||
def file_import(filename, prefix="file_import_"):
|
def file_import(filename, prefix="file_import_"):
|
||||||
filename = pathlib.Path(filename)
|
filename = pathlib.Path(filename)
|
||||||
modname = prefix + filename.stem
|
modname = prefix + filename.stem
|
||||||
|
|
|
@ -0,0 +1,114 @@
|
||||||
|
# This file is automatically @generated by Cargo.
|
||||||
|
# It is not intended for manual editing.
|
||||||
|
version = 3
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "approx"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "autocfg"
|
||||||
|
version = "1.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cslice"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libm"
|
||||||
|
version = "0.2.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "linalg"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"cslice",
|
||||||
|
"nalgebra",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nalgebra"
|
||||||
|
version = "0.32.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
|
||||||
|
dependencies = [
|
||||||
|
"approx",
|
||||||
|
"num-complex",
|
||||||
|
"num-rational",
|
||||||
|
"num-traits",
|
||||||
|
"simba",
|
||||||
|
"typenum",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.4.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-integer"
|
||||||
|
version = "0.1.46"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-rational"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
||||||
|
dependencies = [
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-traits"
|
||||||
|
version = "0.2.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"libm",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "paste"
|
||||||
|
version = "1.0.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "simba"
|
||||||
|
version = "0.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
|
||||||
|
dependencies = [
|
||||||
|
"approx",
|
||||||
|
"num-complex",
|
||||||
|
"num-traits",
|
||||||
|
"paste",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typenum"
|
||||||
|
version = "1.17.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
|
@ -0,0 +1,13 @@
|
||||||
|
[package]
|
||||||
|
name = "linalg"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
crate-type = ["staticlib"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
|
||||||
|
cslice = "0.3.0"
|
||||||
|
|
||||||
|
[workspace]
|
|
@ -0,0 +1,406 @@
|
||||||
|
// Uses `nalgebra` crate to invoke `np_linalg` and `sp_linalg` functions
|
||||||
|
// When converting between `nalgebra::Matrix` and `NDArray` following considerations are necessary
|
||||||
|
//
|
||||||
|
// * Both `nalgebra::Matrix` and `NDArray` require their content to be stored in row-major order
|
||||||
|
// * `NDArray` data pointer can be directly read and converted to `nalgebra::Matrix` (row and column number must be known)
|
||||||
|
// * `nalgebra::Matrix::as_slice` returns the content of matrix in column-major order and initial data needs to be transposed before storing it in `NDArray` data pointer
|
||||||
|
|
||||||
|
use core::slice;
|
||||||
|
use nalgebra::DMatrix;
|
||||||
|
|
||||||
|
fn report_error(
|
||||||
|
error_name: &str,
|
||||||
|
fn_name: &str,
|
||||||
|
file_name: &str,
|
||||||
|
line_num: u32,
|
||||||
|
col_num: u32,
|
||||||
|
err_msg: &str,
|
||||||
|
) -> ! {
|
||||||
|
panic!(
|
||||||
|
"Exception {} from {} in {}:{}:{}, message: {}",
|
||||||
|
error_name, fn_name, file_name, line_num, col_num, err_msg
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct InputMatrix {
|
||||||
|
pub ndims: usize,
|
||||||
|
pub dims: *const usize,
|
||||||
|
pub data: *mut f64,
|
||||||
|
}
|
||||||
|
impl InputMatrix {
|
||||||
|
fn get_dims(&mut self) -> Vec<usize> {
|
||||||
|
let dims = unsafe { slice::from_raw_parts(self.dims, self.ndims) };
|
||||||
|
dims.to_vec()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
if dim1[0] != dim1[1] {
|
||||||
|
let err_msg =
|
||||||
|
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
|
||||||
|
report_error("LinAlgError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let outdim = out.get_dims();
|
||||||
|
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
|
||||||
|
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
let result = matrix1.cholesky();
|
||||||
|
match result {
|
||||||
|
Some(res) => {
|
||||||
|
out_slice.copy_from_slice(res.unpack().transpose().as_slice());
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
report_error(
|
||||||
|
"LinAlgError",
|
||||||
|
"np_linalg_cholesky",
|
||||||
|
file!(),
|
||||||
|
line!(),
|
||||||
|
column!(),
|
||||||
|
"Matrix is not positive definite",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_qr(
|
||||||
|
mat1: *mut InputMatrix,
|
||||||
|
out_q: *mut InputMatrix,
|
||||||
|
out_r: *mut InputMatrix,
|
||||||
|
) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let out_q = out_q.as_mut().unwrap();
|
||||||
|
let out_r = out_r.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
let outq_dim = (*out_q).get_dims();
|
||||||
|
let outr_dim = (*out_r).get_dims();
|
||||||
|
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, outq_dim[0] * outq_dim[1]) };
|
||||||
|
let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r.data, outr_dim[0] * outr_dim[1]) };
|
||||||
|
|
||||||
|
// Refer to https://github.com/dimforge/nalgebra/issues/735
|
||||||
|
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
|
||||||
|
let res = matrix1.qr();
|
||||||
|
let (q, r) = res.unpack();
|
||||||
|
|
||||||
|
// Uses different algo need to match numpy
|
||||||
|
out_q_slice.copy_from_slice(q.transpose().as_slice());
|
||||||
|
out_r_slice.copy_from_slice(r.transpose().as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_svd(
|
||||||
|
mat1: *mut InputMatrix,
|
||||||
|
outu: *mut InputMatrix,
|
||||||
|
outs: *mut InputMatrix,
|
||||||
|
outvh: *mut InputMatrix,
|
||||||
|
) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let outu = outu.as_mut().unwrap();
|
||||||
|
let outs = outs.as_mut().unwrap();
|
||||||
|
let outvh = outvh.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
report_error("ValueError", "np_linalg_svd", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
let outu_dim = (*outu).get_dims();
|
||||||
|
let outs_dim = (*outs).get_dims();
|
||||||
|
let outvh_dim = (*outvh).get_dims();
|
||||||
|
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
let out_u_slice = unsafe { slice::from_raw_parts_mut(outu.data, outu_dim[0] * outu_dim[1]) };
|
||||||
|
let out_s_slice = unsafe { slice::from_raw_parts_mut(outs.data, outs_dim[0]) };
|
||||||
|
let out_vh_slice =
|
||||||
|
unsafe { slice::from_raw_parts_mut(outvh.data, outvh_dim[0] * outvh_dim[1]) };
|
||||||
|
|
||||||
|
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
let result = matrix.svd(true, true);
|
||||||
|
out_u_slice.copy_from_slice(result.u.unwrap().transpose().as_slice());
|
||||||
|
out_s_slice.copy_from_slice(result.singular_values.as_slice());
|
||||||
|
out_vh_slice.copy_from_slice(result.v_t.unwrap().transpose().as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
report_error("ValueError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
|
||||||
|
if dim1[0] != dim1[1] {
|
||||||
|
let err_msg =
|
||||||
|
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
|
||||||
|
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let outdim = out.get_dims();
|
||||||
|
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
|
||||||
|
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
if !matrix.is_invertible() {
|
||||||
|
report_error(
|
||||||
|
"LinAlgError",
|
||||||
|
"np_linalg_inv",
|
||||||
|
file!(),
|
||||||
|
line!(),
|
||||||
|
column!(),
|
||||||
|
"no inverse for Singular Matrix",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let inv = matrix.try_inverse().unwrap();
|
||||||
|
out_slice.copy_from_slice(inv.transpose().as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
report_error("ValueError", "np_linalg_pinv", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
let outdim = out.get_dims();
|
||||||
|
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
|
||||||
|
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
let svd = matrix.svd(true, true);
|
||||||
|
let inv = svd.pseudo_inverse(1e-15);
|
||||||
|
|
||||||
|
match inv {
|
||||||
|
Ok(m) => {
|
||||||
|
out_slice.copy_from_slice(m.transpose().as_slice());
|
||||||
|
}
|
||||||
|
Err(err_msg) => {
|
||||||
|
report_error("LinAlgError", "np_linalg_pinv", file!(), line!(), column!(), err_msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_matrix_power(
|
||||||
|
mat1: *mut InputMatrix,
|
||||||
|
mat2: *mut InputMatrix,
|
||||||
|
out: *mut InputMatrix,
|
||||||
|
) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let mat2 = mat2.as_mut().unwrap();
|
||||||
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D", mat1.ndims);
|
||||||
|
report_error("ValueError", "np_linalg_matrix_power", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
|
||||||
|
let power = power[0];
|
||||||
|
let outdim = out.get_dims();
|
||||||
|
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
|
||||||
|
let abs_pow = power.abs();
|
||||||
|
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
let mut result = matrix1.pow(abs_pow as u32);
|
||||||
|
|
||||||
|
if power < 0.0 {
|
||||||
|
if !result.is_invertible() {
|
||||||
|
report_error(
|
||||||
|
"LinAlgError",
|
||||||
|
"np_linalg_inv",
|
||||||
|
file!(),
|
||||||
|
line!(),
|
||||||
|
column!(),
|
||||||
|
"no inverse for Singular Matrix",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
result = result.try_inverse().unwrap();
|
||||||
|
}
|
||||||
|
out_slice.copy_from_slice(result.transpose().as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let out = out.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
report_error("ValueError", "np_linalg_det", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
|
||||||
|
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
if !matrix.is_square() {
|
||||||
|
let err_msg =
|
||||||
|
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
|
||||||
|
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
out_slice[0] = matrix.determinant();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn sp_linalg_lu(
|
||||||
|
mat1: *mut InputMatrix,
|
||||||
|
out_l: *mut InputMatrix,
|
||||||
|
out_u: *mut InputMatrix,
|
||||||
|
) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let out_l = out_l.as_mut().unwrap();
|
||||||
|
let out_u = out_u.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
report_error("ValueError", "sp_linalg_lu", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
let outl_dim = (*out_l).get_dims();
|
||||||
|
let outu_dim = (*out_u).get_dims();
|
||||||
|
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l.data, outl_dim[0] * outl_dim[1]) };
|
||||||
|
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u.data, outu_dim[0] * outu_dim[1]) };
|
||||||
|
|
||||||
|
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
let (_, l, u) = matrix.lu().unpack();
|
||||||
|
|
||||||
|
out_l_slice.copy_from_slice(l.transpose().as_slice());
|
||||||
|
out_u_slice.copy_from_slice(u.transpose().as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn sp_linalg_schur(
|
||||||
|
mat1: *mut InputMatrix,
|
||||||
|
out_t: *mut InputMatrix,
|
||||||
|
out_z: *mut InputMatrix,
|
||||||
|
) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let out_t = out_t.as_mut().unwrap();
|
||||||
|
let out_z = out_z.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
report_error("ValueError", "sp_linalg_schur", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
|
||||||
|
if dim1[0] != dim1[1] {
|
||||||
|
let err_msg =
|
||||||
|
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
|
||||||
|
report_error("LinAlgError", "np_linalg_schur", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let out_t_dim = (*out_t).get_dims();
|
||||||
|
let out_z_dim = (*out_z).get_dims();
|
||||||
|
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t.data, out_t_dim[0] * out_t_dim[1]) };
|
||||||
|
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z.data, out_z_dim[0] * out_z_dim[1]) };
|
||||||
|
|
||||||
|
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
let (z, t) = matrix.schur().unpack();
|
||||||
|
|
||||||
|
out_t_slice.copy_from_slice(t.transpose().as_slice());
|
||||||
|
out_z_slice.copy_from_slice(z.transpose().as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn sp_linalg_hessenberg(
|
||||||
|
mat1: *mut InputMatrix,
|
||||||
|
out_h: *mut InputMatrix,
|
||||||
|
out_q: *mut InputMatrix,
|
||||||
|
) {
|
||||||
|
let mat1 = mat1.as_mut().unwrap();
|
||||||
|
let out_h = out_h.as_mut().unwrap();
|
||||||
|
let out_q = out_q.as_mut().unwrap();
|
||||||
|
|
||||||
|
if mat1.ndims != 2 {
|
||||||
|
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
|
||||||
|
report_error("ValueError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dim1 = (*mat1).get_dims();
|
||||||
|
|
||||||
|
if dim1[0] != dim1[1] {
|
||||||
|
let err_msg =
|
||||||
|
format!("last 2 dimensions of the array must be square: {} != {}", dim1[0], dim1[1]);
|
||||||
|
report_error("LinAlgError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let out_h_dim = (*out_h).get_dims();
|
||||||
|
let out_q_dim = (*out_q).get_dims();
|
||||||
|
|
||||||
|
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||||
|
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h.data, out_h_dim[0] * out_h_dim[1]) };
|
||||||
|
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, out_q_dim[0] * out_q_dim[1]) };
|
||||||
|
|
||||||
|
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||||
|
let (q, h) = matrix.hessenberg().unpack();
|
||||||
|
|
||||||
|
out_h_slice.copy_from_slice(h.transpose().as_slice());
|
||||||
|
out_q_slice.copy_from_slice(q.transpose().as_slice());
|
||||||
|
}
|
|
@ -2,6 +2,9 @@
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
: "${DEMO_LINALG_STUB:=linalg/target/release/liblinalg.a}"
|
||||||
|
: "${DEMO_LINALG_STUB32:=linalg/target/i686-unknown-linux-gnu/release/liblinalg.a}"
|
||||||
|
|
||||||
if [ -z "$1" ]; then
|
if [ -z "$1" ]; then
|
||||||
echo "No argument supplied"
|
echo "No argument supplied"
|
||||||
exit 1
|
exit 1
|
||||||
|
@ -11,7 +14,7 @@ declare -a nac3args
|
||||||
while [ $# -ge 1 ]; do
|
while [ $# -ge 1 ]; do
|
||||||
case "$1" in
|
case "$1" in
|
||||||
--help)
|
--help)
|
||||||
echo "Usage: run_demo.sh [--help] [--out OUTFILE] [--debug] [-i386] -- [NAC3ARGS...]"
|
echo "Usage: run_demo.sh [--help] [--out OUTFILE] [--debug] [-i686] -- [NAC3ARGS...]"
|
||||||
exit
|
exit
|
||||||
;;
|
;;
|
||||||
--out)
|
--out)
|
||||||
|
@ -21,8 +24,8 @@ while [ $# -ge 1 ]; do
|
||||||
--debug)
|
--debug)
|
||||||
debug=1
|
debug=1
|
||||||
;;
|
;;
|
||||||
-i386)
|
-i686)
|
||||||
i386=1
|
i686=1
|
||||||
;;
|
;;
|
||||||
--)
|
--)
|
||||||
shift
|
shift
|
||||||
|
@ -51,18 +54,14 @@ fi
|
||||||
|
|
||||||
rm -f ./*.o ./*.bc demo
|
rm -f ./*.o ./*.bc demo
|
||||||
|
|
||||||
if [ -z "$i386" ]; then
|
if [ -z "$i686" ]; then
|
||||||
$nac3standalone "${nac3args[@]}"
|
$nac3standalone "${nac3args[@]}"
|
||||||
|
|
||||||
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
|
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
|
||||||
clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o
|
clang -o demo module.o demo.o $DEMO_LINALG_STUB -lm -Wl,--no-warn-search-mismatch
|
||||||
else
|
else
|
||||||
# Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations
|
$nac3standalone --triple i686-unknown-linux-gnu "${nac3args[@]}"
|
||||||
|
|
||||||
$nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}"
|
|
||||||
|
|
||||||
clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
|
clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
|
||||||
clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o
|
clang -m32 -o demo module.o demo.o $DEMO_LINALG_STUB32 -lm -Wl,--no-warn-search-mismatch
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -z "$outfile" ]; then
|
if [ -z "$outfile" ]; then
|
||||||
|
|
|
@ -1429,6 +1429,142 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
||||||
output_ndarray_float_2(nextafter_x_zeros)
|
output_ndarray_float_2(nextafter_x_zeros)
|
||||||
output_ndarray_float_2(nextafter_x_ones)
|
output_ndarray_float_2(nextafter_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_transpose():
|
||||||
|
x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]])
|
||||||
|
y = np_transpose(x)
|
||||||
|
z = np_transpose(y)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_reshape():
|
||||||
|
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
|
||||||
|
x = np_reshape(w, (1, 2, 1, -1))
|
||||||
|
y = np_reshape(x, [2, -1])
|
||||||
|
z = np_reshape(y, 10)
|
||||||
|
|
||||||
|
x1: ndarray[int32, 1] = np_array([1, 2, 3, 4])
|
||||||
|
x2: ndarray[int32, 2] = np_reshape(x1, (2, 2))
|
||||||
|
|
||||||
|
output_ndarray_float_1(w)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
output_ndarray_float_1(z)
|
||||||
|
|
||||||
|
def test_ndarray_dot():
|
||||||
|
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
|
||||||
|
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
|
||||||
|
z1 = np_dot(x1, y1)
|
||||||
|
|
||||||
|
x2: ndarray[int32, 1] = np_array([5, 1, 4, 2])
|
||||||
|
y2: ndarray[int32, 1] = np_array([5, 1, 6, 6])
|
||||||
|
z2 = np_dot(x2, y2)
|
||||||
|
|
||||||
|
x3: ndarray[bool, 1] = np_array([True, True, True, True])
|
||||||
|
y3: ndarray[bool, 1] = np_array([True, True, True, True])
|
||||||
|
z3 = np_dot(x3, y3)
|
||||||
|
|
||||||
|
z4 = np_dot(2, 3)
|
||||||
|
z5 = np_dot(2., 3.)
|
||||||
|
z6 = np_dot(True, False)
|
||||||
|
|
||||||
|
output_float64(z1)
|
||||||
|
output_int32(z2)
|
||||||
|
output_bool(z3)
|
||||||
|
output_int32(z4)
|
||||||
|
output_float64(z5)
|
||||||
|
output_bool(z6)
|
||||||
|
|
||||||
|
def test_ndarray_cholesky():
|
||||||
|
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
||||||
|
y = np_linalg_cholesky(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_qr():
|
||||||
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||||
|
y, z = np_linalg_qr(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
|
||||||
|
# QR Factorization is not unique and gives different results in numpy and nalgebra
|
||||||
|
# Reverting the decomposition to compare the initial arrays
|
||||||
|
a = y @ z
|
||||||
|
output_ndarray_float_2(a)
|
||||||
|
|
||||||
|
def test_ndarray_linalg_inv():
|
||||||
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||||
|
y = np_linalg_inv(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_pinv():
|
||||||
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
|
||||||
|
y = np_linalg_pinv(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_matrix_power():
|
||||||
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||||
|
y = np_linalg_matrix_power(x, -9)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_det():
|
||||||
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||||
|
y = np_linalg_det(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_float64(y)
|
||||||
|
|
||||||
|
def test_ndarray_schur():
|
||||||
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||||
|
t, z = sp_linalg_schur(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
|
||||||
|
# Schur Factorization is not unique and gives different results in scipy and nalgebra
|
||||||
|
# Reverting the decomposition to compare the initial arrays
|
||||||
|
a = (z @ t) @ np_linalg_inv(z)
|
||||||
|
output_ndarray_float_2(a)
|
||||||
|
|
||||||
|
def test_ndarray_hessenberg():
|
||||||
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 5.0, 8.5]])
|
||||||
|
h, q = sp_linalg_hessenberg(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
|
||||||
|
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
|
||||||
|
# Reverting the decomposition to compare the initial arrays
|
||||||
|
a = (q @ h) @ np_linalg_inv(q)
|
||||||
|
output_ndarray_float_2(a)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ndarray_lu():
|
||||||
|
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
|
||||||
|
l, u = sp_linalg_lu(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(l)
|
||||||
|
output_ndarray_float_2(u)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ndarray_svd():
|
||||||
|
w: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||||
|
x, y, z = np_linalg_svd(w)
|
||||||
|
|
||||||
|
output_ndarray_float_2(w)
|
||||||
|
|
||||||
|
# SVD Factorization is not unique and gives different results in numpy and nalgebra
|
||||||
|
# Reverting the decomposition to compare the initial arrays
|
||||||
|
a = x @ z
|
||||||
|
output_ndarray_float_2(a)
|
||||||
|
output_ndarray_float_1(y)
|
||||||
|
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
test_ndarray_ctor()
|
test_ndarray_ctor()
|
||||||
test_ndarray_empty()
|
test_ndarray_empty()
|
||||||
|
@ -1607,5 +1743,18 @@ def run() -> int32:
|
||||||
test_ndarray_nextafter_broadcast()
|
test_ndarray_nextafter_broadcast()
|
||||||
test_ndarray_nextafter_broadcast_lhs_scalar()
|
test_ndarray_nextafter_broadcast_lhs_scalar()
|
||||||
test_ndarray_nextafter_broadcast_rhs_scalar()
|
test_ndarray_nextafter_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_transpose()
|
||||||
|
test_ndarray_reshape()
|
||||||
|
|
||||||
|
test_ndarray_dot()
|
||||||
|
test_ndarray_cholesky()
|
||||||
|
test_ndarray_qr()
|
||||||
|
test_ndarray_svd()
|
||||||
|
test_ndarray_linalg_inv()
|
||||||
|
test_ndarray_pinv()
|
||||||
|
test_ndarray_matrix_power()
|
||||||
|
test_ndarray_det()
|
||||||
|
test_ndarray_lu()
|
||||||
|
test_ndarray_schur()
|
||||||
|
test_ndarray_hessenberg()
|
||||||
return 0
|
return 0
|
||||||
|
|
Loading…
Reference in New Issue