Compare commits

...

11 Commits

21 changed files with 1414 additions and 197 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
__pycache__ __pycache__
/target /target
/nac3standalone/demo/linalg/target
nix/windows/msys2 nix/windows/msys2

View File

@ -3,7 +3,9 @@ use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools; use itertools::Itertools;
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor}; use crate::codegen::classes::{
NDArrayValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
};
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
@ -1865,84 +1867,7 @@ fn build_output_struct<'ctx>(
out_ptr out_ptr
} }
/// Invokes the `np_dot` using `nalgebra` crate /// Invokes the `np_linalg_cholesky` linalg function
pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_dot";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
if let (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) = (x1, x2) {
let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty);
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
Ok(extern_fns::call_np_dot(ctx, x1, x2, None).into())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
/// Invokes the `np_linalg_matmul` using `nalgebra` crate
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matmul";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
let outdim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let outdim1 = unsafe {
n2.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_matmul(ctx, x1, x2, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
/// Invokes the `np_linalg_cholesky` using `nalgebra` crate
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -1957,7 +1882,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -1984,7 +1909,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `np_linalg_qr` using `nalgebra` crate /// Invokes the `np_linalg_qr` linalg function
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2034,7 +1959,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `np_linalg_svd` using `nalgebra` crate /// Invokes the `np_linalg_svd` linalg function
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2049,7 +1974,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -2089,7 +2014,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `np_linalg_inv` using `nalgebra` crate /// Invokes the `np_linalg_inv` linalg function
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2104,7 +2029,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -2131,7 +2056,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `np_linalg_pinv` using `nalgebra` crate /// Invokes the `np_linalg_pinv` linalg function
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2146,7 +2071,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -2174,7 +2099,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `sp_linalg_lu` using `nalgebra` crate /// Invokes the `sp_linalg_lu` linalg function
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2189,7 +2114,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -2224,7 +2149,105 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `sp_linalg_schur` using `nalgebra` crate /// Invokes the `np_linalg_matrix_power` linalg function
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
let llvm_usize = generator.get_size_type(ctx.ctx);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
// Changing second parameter to a `NDArray` for uniformity in function call
let n2_array = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
unsafe {
n2_array.data().set_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
n2.as_basic_value_enum(),
);
};
let n2_array = n2_array.as_base_value().as_basic_value_enum();
let outdim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let outdim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
/// Invokes the `np_linalg_det` linalg function
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(_) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
// Changing second parameter to a `NDArray` for uniformity in function call
let out = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
let res =
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
Ok(res)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `sp_linalg_schur` linalg function
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2239,7 +2262,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -2267,7 +2290,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `sp_linalg_hessenberg` using `nalgebra` crate /// Invokes the `sp_linalg_hessenberg` linalg function
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2282,7 +2305,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);

View File

