forked from M-Labs/nac3
[core] codegen: Implement matrix_power
Last of the functions that need to be ported over to strided-ndarray.
This commit is contained in:
parent
27a6f47330
commit
a00eb7969e
@ -1,6 +1,6 @@
|
|||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::BasicTypeEnum,
|
types::BasicTypeEnum,
|
||||||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
FloatPredicate, IntPredicate, OptimizationLevel,
|
FloatPredicate, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
@ -17,7 +17,7 @@ use super::{
|
|||||||
types::ndarray::NDArrayType,
|
types::ndarray::NDArrayType,
|
||||||
values::{
|
values::{
|
||||||
ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
||||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
UntypedArrayLikeAccessor,
|
||||||
},
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
@ -2165,58 +2165,49 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||||
|
|
||||||
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else {
|
let BasicValueEnum::PointerValue(x1) = x1 else {
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = llvm_ndarray_ty.map_value(n1, None);
|
|
||||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
|
||||||
let n2_array = numpy::create_ndarray_const_shape(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
elem_ty,
|
|
||||||
&[llvm_usize.const_int(1, false)],
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
unsafe {
|
|
||||||
n2_array.data().set_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_zero(),
|
|
||||||
n2.as_basic_value_enum(),
|
|
||||||
);
|
|
||||||
};
|
|
||||||
let n2_array = n2_array.as_base_value().as_basic_value_enum();
|
|
||||||
|
|
||||||
let outdim0 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let outdim1 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
|
||||||
.map(NDArrayValue::into)
|
|
||||||
.map(PointerValue::into)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
|
|
||||||
|
|
||||||
Ok(out)
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||||
|
};
|
||||||
|
|
||||||
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
|
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
|
||||||
|
|
||||||
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// x2 is a float, but we are promoting this to a 1D ndarray (.shape == [1]) for uniformity in function call.
|
||||||
|
let x2 = call_float(generator, ctx, (x2_ty, x2))?;
|
||||||
|
let BasicValueEnum::FloatValue(x2) = x2 else {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||||
|
};
|
||||||
|
|
||||||
|
let x2 = NDArrayType::new_unsized(generator, ctx.ctx, ctx.ctx.f64_type().into())
|
||||||
|
.construct_unsized(generator, ctx, &x2, None); // x2.shape == []
|
||||||
|
let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1]
|
||||||
|
|
||||||
|
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
|
||||||
|
.construct_uninitialized(generator, ctx, None);
|
||||||
|
out.copy_shape_from_ndarray(generator, ctx, x1);
|
||||||
|
unsafe { out.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let x2_c = x2.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let out_c = out.make_contiguous_ndarray(generator, ctx);
|
||||||
|
|
||||||
|
extern_fns::call_np_linalg_matrix_power(
|
||||||
|
ctx,
|
||||||
|
x1_c.as_base_value().into(),
|
||||||
|
x2_c.as_base_value().into(),
|
||||||
|
out_c.as_base_value().into(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(out.as_base_value().into())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `np_linalg_det` linalg function
|
/// Invokes the `np_linalg_det` linalg function
|
||||||
|
@ -1764,7 +1764,7 @@ def run() -> int32:
|
|||||||
test_ndarray_svd()
|
test_ndarray_svd()
|
||||||
test_ndarray_linalg_inv()
|
test_ndarray_linalg_inv()
|
||||||
test_ndarray_pinv()
|
test_ndarray_pinv()
|
||||||
# test_ndarray_matrix_power()
|
test_ndarray_matrix_power()
|
||||||
test_ndarray_det()
|
test_ndarray_det()
|
||||||
test_ndarray_lu()
|
test_ndarray_lu()
|
||||||
test_ndarray_schur()
|
test_ndarray_schur()
|
||||||
|
Loading…
Reference in New Issue
Block a user