diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 80dbf0aa..a41b9f55 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,6 +1,6 @@ use inkwell::{ types::BasicTypeEnum, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, + values::{BasicValueEnum, IntValue, PointerValue}, FloatPredicate, IntPredicate, OptimizationLevel, }; use itertools::Itertools; @@ -17,7 +17,7 @@ use super::{ types::ndarray::NDArrayType, values::{ ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, }; @@ -2165,58 +2165,49 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_matrix_power"; - let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap(); - let llvm_usize = generator.get_size_type(ctx.ctx); - if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); - let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else { - unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); - }; - - let n1 = llvm_ndarray_ty.map_value(n1, None); - // Changing second parameter to a `NDArray` for uniformity in function call - let n2_array = numpy::create_ndarray_const_shape( - generator, - ctx, - elem_ty, - &[llvm_usize.const_int(1, false)], - ) - .unwrap(); - unsafe { - n2_array.data().set_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - n2.as_basic_value_enum(), - ); - }; - let n2_array = n2_array.as_base_value().as_basic_value_enum(); - - let outdim0 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let outdim1 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - - extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None); - - Ok(out) - } else { + let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + }; + + let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let ndims = extract_ndims(&ctx.unifier, ndims); + let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None); + + if !x1.get_type().element_type().is_float_type() { + unsupported_type(ctx, FN_NAME, &[x1_ty]); } + + // x2 is a float, but we are promoting this to a 1D ndarray (.shape == [1]) for uniformity in function call. + let x2 = call_float(generator, ctx, (x2_ty, x2))?; + let BasicValueEnum::FloatValue(x2) = x2 else { + unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + }; + + let x2 = NDArrayType::new_unsized(generator, ctx.ctx, ctx.ctx.f64_type().into()) + .construct_unsized(generator, ctx, &x2, None); // x2.shape == [] + let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1] + + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + .construct_uninitialized(generator, ctx, None); + out.copy_shape_from_ndarray(generator, ctx, x1); + unsafe { out.create_data(generator, ctx) }; + + let x1_c = x1.make_contiguous_ndarray(generator, ctx); + let x2_c = x2.make_contiguous_ndarray(generator, ctx); + let out_c = out.make_contiguous_ndarray(generator, ctx); + + extern_fns::call_np_linalg_matrix_power( + ctx, + x1_c.as_base_value().into(), + x2_c.as_base_value().into(), + out_c.as_base_value().into(), + None, + ); + + Ok(out.as_base_value().into()) } /// Invokes the `np_linalg_det` linalg function diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index e82a6b7d..d42f3b93 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1764,7 +1764,7 @@ def run() -> int32: test_ndarray_svd() test_ndarray_linalg_inv() test_ndarray_pinv() - # test_ndarray_matrix_power() + test_ndarray_matrix_power() test_ndarray_det() test_ndarray_lu() test_ndarray_schur()