@ -179,42 +179,13 @@ macro_rules! generate_linalg_extern_fn {
}; };
} }
generate_linalg_extern_fn!(call_np_linalg_matmul, "np_linalg_matmul", 3);
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2); generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3); generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4); generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2); generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2); generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3); generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3); generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3); generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
/// Invokes the linalg `np_dot` function.
pub fn call_np_dot<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
mat1: BasicValueEnum<'ctx>,
mat2: BasicValueEnum<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "np_dot";
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type =
ctx.ctx.f64_type().fn_type(&[mat1.get_type().into(), mat2.get_type().into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[mat1.into(), mat2.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -26,12 +26,15 @@ use crate::{
typedef::{FunSignature, Type, TypeEnum}, typedef::{FunSignature, Type, TypeEnum},
}, },
}; };
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
use inkwell::{ use inkwell::{
types::BasicType, types::BasicType,
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use inkwell::{
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
values::BasicValue,
};
use nac3parser::ast::{Operator, StrRef}; use nac3parser::ast::{Operator, StrRef};
/// Creates an uninitialized `NDArray` instance. /// Creates an uninitialized `NDArray` instance.
@ -2026,3 +2029,493 @@ pub fn gen_ndarray_fill<'ctx>(
Ok(()) Ok(())
} }
/// Generates LLVM IR for `ndarray.transpose`.
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_transpose";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
// Dimensions are reversed in the transposed array
let out = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&n1,
|_, ctx, n| Ok(n.load_ndims(ctx)),
|generator, ctx, n, idx| {
let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap();
let new_idx = ctx
.builder
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
.unwrap();
unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) }
},
)
.unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
ctx.builder.build_store(rem_idx, idx).unwrap();
// Incrementally calculate the new index in the transposed array
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
// The formula used for indexing is:
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1.load_ndims(ctx), false),
|generator, ctx, _, ndim| {
let ndim_rev =
ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap();
let ndim_rev = ctx
.builder
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
.unwrap();
let dim = unsafe {
n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None)
};
let rem_idx_val =
ctx.builder.build_load(rem_idx, "").unwrap().into_int_value();
let new_idx_val =
ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
let add_component =
ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap();
let rem_idx_val =
ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap();
let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap();
let new_idx_val =
ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap();
ctx.builder.build_store(rem_idx, rem_idx_val).unwrap();
ctx.builder.build_store(new_idx, new_idx_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
///
/// * `x1` - `NDArray` to reshape.
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
/// Note that unlike other generating functions, one of the dimesions in the shape can be negative
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
shape: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_reshape";
let (x1_ty, x1) = x1;
let (_, shape) = shape;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
let out = match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
{
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
// Check for -1 in dimensions
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(shape_list.load_size(ctx, None), false),
|generator, ctx, _, idx| {
let ele =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
ele,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_neg_value =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_neg_value = ctx
.builder
.build_int_add(
num_neg_value,
llvm_usize.const_int(1, false),
"",
)
.unwrap();
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
Ok(None)
},
|_, ctx| {
let acc_value =
ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_value =
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
ctx.builder.build_store(acc, acc_value).unwrap();
Ok(None)
},
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
// Generate the output shape by filling -1 with `rem`
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape_list,
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|generator, ctx, shape_list, idx| {
let dim =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value())
},
)
}
BasicValueEnum::StructValue(shape_tuple) => {
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
let ndims = shape_tuple.get_type().count_fields();
// Check for -1 in dims
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_negs =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_negs = ctx
.builder
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
.unwrap();
ctx.builder.build_store(num_neg, num_negs).unwrap();
Ok(None)
},
|_, ctx| {
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(None)
},
)?;
}
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
let mut shape = Vec::with_capacity(ndims as usize);
// Reconstruct shape filling negatives with rem
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
let dim = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
}
BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
let shape_int = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
shape_int,
llvm_usize.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(n_sz)),
|_, ctx| {
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
},
)?
.unwrap()
.into_int_value();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
_ => unreachable!(),
}
.unwrap();
// Only allow one dimension to be negative
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"can only specify one unknown dimension",
[None, None, None],
ctx.current_loc,
);
// The new shape must be compatible with the old shape
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
"0:ValueError",
"cannot reshape array of size {0} into provided shape of size {1}",
[Some(n_sz), Some(out_sz), None],
ctx.current_loc,
);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
/// Generates LLVM IR for `ndarray.dot`.
/// Calculate inner product of two vectors or literals
/// For matrix multiplication use `np_matmul`
///
/// The input `NDArray` are flattened and treated as 1D
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_dot";
let (x1_ty, x1) = x1;
let (_, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
match (x1, x2) {
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
"0:ValueError",
"shapes ({0}), ({1}) not aligned",
[Some(n1_sz), Some(n2_sz), None],
ctx.current_loc,
);
let identity =
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1_sz, false),
|generator, ctx, _, idx| {
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
let product = match elem1 {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_mul(e1, elem2.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_mul(e1, elem2.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
};
let acc_val = ctx.builder.build_load(acc, "").unwrap();
let acc_val = match acc_val {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_add(e1, product.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_add(e1, product.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
};
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap();
Ok(acc_val)
}
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
_ => unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
),
}
}

View File

@ -557,13 +557,18 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpHypot | PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
self.build_np_sp_ndarray_function(prim)
}
PrimDef::FunNpDot PrimDef::FunNpDot
| PrimDef::FunNpLinalgMatmul
| PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgQr | PrimDef::FunNpLinalgQr
| PrimDef::FunNpLinalgSvd | PrimDef::FunNpLinalgSvd
| PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgInv
| PrimDef::FunNpLinalgPinv | PrimDef::FunNpLinalgPinv
| PrimDef::FunNpLinalgMatrixPower
| PrimDef::FunNpLinalgDet
| PrimDef::FunSpLinalgLu | PrimDef::FunSpLinalgLu
| PrimDef::FunSpLinalgSchur | PrimDef::FunSpLinalgSchur
| PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim), | PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim),
@ -1885,6 +1890,57 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
/// Build np/sp functions that take as input `NDArray` only
fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
match prim {
PrimDef::FunNpTranspose => {
let ndarray_ty = self.unifier.get_fresh_var_with_range(
&[self.ndarray_num_ty],
Some("T".into()),
None,
);
create_fn_by_codegen(
self.unifier,
&into_var_map([ndarray_ty]),
prim.name(),
ndarray_ty.ty,
&[(ndarray_ty.ty, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
}),
)
}
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
// the `param_ty` for `create_fn_by_codegen`.
//
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
PrimDef::FunNpReshape => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_num_ty,
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
),
_ => unreachable!(),
}
}
/// Build `np_linalg` and `sp_linalg` functions /// Build `np_linalg` and `sp_linalg` functions
/// ///
/// The input to these functions must be floating point `NDArray` /// The input to these functions must be floating point `NDArray`
@ -1893,12 +1949,13 @@ impl<'a> BuiltinBuilder<'a> {
prim, prim,
&[ &[
PrimDef::FunNpDot, PrimDef::FunNpDot,
PrimDef::FunNpLinalgMatmul,
PrimDef::FunNpLinalgCholesky, PrimDef::FunNpLinalgCholesky,
PrimDef::FunNpLinalgQr, PrimDef::FunNpLinalgQr,
PrimDef::FunNpLinalgSvd, PrimDef::FunNpLinalgSvd,
PrimDef::FunNpLinalgInv, PrimDef::FunNpLinalgInv,
PrimDef::FunNpLinalgPinv, PrimDef::FunNpLinalgPinv,
PrimDef::FunNpLinalgMatrixPower,
PrimDef::FunNpLinalgDet,
PrimDef::FunSpLinalgLu, PrimDef::FunSpLinalgLu,
PrimDef::FunSpLinalgSchur, PrimDef::FunSpLinalgSchur,
PrimDef::FunSpLinalgHessenberg, PrimDef::FunSpLinalgHessenberg,
@ -1910,7 +1967,7 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier, self.unifier,
&self.num_or_ndarray_var_map, &self.num_or_ndarray_var_map,
prim.name(), prim.name(),
self.primitives.float, self.num_ty.ty,
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")], &[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty; let x1_ty = fun.0.args[0].ty;
@ -1918,33 +1975,7 @@ impl<'a> BuiltinBuilder<'a> {
let x2_ty = fun.0.args[1].ty; let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_dot( Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgMatmul => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_linalg_matmul(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}), }),
), ),
@ -2022,10 +2053,39 @@ impl<'a> BuiltinBuilder<'a> {
}), }),
) )
} }
_ => { PrimDef::FunNpLinalgMatrixPower => create_fn_by_codegen(
println!("{:?}", prim.name()); self.unifier,
unreachable!() &VarMap::new(),
} prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1"), (self.primitives.int32, "power")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_linalg_matrix_power(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgDet => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.primitives.float,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
Ok(Some(builtin_fns::call_np_linalg_det(generator, ctx, (x1_ty, x1_val))?))
}),
),
_ => unreachable!(),
} }
} }

