[core] codegen: Implement matrix_power

Last of the functions that need to be ported over to strided-ndarray.
This commit is contained in:
David Mak 2024-12-12 11:19:12 +08:00
parent 27a6f47330
commit a00eb7969e
2 changed files with 42 additions and 51 deletions

View File

@ -1,6 +1,6 @@
use inkwell::{ use inkwell::{
types::BasicTypeEnum, types::BasicTypeEnum,
values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
FloatPredicate, IntPredicate, OptimizationLevel, FloatPredicate, IntPredicate, OptimizationLevel,
}; };
use itertools::Itertools; use itertools::Itertools;
@ -17,7 +17,7 @@ use super::{
types::ndarray::NDArrayType, types::ndarray::NDArrayType,
values::{ values::{
ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator, UntypedArrayLikeAccessor,
}, },
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
@ -2165,58 +2165,49 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power"; const FN_NAME: &str = "np_linalg_matrix_power";
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else { let BasicValueEnum::PointerValue(x1) = x1 else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
};
let n1 = llvm_ndarray_ty.map_value(n1, None);
// Changing second parameter to a `NDArray` for uniformity in function call
let n2_array = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
unsafe {
n2_array.data().set_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
n2.as_basic_value_enum(),
);
};
let n2_array = n2_array.as_base_value().as_basic_value_enum();
let outdim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let outdim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
};
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
} }
// x2 is a float, but we are promoting this to a 1D ndarray (.shape == [1]) for uniformity in function call.
let x2 = call_float(generator, ctx, (x2_ty, x2))?;
let BasicValueEnum::FloatValue(x2) = x2 else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
};
let x2 = NDArrayType::new_unsized(generator, ctx.ctx, ctx.ctx.f64_type().into())
.construct_unsized(generator, ctx, &x2, None); // x2.shape == []
let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1]
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
.construct_uninitialized(generator, ctx, None);
out.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { out.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let x2_c = x2.make_contiguous_ndarray(generator, ctx);
let out_c = out.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_matrix_power(
ctx,
x1_c.as_base_value().into(),
x2_c.as_base_value().into(),
out_c.as_base_value().into(),
None,
);
Ok(out.as_base_value().into())
} }
/// Invokes the `np_linalg_det` linalg function /// Invokes the `np_linalg_det` linalg function

View File

@ -1764,7 +1764,7 @@ def run() -> int32:
test_ndarray_svd() test_ndarray_svd()
test_ndarray_linalg_inv() test_ndarray_linalg_inv()
test_ndarray_pinv() test_ndarray_pinv()
# test_ndarray_matrix_power() test_ndarray_matrix_power()
test_ndarray_det() test_ndarray_det()
test_ndarray_lu() test_ndarray_lu()
test_ndarray_schur() test_ndarray_schur()