diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index fdbffb39..0773589a 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include diff --git a/nac3core/irrt/irrt/ndarray/matmul.hpp b/nac3core/irrt/irrt/ndarray/matmul.hpp new file mode 100644 index 00000000..bec36c48 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/matmul.hpp @@ -0,0 +1,197 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +// NOTE: Everything would be much easier and elegant if einsum is implemented. + +namespace +{ +namespace ndarray +{ +namespace matmul +{ + +/** + * @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`. + * + * Example: + * Suppose `a_shape == [1, 97, 4, 2]` + * and `b_shape == [99, 98, 1, 2, 5]`, + * + * ...then `new_a_shape == [99, 98, 97, 4, 2]`, + * `new_b_shape == [99, 98, 97, 2, 5]`, + * and `dst_shape == [99, 98, 97, 4, 5]`. + * ^^^^^^^^^^ ^^^^ + * (broadcasted) (4x2 @ 2x5 => 4x5) + * + * @param a_ndims Length of `a_shape`. + * @param a_shape Shape of `a`. + * @param b_ndims Length of `b_shape`. + * @param b_shape Shape of `b`. + * @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`, + * `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting. + */ +template +void calculate_shapes(SizeT a_ndims, SizeT *a_shape, SizeT b_ndims, SizeT *b_shape, SizeT final_ndims, + SizeT *new_a_shape, SizeT *new_b_shape, SizeT *dst_shape) +{ + debug_assert(SizeT, a_ndims >= 2); + debug_assert(SizeT, b_ndims >= 2); + debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims); + + const SizeT num_entries = 2; + ShapeEntry entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape}, + {.ndims = b_ndims - 2, .shape = b_shape}}; + + // TODO: Optimize this + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_a_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_b_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, dst_shape); + + new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1]; + new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2]; + new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1]; + dst_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + dst_shape[final_ndims - 1] = b_shape[b_ndims - 1]; +} + +/** + * @brief Perform einsum notation, the output is the broadcasts performed by `np.einsum("...ij,...jk->...ik", a, b)`. + * + * This function is an auxillary function to compute a matrix multiplication. + * + * Also see https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy-matmul. + * + * This function expects `dst_ndarray` to contain the following content when called: + * - `dst_ndarray->data` is allocated. Can be uninitialized. + * - `dst_ndarray->itemsize` is set to `sizeof(T)`. + * - `dst_ndarray->ndims` is set to be the correct ndims of the result of the matmul.. + * - `dst_ndarray->shape` is set to be the correct shape of the result of the matmul. + * - `dst_ndarray->strides` is ignored. + */ +template +void matmul_at_least_2d(NDArray *a_ndarray, NDArray *b_ndarray, NDArray *dst_ndarray) +{ + // All inputs' ndims should be >= 2 and be the same. + debug_assert_eq(SizeT, a_ndarray->ndims, b_ndarray->ndims); + debug_assert_eq(SizeT, a_ndarray->ndims, dst_ndarray->ndims); + debug_assert(SizeT, a_ndarray->ndims >= 2); + + debug_assert_eq(SizeT, a_ndarray->itemsize, sizeof(T)); + debug_assert_eq(SizeT, b_ndarray->itemsize, sizeof(T)); + debug_assert_eq(SizeT, dst_ndarray->itemsize, sizeof(T)); + + if (IRRT_DEBUG_ASSERT_BOOL) + { + // Check that the shapes are the same. + for (SizeT i = 0; i < a_ndarray->ndims - 2; i++) + { + if (dst_ndarray->shape[0] != a_ndarray->shape[0]) + { + raise_debug_assert(SizeT, "Bad shape. At axis {0}, a has {1}, dst has {2}", i, a_ndarray->shape[i], + dst_ndarray->shape[i]); + } + if (dst_ndarray->shape[0] != b_ndarray->shape[0]) + { + raise_debug_assert(SizeT, "Bad shape. At axis {0}, b has {1}, dst has {2}", i, b_ndarray->shape[i], + dst_ndarray->shape[i]); + } + } + } + + // Number of dimensions dedicated to stacking + // e.g., [4, 6, 1, 2, 3] + // ^^^^^^^ count these + const SizeT u = a_ndarray->ndims - 2; // Alias + + SizeT *a_mat_shape = a_ndarray->shape + u; + SizeT *b_mat_shape = b_ndarray->shape + u; + SizeT *dst_mat_shape = dst_ndarray->shape + u; + + // Assert that dst_ndarray has the correct shape + debug_assert_eq(SizeT, dst_mat_shape[0], a_mat_shape[0]); + debug_assert_eq(SizeT, dst_mat_shape[1], b_mat_shape[1]); + + // Check that a and b are compatible for matmul + if (a_mat_shape[1] != b_mat_shape[0]) + { + // This is a custom error message. Different from NumPy. + raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})", + a_mat_shape[1], b_mat_shape[0], NO_PARAM); + } + + // Iterate through shape[:-2]. i.e, + // Given a = [5, 4, 3, m, p] and b = [5, 4, 3, p, n]. We iterate with shape [5, 4, 3]. + SizeT *indices = (SizeT *)__builtin_alloca(sizeof(SizeT) * dst_ndarray->ndims); + SizeT *mat_indices = indices + u; + NDIter iter; + iter.initialize(u, dst_ndarray->shape, dst_ndarray->strides, dst_ndarray->data, indices); + + for (; iter.has_next(); iter.next()) + { + for (SizeT i = 0; i < dst_mat_shape[0]; i++) + { + for (SizeT j = 0; j < dst_mat_shape[1]; j++) + { + // `indices` is being reused to index into different ndarrays. + mat_indices[0] = i; + mat_indices[1] = j; + T *d = ndarray::basic::get_ptr(dst_ndarray, indices); + *d = 0; + + for (SizeT k = 0; k < a_ndarray->shape[1]; k++) + { + mat_indices[0] = i; + mat_indices[1] = k; + T *a = ndarray::basic::get_ptr(a_ndarray, indices); + + mat_indices[0] = k; + mat_indices[1] = j; + T *b = ndarray::basic::get_ptr(b_ndarray, indices); + + *d += (*a) * (*b); + } + } + } + } +} +} // namespace matmul +} // namespace ndarray +} // namespace + +extern "C" +{ + using namespace ndarray::matmul; + + void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, int32_t *a_shape, int32_t b_ndims, int32_t *b_shape, + int32_t final_ndims, int32_t *new_a_shape, int32_t *new_b_shape, + int32_t *dst_shape) + { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); + } + + void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, int64_t *a_shape, int64_t b_ndims, int64_t *b_shape, + int64_t final_ndims, int64_t *new_a_shape, int64_t *new_b_shape, + int64_t *dst_shape) + { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); + } + + void __nac3_ndarray_float64_matmul_at_least_2d(NDArray *a_ndarray, NDArray *b_ndarray, + NDArray *dst_ndarray) + { + matmul_at_least_2d(a_ndarray, b_ndarray, dst_ndarray); + } + + void __nac3_ndarray_float64_matmul_at_least_2d64(NDArray *a_ndarray, NDArray *b_ndarray, + NDArray *dst_ndarray) + { + matmul_at_least_2d(a_ndarray, b_ndarray, dst_ndarray); + } +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 31ecd298..0072fd62 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1572,7 +1572,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( if op.base == Operator::MatMult { // Handle matrix multiplication. - todo!() + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); + let result = NDArrayObject::matmul(generator, ctx, left, right, out) + .split_unsized(generator, ctx); + Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum()))) } else { // For other operations, they are all elementwise operations. diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 9f928e48..d3bb19ac 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1219,3 +1219,49 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( .arg(axes) .returning_void(); } + +#[allow(clippy::too_many_arguments)] +pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a_ndims: Instance<'ctx, Int>, + a_shape: Instance<'ctx, Ptr>>, + b_ndims: Instance<'ctx, Int>, + b_shape: Instance<'ctx, Ptr>>, + final_ndims: Instance<'ctx, Int>, + new_a_shape: Instance<'ctx, Ptr>>, + new_b_shape: Instance<'ctx, Ptr>>, + dst_shape: Instance<'ctx, Ptr>>, +) { + let name = + get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes"); + CallFunction::begin(generator, ctx, &name) + .arg(a_ndims) + .arg(a_shape) + .arg(b_ndims) + .arg(b_shape) + .arg(final_ndims) + .arg(new_a_shape) + .arg(new_b_shape) + .arg(dst_shape) + .returning_void(); +} + +pub fn call_nac3_ndarray_float64_matmul_at_least_2d<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a_ndarray: Instance<'ctx, Ptr>>, + b_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_float64_matmul_at_least_2d", + ); + CallFunction::begin(generator, ctx, &name) + .arg(a_ndarray) + .arg(b_ndarray) + .arg(dst_ndarray) + .returning_void(); +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index e094c1b4..9ccc71a4 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1437,302 +1437,6 @@ where Ok(ndarray) } -/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - res: Option>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - if cfg!(debug_assertions) { - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - - // lhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // rhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - if let Some(res) = res { - let res_ndims = res.load_ndims(ctx); - let res_dim0 = unsafe { - res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let res_dim1 = unsafe { - res.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let lhs_dim0 = unsafe { - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let rhs_dim1 = unsafe { - rhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - - // res.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare( - IntPredicate::EQ, - res_ndims, - llvm_usize.const_int(2, false), - "", - ) - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[0] == lhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - } - - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let lhs_dim1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_dim0 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - // lhs.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - - let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) { - ndarray_copy_impl(generator, ctx, elem_ty, lhs)? - } else { - lhs - }; - - let ndarray = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &(lhs, rhs), - |_, _, _| Ok(llvm_usize.const_int(2, false)), - |generator, ctx, (lhs, rhs), idx| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "") - .unwrap()) - }, - |generator, ctx| { - Ok(Some(unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - })) - }, - |generator, ctx| { - Ok(Some(unsafe { - rhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - })) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value).unwrap()) - }, - ) - .unwrap() - }); - - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| { - llvm_intrinsics::call_expect( - ctx, - idx.size(ctx, generator).get_type().const_int(2, false), - idx.size(ctx, generator), - None, - ); - - let common_dim = { - let lhs_idx1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_idx0 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - - ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() - }; - - let idx0 = unsafe { - let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - - ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() - }; - let idx1 = unsafe { - let idx1 = - idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); - - ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() - }; - - let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let result_identity = ndarray_zero_value(generator, ctx, elem_ty); - ctx.builder.build_store(result_addr, result_identity).unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_i32.const_zero(), - (common_dim, false), - |generator, ctx, _, i| { - let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); - - let ab_idx = generator.gen_array_var_alloc( - ctx, - llvm_i32.into(), - llvm_usize.const_int(2, false), - None, - )?; - - let a = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into()); - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into()); - - lhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - let b = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into()); - ab_idx.set_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - idx1.into(), - ); - - rhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - - let a_mul_b = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), a), - Binop::normal(Operator::Mult), - (&Some(elem_ty), b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - let result = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), result), - Binop::normal(Operator::Add), - (&Some(elem_ty), a_mul_b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - ctx.builder.build_store(result_addr, result).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - Ok(result) - })?; - - Ok(ndarray) -} - /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx>( context: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/object/ndarray/matmul.rs b/nac3core/src/codegen/object/ndarray/matmul.rs new file mode 100644 index 00000000..17d79db5 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/matmul.rs @@ -0,0 +1,153 @@ +use std::cmp::max; + +use crate::codegen::{ + irrt::{ + call_nac3_ndarray_float64_matmul_at_least_2d, call_nac3_ndarray_matmul_calculate_shapes, + }, + model::*, + object::ndarray::indexing::RustNDIndex, + CodeGenContext, CodeGenerator, +}; + +use super::{NDArrayObject, NDArrayOut}; + +/// Perform `np.einsum("...ij,...jk->...ik", a, b)`. +fn matmul_helper<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a: NDArrayObject<'ctx>, + b: NDArrayObject<'ctx>, +) -> NDArrayObject<'ctx> { + assert!(a.ndims >= 2); + assert!(b.ndims >= 2); + + assert!(ctx.unifier.unioned(ctx.primitives.float, a.dtype)); + assert!(ctx.unifier.unioned(ctx.primitives.float, b.dtype)); + + let final_ndims_int = max(a.ndims, b.ndims); + + let a_ndims = a.ndims_llvm(generator, ctx.ctx); + let a_shape = a.instance.get(generator, ctx, |f| f.shape); + let b_ndims = b.ndims_llvm(generator, ctx.ctx); + let b_shape = b.instance.get(generator, ctx, |f| f.shape); + let final_ndims = Int(SizeT).const_int(generator, ctx.ctx, final_ndims_int); + let new_a_shape = Int(SizeT).array_alloca(generator, ctx, final_ndims.value); + let new_b_shape = Int(SizeT).array_alloca(generator, ctx, final_ndims.value); + let dst_shape = Int(SizeT).array_alloca(generator, ctx, final_ndims.value); + + call_nac3_ndarray_matmul_calculate_shapes( + generator, + ctx, + a_ndims, + a_shape, + b_ndims, + b_shape, + final_ndims, + new_a_shape, + new_b_shape, + dst_shape, + ); + + let dst = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, final_ndims_int); + dst.copy_shape_from_array(generator, ctx, dst_shape); + dst.create_data(generator, ctx); + + let new_a = a.broadcast_to(generator, ctx, final_ndims_int, new_a_shape); + let new_b = b.broadcast_to(generator, ctx, final_ndims_int, new_b_shape); + + call_nac3_ndarray_float64_matmul_at_least_2d( + generator, + ctx, + new_a.instance, + new_b.instance, + dst.instance, + ); + + dst +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Perform `np.matmul` according to the rules in + /// . + /// + /// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`] + /// to handle when the output could be a scalar. + pub fn matmul( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a: Self, + b: Self, + out: NDArrayOut<'ctx>, + ) -> Self { + // Sanity check, but type inference should prevent this. + assert!(a.ndims > 0 && b.ndims > 0, "np.matmul disallows scalar input"); + + /* + If both arguments are 2-D they are multiplied like conventional matrices. + If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indices and broadcast accordingly. + If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed. + If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed. + */ + + let new_a = if a.ndims == 1 { + // Prepend 1 to its dimensions + a.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis]) + } else { + a + }; + + let new_b = if b.ndims == 1 { + // Append 1 to its dimensions + b.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis]) + } else { + b + }; + + // NOTE: `result` will always be a newly allocated ndarray. + // Current implementation cannot do in-place matrix muliplication. + let mut result = matmul_helper(generator, ctx, new_a, new_b); + + let zero = Int(Int32).const_0(generator, ctx.ctx); + + // Postprocessing on the result to remove prepended/appended axes. + let mut postindices = vec![]; + + if a.ndims == 1 { + // Remove the prepended 1 + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if b.ndims == 1 { + // Remove the appended 1 + postindices.push(RustNDIndex::Ellipsis); + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if !postindices.is_empty() { + result = result.index(generator, ctx, &postindices); + } + + match out { + NDArrayOut::NewNDArray { dtype } => { + // We don't support auto-casting right now, nor anything other than float64. + // Force the output dtype to be float64. + assert!(ctx.unifier.unioned(ctx.primitives.float, dtype)); + result + } + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => { + // TODO: It is possible to check the shapes before computing the matmul to die + // quicker to save computes. + let result_shape = result.instance.get(generator, ctx, |f| f.shape); + out_ndarray.assert_can_be_written_by_out( + generator, + ctx, + result.ndims, + result_shape, + ); + + out_ndarray.copy_data_from(generator, ctx, result); + out_ndarray + } + } + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 9ee4fd33..9fb3d220 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -3,6 +3,7 @@ pub mod broadcast; pub mod factory; pub mod indexing; pub mod map; +pub mod matmul; pub mod nditer; pub mod shape_util; pub mod view; diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index b5a0608c..98c43e75 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -1,5 +1,5 @@ use crate::symbol_resolver::SymbolValue; -use crate::toplevel::helper::PrimDef; +use crate::toplevel::helper::{extract_ndims, PrimDef}; use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; use crate::typecheck::{ type_inferencer::*, @@ -520,36 +520,41 @@ pub fn typeof_binop( } Operator::MatMult => { - let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); - let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() + let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); + let lhs_ndims = extract_ndims(unifier, lhs_ndims); + + let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); + let rhs_ndims = extract_ndims(unifier, rhs_ndims); + + if !(unifier.unioned(lhs_dtype, primitives.float) + && unifier.unioned(rhs_dtype, primitives.float)) + { + return Err(format!( + "ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}", + unifier.stringify(lhs), + unifier.stringify(rhs) + )); + } + + let result_ndims = match (lhs_ndims, rhs_ndims) { + (0, _) | (_, 0) => { + return Err( + "ndarray.__matmul__ does not allow unsized ndarray input".to_string() + ) } - _ => unreachable!(), - }; - let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); - let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() - } - _ => unreachable!(), + (1, 1) => 0, + (1, _) => rhs_ndims - 1, + (_, 1) => lhs_ndims - 1, + (m, n) => max(m, n), }; - match (lhs_ndims, rhs_ndims) { - (2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, - (lhs, rhs) if lhs == 0 || rhs == 0 => { - return Err(format!( - "Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", - u8::from(rhs == 0) - )) - } - (lhs, rhs) => { - return Err(format!( - "ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported" - )) - } + if result_ndims == 0 { + // If the result is unsized, NumPy returns a scalar. + primitives.float + } else { + let result_ndims_ty = + unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None); + make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty)) } } @@ -748,7 +753,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); - impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t)); + impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None); impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);