View File

@ -99,14 +99,18 @@ pub enum PrimDef {
FunNpLdExp, FunNpLdExp,
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
FunNpTranspose,
FunNpReshape,
// Linalg functions
FunNpDot, FunNpDot,
FunNpLinalgMatmul,
FunNpLinalgCholesky, FunNpLinalgCholesky,
FunNpLinalgQr, FunNpLinalgQr,
FunNpLinalgSvd, FunNpLinalgSvd,
FunNpLinalgInv, FunNpLinalgInv,
FunNpLinalgPinv, FunNpLinalgPinv,
FunNpLinalgMatrixPower,
FunNpLinalgDet,
FunSpLinalgLu, FunSpLinalgLu,
FunSpLinalgSchur, FunSpLinalgSchur,
FunSpLinalgHessenberg, FunSpLinalgHessenberg,
@ -281,13 +285,18 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunNpTranspose => fun("np_transpose", None),
PrimDef::FunNpReshape => fun("np_reshape", None),
// Linalg functions
PrimDef::FunNpDot => fun("np_dot", None), PrimDef::FunNpDot => fun("np_dot", None),
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None), PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None), PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None), PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None), PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None), PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
PrimDef::FunNpLinalgMatrixPower => fun("np_linalg_matrix_power", None),
PrimDef::FunNpLinalgDet => fun("np_linalg_det", None),
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None), PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None), PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None), PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n",
] ]

