forked from M-Labs/nac3
[core] codegen: Implement matrix_power
Last of the functions that need to be ported over to strided-ndarray.
This commit is contained in:
parent
27a6f47330
commit
a00eb7969e
@ -1,6 +1,6 @@
|
||||
use inkwell::{
|
||||
types::BasicTypeEnum,
|
||||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
||||
values::{BasicValueEnum, IntValue, PointerValue},
|
||||
FloatPredicate, IntPredicate, OptimizationLevel,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
@ -17,7 +17,7 @@ use super::{
|
||||
types::ndarray::NDArrayType,
|
||||
values::{
|
||||
ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||
UntypedArrayLikeAccessor,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
@ -2165,58 +2165,49 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||
|
||||
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||
};
|
||||
|
||||
let n1 = llvm_ndarray_ty.map_value(n1, None);
|
||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||
let n2_array = numpy::create_ndarray_const_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&[llvm_usize.const_int(1, false)],
|
||||
)
|
||||
.unwrap();
|
||||
unsafe {
|
||||
n2_array.data().set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_zero(),
|
||||
n2.as_basic_value_enum(),
|
||||
);
|
||||
};
|
||||
let n2_array = n2_array.as_base_value().as_basic_value_enum();
|
||||
|
||||
let outdim0 = unsafe {
|
||||
n1.shape()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let outdim1 = unsafe {
|
||||
n1.shape()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
|
||||
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
|
||||
|
||||
Ok(out)
|
||||
} else {
|
||||
let BasicValueEnum::PointerValue(x1) = x1 else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
};
|
||||
|
||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
|
||||
|
||||
if !x1.get_type().element_type().is_float_type() {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
}
|
||||
|
||||
// x2 is a float, but we are promoting this to a 1D ndarray (.shape == [1]) for uniformity in function call.
|
||||
let x2 = call_float(generator, ctx, (x2_ty, x2))?;
|
||||
let BasicValueEnum::FloatValue(x2) = x2 else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
};
|
||||
|
||||
let x2 = NDArrayType::new_unsized(generator, ctx.ctx, ctx.ctx.f64_type().into())
|
||||
.construct_unsized(generator, ctx, &x2, None); // x2.shape == []
|
||||
let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1]
|
||||
|
||||
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
|
||||
.construct_uninitialized(generator, ctx, None);
|
||||
out.copy_shape_from_ndarray(generator, ctx, x1);
|
||||
unsafe { out.create_data(generator, ctx) };
|
||||
|
||||
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
|
||||
let x2_c = x2.make_contiguous_ndarray(generator, ctx);
|
||||
let out_c = out.make_contiguous_ndarray(generator, ctx);
|
||||
|
||||
extern_fns::call_np_linalg_matrix_power(
|
||||
ctx,
|
||||
x1_c.as_base_value().into(),
|
||||
x2_c.as_base_value().into(),
|
||||
out_c.as_base_value().into(),
|
||||
None,
|
||||
);
|
||||
|
||||
Ok(out.as_base_value().into())
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_det` linalg function
|
||||
|
@ -1764,7 +1764,7 @@ def run() -> int32:
|
||||
test_ndarray_svd()
|
||||
test_ndarray_linalg_inv()
|
||||
test_ndarray_pinv()
|
||||
# test_ndarray_matrix_power()
|
||||
test_ndarray_matrix_power()
|
||||
test_ndarray_det()
|
||||
test_ndarray_lu()
|
||||
test_ndarray_schur()
|
||||
|
Loading…
Reference in New Issue
Block a user