View File

@ -1130,6 +1130,44 @@ impl<'a> Inferencer<'a> {
})); }));
} }
if id == &"np_dot".into() {
let arg0 = self.fold_expr(args.remove(0))?;
let arg1 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap();
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
ndarray_dtype
} else {
arg0_ty
};
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
FuncArg { name: "x2".into(), ty: arg1.custom.unwrap(), default_value: None },
],
ret,
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![arg0, arg1],
keywords: vec![],
},
}));
}
if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
let arg0 = self.fold_expr(args.remove(0))?; let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap(); let arg0_ty = arg0.custom.unwrap();
@ -1389,7 +1427,45 @@ impl<'a> Inferencer<'a> {
}, },
})); }));
} }
// 2-argument ndarray n-dimensional factory functions
if id == &"np_reshape".into() && args.len() == 2 {
let arg0 = self.fold_expr(args.remove(0))?;
let shape_expr = args.remove(0);
let (ndims, shape) =
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let (elem_ty, _) = unpack_ndarray_var_tys(self.unifier, arg0.custom.unwrap());
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(elem_ty), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
FuncArg {
name: "shape".into(),
ty: shape.custom.unwrap(),
default_value: None,
},
],
ret,
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![arg0, shape],
keywords: vec![],
},
}));
}
// 2-argument ndarray n-dimensional creation functions // 2-argument ndarray n-dimensional creation functions
if id == &"np_full".into() && args.len() == 2 { if id == &"np_full".into() && args.len() == 2 {
let ExprKind::List { elts, .. } = &args[0].node else { let ExprKind::List { elts, .. } = &args[0].node else {

View File

@ -5,8 +5,8 @@ import importlib.util
import importlib.machinery import importlib.machinery
import math import math
import numpy as np import numpy as np
import scipy as sp
import numpy.typing as npt import numpy.typing as npt
import scipy as sp
import pathlib import pathlib
from numpy import int32, int64, uint32, uint64 from numpy import int32, int64, uint32, uint64
@ -218,6 +218,8 @@ def patch(module):
module.np_ldexp = np.ldexp module.np_ldexp = np.ldexp
module.np_hypot = np.hypot module.np_hypot = np.hypot
module.np_nextafter = np.nextafter module.np_nextafter = np.nextafter
module.np_transpose = np.transpose
module.np_reshape = np.reshape
# SciPy Math functions # SciPy Math functions
module.sp_spec_erf = special.erf module.sp_spec_erf = special.erf
@ -229,12 +231,13 @@ def patch(module):
# Linalg functions # Linalg functions
module.np_dot = np.dot module.np_dot = np.dot
module.np_linalg_matmul = np.matmul
module.np_linalg_cholesky = np.linalg.cholesky module.np_linalg_cholesky = np.linalg.cholesky
module.np_linalg_qr = np.linalg.qr module.np_linalg_qr = np.linalg.qr
module.np_linalg_svd = np.linalg.svd module.np_linalg_svd = np.linalg.svd
module.np_linalg_inv = np.linalg.inv module.np_linalg_inv = np.linalg.inv
module.np_linalg_pinv = np.linalg.pinv module.np_linalg_pinv = np.linalg.pinv
module.np_linalg_matrix_power = np.linalg.matrix_power
module.np_linalg_det = np.linalg.det
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True) module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
module.sp_linalg_schur = sp.linalg.schur module.sp_linalg_schur = sp.linalg.schur

114
nac3standalone/demo/linalg/Cargo.lock generated Normal file
View File

@ -0,0 +1,114 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]]
name = "autocfg"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "linalg"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]]
name = "nalgebra"
version = "0.32.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
dependencies = [
"approx",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "simba"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"

View File

@ -0,0 +1,13 @@
[package]
name = "linalg"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["staticlib"]
[dependencies]
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
cslice = "0.3.0"
[workspace]

View File

@ -0,0 +1,6 @@
#!/usr/bin/env bash
# Uses rustup to compile the linalg library for i386 and x86_84 architecture
nix-shell -p rustup --command "RUSTC_BOOTSTRAP=1 cargo build -Z unstable-options --target x86_64-unknown-linux-gnu --out-dir liblinalg/x86_64"
nix-shell -p rustup --command "RUSTC_BOOTSTRAP=1 RUSTFLAGS=\"-C target-cpu=i386 -C target-feature=+sse2\" cargo build -Z unstable-options --target i686-unknown-linux-gnu --out-dir liblinalg/i386"

Binary file not shown.

View File

@ -0,0 +1,406 @@
// Uses `nalgebra` crate to invoke `np_linalg` and `sp_linalg` functions
// When converting between `nalgebra::Matrix` and `NDArray` following considerations are necessary
//
// * Both `nalgebra::Matrix` and `NDArray` require their content to be stored in row-major order
// * `NDArray` data pointer can be directly read and converted to `nalgebra::Matrix` (row and column number must be known)
// * `nalgebra::Matrix::as_slice` returns the content of matrix in column-major order and initial data needs to be transposed before storing it in `NDArray` data pointer
use core::slice;
use nalgebra::DMatrix;
fn report_error(
error_name: &str,
fn_name: &str,
file_name: &str,
line_num: u32,
col_num: u32,
err_msg: &str,
) -> ! {
panic!(
"Exception {} from {} in {}:{}:{}, message: {}",
error_name, fn_name, file_name, line_num, col_num, err_msg
);
}
pub struct InputMatrix {
pub ndims: usize,
pub dims: *const usize,
pub data: *mut f64,
}
impl InputMatrix {
fn get_dims(&mut self) -> Vec<usize> {
let dims = unsafe { slice::from_raw_parts(self.dims, self.ndims) };
dims.to_vec()
}
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let result = matrix1.cholesky();
match result {
Some(res) => {
out_slice.copy_from_slice(res.unpack().transpose().as_slice());
}
None => {
report_error(
"LinAlgError",
"np_linalg_cholesky",
file!(),
line!(),
column!(),
"Matrix is not positive definite",
);
}
};
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_qr(
mat1: *mut InputMatrix,
out_q: *mut InputMatrix,
out_r: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_q = out_q.as_mut().unwrap();
let out_r = out_r.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outq_dim = (*out_q).get_dims();
let outr_dim = (*out_r).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, outq_dim[0] * outq_dim[1]) };
let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r.data, outr_dim[0] * outr_dim[1]) };
// Refer to https://github.com/dimforge/nalgebra/issues/735
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let res = matrix1.qr();
let (q, r) = res.unpack();
// Uses different algo need to match numpy
out_q_slice.copy_from_slice(q.transpose().as_slice());
out_r_slice.copy_from_slice(r.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_svd(
mat1: *mut InputMatrix,
outu: *mut InputMatrix,
outs: *mut InputMatrix,
outvh: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let outu = outu.as_mut().unwrap();
let outs = outs.as_mut().unwrap();
let outvh = outvh.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_svd", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outu_dim = (*outu).get_dims();
let outs_dim = (*outs).get_dims();
let outvh_dim = (*outvh).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(outu.data, outu_dim[0] * outu_dim[1]) };
let out_s_slice = unsafe { slice::from_raw_parts_mut(outs.data, outs_dim[0]) };
let out_vh_slice =
unsafe { slice::from_raw_parts_mut(outvh.data, outvh_dim[0] * outvh_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let result = matrix.svd(true, true);
out_u_slice.copy_from_slice(result.u.unwrap().transpose().as_slice());
out_s_slice.copy_from_slice(result.singular_values.as_slice());
out_vh_slice.copy_from_slice(result.v_t.unwrap().transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix.is_invertible() {
report_error(
"LinAlgError",
"np_linalg_inv",
file!(),
line!(),
column!(),
"no inverse for Singular Matrix",
);
}
let inv = matrix.try_inverse().unwrap();
out_slice.copy_from_slice(inv.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_pinv", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let svd = matrix.svd(true, true);
let inv = svd.pseudo_inverse(1e-15);
match inv {
Ok(m) => {
out_slice.copy_from_slice(m.transpose().as_slice());
}
Err(err_msg) => {
report_error("LinAlgError", "np_linalg_pinv", file!(), line!(), column!(), err_msg);
}
}
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matrix_power(
mat1: *mut InputMatrix,
mat2: *mut InputMatrix,
out: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D", mat1.ndims);
report_error("ValueError", "np_linalg_matrix_power", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
let power = power[0];
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let abs_pow = power.abs();
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let mut result = matrix1.pow(abs_pow as u32);
if power < 0.0 {
if !result.is_invertible() {
report_error(
"LinAlgError",
"np_linalg_inv",
file!(),
line!(),
column!(),
"no inverse for Singular Matrix",
);
}
result = result.try_inverse().unwrap();
}
out_slice.copy_from_slice(result.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_det", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix.is_square() {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
out_slice[0] = matrix.determinant();
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_lu(
mat1: *mut InputMatrix,
out_l: *mut InputMatrix,
out_u: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_l = out_l.as_mut().unwrap();
let out_u = out_u.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_lu", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outl_dim = (*out_l).get_dims();
let outu_dim = (*out_u).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l.data, outl_dim[0] * outl_dim[1]) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u.data, outu_dim[0] * outu_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (_, l, u) = matrix.lu().unpack();
out_l_slice.copy_from_slice(l.transpose().as_slice());
out_u_slice.copy_from_slice(u.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_schur(
mat1: *mut InputMatrix,
out_t: *mut InputMatrix,
out_z: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_t = out_t.as_mut().unwrap();
let out_z = out_z.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_schur", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_schur", file!(), line!(), column!(), &err_msg);
}
let out_t_dim = (*out_t).get_dims();
let out_z_dim = (*out_z).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t.data, out_t_dim[0] * out_t_dim[1]) };
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z.data, out_z_dim[0] * out_z_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (z, t) = matrix.schur().unpack();
out_t_slice.copy_from_slice(t.transpose().as_slice());
out_z_slice.copy_from_slice(z.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_hessenberg(
mat1: *mut InputMatrix,
out_h: *mut InputMatrix,
out_q: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_h = out_h.as_mut().unwrap();
let out_q = out_q.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {} != {}", dim1[0], dim1[1]);
report_error("LinAlgError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
}
let out_h_dim = (*out_h).get_dims();
let out_q_dim = (*out_q).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h.data, out_h_dim[0] * out_h_dim[1]) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, out_q_dim[0] * out_q_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (q, h) = matrix.hessenberg().unpack();
out_h_slice.copy_from_slice(h.transpose().as_slice());
out_q_slice.copy_from_slice(q.transpose().as_slice());
}

View File

@ -54,15 +54,16 @@ rm -f ./*.o ./*.bc demo
if [ -z "$i386" ]; then if [ -z "$i386" ]; then
$nac3standalone "${nac3args[@]}" $nac3standalone "${nac3args[@]}"
cd linalg && cargo build -q && cd ..
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o linalg/liblinalg/x86_64/liblinalg.a
else else
# Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations # Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations
$nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}" $nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}"
clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o linalg/liblinalg/i386/liblinalg.a
fi fi
if [ -z "$outfile" ]; then if [ -z "$outfile" ]; then

View File

@ -1429,26 +1429,50 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_dot(): def test_ndarray_transpose():
x: ndarray[float, 1] = np_array([5.0, 1.0]) x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]])
y: ndarray[float, 1] = np_array([5.0, 1.0]) y = np_transpose(x)
z = np_dot(x, y) z = np_transpose(y)
output_ndarray_float_1(x)
output_ndarray_float_1(y)
output_float64(z)
def test_ndarray_linalg_matmul():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
z = np_linalg_matmul(x, y)
m = np_argmax(z)
output_ndarray_float_2(x) output_ndarray_float_2(x)
output_ndarray_float_2(y) output_ndarray_float_2(y)
output_ndarray_float_2(z)
output_int64(m) def test_ndarray_reshape():
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
x = np_reshape(w, (1, 2, 1, -1))
y = np_reshape(x, [2, -1])
z = np_reshape(y, 10)
x1: ndarray[int32, 1] = np_array([1, 2, 3, 4])
x2: ndarray[int32, 2] = np_reshape(x1, (2, 2))
output_ndarray_float_1(w)
output_ndarray_float_2(y)
output_ndarray_float_1(z)
def test_ndarray_dot():
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
z1 = np_dot(x1, y1)
x2: ndarray[int32, 1] = np_array([5, 1, 4, 2])
y2: ndarray[int32, 1] = np_array([5, 1, 6, 6])
z2 = np_dot(x2, y2)
x3: ndarray[bool, 1] = np_array([True, True, True, True])
y3: ndarray[bool, 1] = np_array([True, True, True, True])
z3 = np_dot(x3, y3)
z4 = np_dot(2, 3)
z5 = np_dot(2., 3.)
z6 = np_dot(True, False)
output_float64(z1)
output_int32(z2)
output_bool(z3)
output_int32(z4)
output_float64(z5)
output_bool(z6)
def test_ndarray_cholesky(): def test_ndarray_cholesky():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
@ -1465,7 +1489,7 @@ def test_ndarray_qr():
# QR Factorization is not unique and gives different results in numpy and nalgebra # QR Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(y, z) a = y @ z
output_ndarray_float_2(a) output_ndarray_float_2(a)
def test_ndarray_linalg_inv(): def test_ndarray_linalg_inv():
@ -1482,6 +1506,20 @@ def test_ndarray_pinv():
output_ndarray_float_2(x) output_ndarray_float_2(x)
output_ndarray_float_2(y) output_ndarray_float_2(y)
def test_ndarray_matrix_power():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_matrix_power(x, -9)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_det():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_det(x)
output_ndarray_float_2(x)
output_float64(y)
def test_ndarray_schur(): def test_ndarray_schur():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]]) x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
t, z = sp_linalg_schur(x) t, z = sp_linalg_schur(x)
@ -1490,7 +1528,7 @@ def test_ndarray_schur():
# Schur Factorization is not unique and gives different results in scipy and nalgebra # Schur Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z)) a = (z @ t) @ np_linalg_inv(z)
output_ndarray_float_2(a) output_ndarray_float_2(a)
def test_ndarray_hessenberg(): def test_ndarray_hessenberg():
@ -1501,7 +1539,7 @@ def test_ndarray_hessenberg():
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra # Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(np_linalg_matmul(q, h), np_linalg_inv(q)) a = (q @ h) @ np_linalg_inv(q)
output_ndarray_float_2(a) output_ndarray_float_2(a)
@ -1522,7 +1560,7 @@ def test_ndarray_svd():
# SVD Factorization is not unique and gives different results in numpy and nalgebra # SVD Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays # Reverting the decomposition to compare the initial arrays
a = np_linalg_matmul(x, z) a = x @ z
output_ndarray_float_2(a) output_ndarray_float_2(a)
output_ndarray_float_1(y) output_ndarray_float_1(y)
@ -1705,14 +1743,17 @@ def run() -> int32:
test_ndarray_nextafter_broadcast() test_ndarray_nextafter_broadcast()
test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar()
test_ndarray_transpose()
test_ndarray_reshape()
test_ndarray_dot() test_ndarray_dot()
test_ndarray_linalg_matmul()
test_ndarray_cholesky() test_ndarray_cholesky()
test_ndarray_qr() test_ndarray_qr()
test_ndarray_svd() test_ndarray_svd()
test_ndarray_linalg_inv() test_ndarray_linalg_inv()
test_ndarray_pinv() test_ndarray_pinv()
test_ndarray_matrix_power()
test_ndarray_det()
test_ndarray_lu() test_ndarray_lu()
test_ndarray_schur() test_ndarray_schur()
test_ndarray_hessenberg() test_ndarray_hessenberg()