From 548739b86f9d8699abca5cc096ab9587b3dd2fa7 Mon Sep 17 00:00:00 2001 From: lyken Date: Sun, 25 Aug 2024 00:04:10 +0800 Subject: [PATCH] core/ndstrides: implement general matmul --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/matmul.hpp | 100 ++++++ nac3core/src/codegen/expr.rs | 6 +- nac3core/src/codegen/irrt/mod.rs | 27 ++ nac3core/src/codegen/numpy.rs | 306 +----------------- nac3core/src/codegen/object/ndarray/matmul.rs | 218 +++++++++++++ nac3core/src/codegen/object/ndarray/mod.rs | 1 + nac3core/src/typecheck/magic_methods.rs | 83 +++-- 8 files changed, 396 insertions(+), 348 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/matmul.hpp create mode 100644 nac3core/src/codegen/object/ndarray/matmul.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 95856830..70ef2392 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -12,4 +12,5 @@ #include "irrt/ndarray/array.hpp" #include "irrt/ndarray/reshape.hpp" #include "irrt/ndarray/broadcast.hpp" -#include "irrt/ndarray/transpose.hpp" \ No newline at end of file +#include "irrt/ndarray/transpose.hpp" +#include "irrt/ndarray/matmul.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/matmul.hpp b/nac3core/irrt/irrt/ndarray/matmul.hpp new file mode 100644 index 00000000..8c0a6a60 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/matmul.hpp @@ -0,0 +1,100 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/broadcast.hpp" +#include "irrt/ndarray/iter.hpp" + +// NOTE: Everything would be much easier and elegant if einsum is implemented. + +namespace { +namespace ndarray { +namespace matmul { + +/** + * @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`. + * + * Example: + * Suppose `a_shape == [1, 97, 4, 2]` + * and `b_shape == [99, 98, 1, 2, 5]`, + * + * ...then `new_a_shape == [99, 98, 97, 4, 2]`, + * `new_b_shape == [99, 98, 97, 2, 5]`, + * and `dst_shape == [99, 98, 97, 4, 5]`. + * ^^^^^^^^^^ ^^^^ + * (broadcasted) (4x2 @ 2x5 => 4x5) + * + * @param a_ndims Length of `a_shape`. + * @param a_shape Shape of `a`. + * @param b_ndims Length of `b_shape`. + * @param b_shape Shape of `b`. + * @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`, + * `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting. + */ +template +void calculate_shapes(SizeT a_ndims, + SizeT* a_shape, + SizeT b_ndims, + SizeT* b_shape, + SizeT final_ndims, + SizeT* new_a_shape, + SizeT* new_b_shape, + SizeT* dst_shape) { + debug_assert(SizeT, a_ndims >= 2); + debug_assert(SizeT, b_ndims >= 2); + debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims); + + // Check that a and b are compatible for matmul + if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) { + // This is a custom error message. Different from NumPy. + raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})", + a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM); + } + + const SizeT num_entries = 2; + ShapeEntry entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape}, + {.ndims = b_ndims - 2, .shape = b_shape}}; + + // TODO: Optimize this + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_a_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_b_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, dst_shape); + + new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1]; + new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2]; + new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1]; + dst_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + dst_shape[final_ndims - 1] = b_shape[b_ndims - 1]; +} +} // namespace matmul +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::matmul; + +void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, + int32_t* a_shape, + int32_t b_ndims, + int32_t* b_shape, + int32_t final_ndims, + int32_t* new_a_shape, + int32_t* new_b_shape, + int32_t* dst_shape) { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); +} + +void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, + int64_t* a_shape, + int64_t b_ndims, + int64_t* b_shape, + int64_t final_ndims, + int64_t* new_a_shape, + int64_t* new_b_shape, + int64_t* dst_shape) { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 732993d7..7c1ac239 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1578,7 +1578,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( if op.base == Operator::MatMult { // Handle matrix multiplication. - todo!() + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); + let result = NDArrayObject::matmul(generator, ctx, left, right, out) + .split_unsized(generator, ctx); + Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum()))) } else { // For other operations, they are all elementwise operations. diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index c19d9e15..2bdec6a0 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1221,3 +1221,30 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( .arg(axes) .returning_void(); } + +#[allow(clippy::too_many_arguments)] +pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a_ndims: Instance<'ctx, Int>, + a_shape: Instance<'ctx, Ptr>>, + b_ndims: Instance<'ctx, Int>, + b_shape: Instance<'ctx, Ptr>>, + final_ndims: Instance<'ctx, Int>, + new_a_shape: Instance<'ctx, Ptr>>, + new_b_shape: Instance<'ctx, Ptr>>, + dst_shape: Instance<'ctx, Ptr>>, +) { + let name = + get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes"); + FnCall::builder(generator, ctx, &name) + .arg(a_ndims) + .arg(a_shape) + .arg(b_ndims) + .arg(b_shape) + .arg(final_ndims) + .arg(new_a_shape) + .arg(new_b_shape) + .arg(dst_shape) + .returning_void(); +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 0cbc3476..d3578d20 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,10 +1,10 @@ use inkwell::{ types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, - AddressSpace, IntPredicate, OptimizationLevel, + AddressSpace, IntPredicate, }; -use nac3parser::ast::{Operator, StrRef}; +use nac3parser::ast::StrRef; use crate::{ codegen::{ @@ -13,7 +13,6 @@ use crate::{ ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, - expr::gen_binop_expr_with_values, irrt::{ calculate_len_for_slice_range, call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, @@ -35,10 +34,7 @@ use crate::{ numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, }, - typecheck::{ - magic_methods::Binop, - typedef::{FunSignature, Type}, - }, + typecheck::typedef::{FunSignature, Type}, }; /// Creates an uninitialized `NDArray` instance. @@ -1438,302 +1434,6 @@ where Ok(ndarray) } -/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - res: Option>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - if cfg!(debug_assertions) { - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - - // lhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // rhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - if let Some(res) = res { - let res_ndims = res.load_ndims(ctx); - let res_dim0 = unsafe { - res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let res_dim1 = unsafe { - res.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let lhs_dim0 = unsafe { - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let rhs_dim1 = unsafe { - rhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - - // res.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare( - IntPredicate::EQ, - res_ndims, - llvm_usize.const_int(2, false), - "", - ) - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[0] == lhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - } - - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let lhs_dim1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_dim0 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - // lhs.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - - let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) { - ndarray_copy_impl(generator, ctx, elem_ty, lhs)? - } else { - lhs - }; - - let ndarray = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &(lhs, rhs), - |_, _, _| Ok(llvm_usize.const_int(2, false)), - |generator, ctx, (lhs, rhs), idx| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "") - .unwrap()) - }, - |generator, ctx| { - Ok(Some(unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - })) - }, - |generator, ctx| { - Ok(Some(unsafe { - rhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - })) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value).unwrap()) - }, - ) - .unwrap() - }); - - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| { - llvm_intrinsics::call_expect( - ctx, - idx.size(ctx, generator).get_type().const_int(2, false), - idx.size(ctx, generator), - None, - ); - - let common_dim = { - let lhs_idx1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_idx0 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - - ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() - }; - - let idx0 = unsafe { - let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - - ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() - }; - let idx1 = unsafe { - let idx1 = - idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); - - ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() - }; - - let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let result_identity = ndarray_zero_value(generator, ctx, elem_ty); - ctx.builder.build_store(result_addr, result_identity).unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_i32.const_zero(), - (common_dim, false), - |generator, ctx, _, i| { - let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); - - let ab_idx = generator.gen_array_var_alloc( - ctx, - llvm_i32.into(), - llvm_usize.const_int(2, false), - None, - )?; - - let a = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into()); - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into()); - - lhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - let b = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into()); - ab_idx.set_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - idx1.into(), - ); - - rhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - - let a_mul_b = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), a), - Binop::normal(Operator::Mult), - (&Some(elem_ty), b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - let result = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), result), - Binop::normal(Operator::Add), - (&Some(elem_ty), a_mul_b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - ctx.builder.build_store(result_addr, result).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - Ok(result) - })?; - - Ok(ndarray) -} - /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx>( context: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/object/ndarray/matmul.rs b/nac3core/src/codegen/object/ndarray/matmul.rs new file mode 100644 index 00000000..a2f3406a --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/matmul.rs @@ -0,0 +1,218 @@ +use std::cmp::max; + +use nac3parser::ast::Operator; +use util::gen_for_model; + +use crate::{ + codegen::{ + expr::gen_binop_expr_with_values, irrt::call_nac3_ndarray_matmul_calculate_shapes, + model::*, object::ndarray::indexing::RustNDIndex, CodeGenContext, CodeGenerator, + }, + typecheck::{magic_methods::Binop, typedef::Type}, +}; + +use super::{NDArrayObject, NDArrayOut}; + +/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`. +/// +/// `dst_dtype` defines the dtype of the returned ndarray. +fn matmul_at_least_2d<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dst_dtype: Type, + in_a: NDArrayObject<'ctx>, + in_b: NDArrayObject<'ctx>, +) -> NDArrayObject<'ctx> { + assert!(in_a.ndims >= 2); + assert!(in_b.ndims >= 2); + + // Deduce ndims of the result of matmul. + let ndims_int = max(in_a.ndims, in_b.ndims); + let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int, false); + + // Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the + // destination ndarray to store the result of matmul. + let (lhs, rhs, dst) = { + let in_lhs_ndims = in_a.ndims_llvm(generator, ctx.ctx); + let in_lhs_shape = in_a.instance.get(generator, ctx, |f| f.shape); + let in_rhs_ndims = in_b.ndims_llvm(generator, ctx.ctx); + let in_rhs_shape = in_b.instance.get(generator, ctx, |f| f.shape); + let lhs_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value); + let rhs_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value); + let dst_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value); + + // Matmul dimension compatibility is checked here. + call_nac3_ndarray_matmul_calculate_shapes( + generator, + ctx, + in_lhs_ndims, + in_lhs_shape, + in_rhs_ndims, + in_rhs_shape, + ndims, + lhs_shape, + rhs_shape, + dst_shape, + ); + + let lhs = in_a.broadcast_to(generator, ctx, ndims_int, lhs_shape); + let rhs = in_b.broadcast_to(generator, ctx, ndims_int, rhs_shape); + + let dst = NDArrayObject::alloca(generator, ctx, dst_dtype, ndims_int); + dst.copy_shape_from_array(generator, ctx, dst_shape); + dst.create_data(generator, ctx); + + (lhs, rhs, dst) + }; + + let len = lhs.instance.get(generator, ctx, |f| f.shape).get_index_const( + generator, + ctx, + i64::try_from(ndims_int - 1).unwrap(), + ); + + let at_row = i64::try_from(ndims_int - 2).unwrap(); + let at_col = i64::try_from(ndims_int - 1).unwrap(); + + let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype); + let dst_zero = dst_dtype_llvm.const_zero(); + + dst.foreach(generator, ctx, |generator, ctx, _, hdl| { + let pdst_ij = hdl.get_pointer(generator, ctx); + + ctx.builder.build_store(pdst_ij, dst_zero).unwrap(); + + let indices = hdl.get_indices(); + let i = indices.get_index_const(generator, ctx, at_row); + let j = indices.get_index_const(generator, ctx, at_col); + + let num_0 = Int(SizeT).const_int(generator, ctx.ctx, 0, false); + let num_1 = Int(SizeT).const_int(generator, ctx.ctx, 1, false); + + gen_for_model(generator, ctx, num_0, len, num_1, |generator, ctx, _, k| { + // `indices` is modified to index into `a` and `b`, and restored. + indices.set_index_const(ctx, at_row, i); + indices.set_index_const(ctx, at_col, k); + let a_ik = lhs.get_scalar_by_indices(generator, ctx, indices); + + indices.set_index_const(ctx, at_row, k); + indices.set_index_const(ctx, at_col, j); + let b_kj = rhs.get_scalar_by_indices(generator, ctx, indices); + + // Restore `indices`. + indices.set_index_const(ctx, at_row, i); + indices.set_index_const(ctx, at_col, j); + + // x = a_[...]ik * b_[...]kj + let x = gen_binop_expr_with_values( + generator, + ctx, + (&Some(lhs.dtype), a_ik.value), + Binop::normal(Operator::Mult), + (&Some(rhs.dtype), b_kj.value), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + + // dst_[...]ij += x + let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap(); + let dst_ij = gen_binop_expr_with_values( + generator, + ctx, + (&Some(dst_dtype), dst_ij), + Binop::normal(Operator::Add), + (&Some(dst_dtype), x), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + ctx.builder.build_store(pdst_ij, dst_ij).unwrap(); + + Ok(()) + }) + }) + .unwrap(); + + dst +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Perform `np.matmul` according to the rules in + /// . + /// + /// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`] + /// to handle when the output could be a scalar. + /// + /// `dst_dtype` defines the dtype of the returned ndarray. + pub fn matmul( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a: Self, + b: Self, + out: NDArrayOut<'ctx>, + ) -> Self { + // Sanity check, but type inference should prevent this. + assert!(a.ndims > 0 && b.ndims > 0, "np.matmul disallows scalar input"); + + /* + If both arguments are 2-D they are multiplied like conventional matrices. + If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indices and broadcast accordingly. + If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed. + If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed. + */ + + let new_a = if a.ndims == 1 { + // Prepend 1 to its dimensions + a.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis]) + } else { + a + }; + + let new_b = if b.ndims == 1 { + // Append 1 to its dimensions + b.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis]) + } else { + b + }; + + // NOTE: `result` will always be a newly allocated ndarray. + // Current implementation cannot do in-place matrix muliplication. + let mut result = matmul_at_least_2d(generator, ctx, out.get_dtype(), new_a, new_b); + + // Postprocessing on the result to remove prepended/appended axes. + let mut postindices = vec![]; + let zero = Int(Int32).const_0(generator, ctx.ctx); + + if a.ndims == 1 { + // Remove the prepended 1 + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if b.ndims == 1 { + // Remove the appended 1 + postindices.push(RustNDIndex::Ellipsis); + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if !postindices.is_empty() { + result = result.index(generator, ctx, &postindices); + } + + match out { + NDArrayOut::NewNDArray { .. } => result, + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => { + let result_shape = result.instance.get(generator, ctx, |f| f.shape); + out_ndarray.assert_can_be_written_by_out( + generator, + ctx, + result.ndims, + result_shape, + ); + + out_ndarray.copy_data_from(generator, ctx, result); + out_ndarray + } + } + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 276ee3a7..48729645 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -27,6 +27,7 @@ pub mod broadcast; pub mod factory; pub mod indexing; pub mod map; +pub mod matmul; pub mod nditer; pub mod shape_util; pub mod view; diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 322f7b9e..5bee112a 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -5,10 +5,11 @@ use strum::IntoEnumIterator; use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop}; +use super::typedef::into_var_map; use crate::{ symbol_resolver::SymbolValue, toplevel::{ - helper::PrimDef, + helper::{extract_ndims, PrimDef}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, }, typecheck::{ @@ -175,19 +176,8 @@ pub fn impl_binop( ops: &[Operator], ) { with_fields(unifier, ty, |unifier, fields| { - let (other_ty, other_var_id) = if other_ty.len() == 1 { - (other_ty[0], None) - } else { - let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); - (tvar.ty, Some(tvar.id)) - }; - - let function_vars = if let Some(var_id) = other_var_id { - vec![(var_id, other_ty)].into_iter().collect::() - } else { - VarMap::new() - }; - + let other_tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); + let function_vars = into_var_map([other_tvar]); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) { @@ -198,7 +188,7 @@ pub fn impl_binop( ret: ret_ty, vars: function_vars.clone(), args: vec![FuncArg { - ty: other_ty, + ty: other_tvar.ty, default_value: None, name: "other".into(), is_vararg: false, @@ -541,36 +531,43 @@ pub fn typeof_binop( } } - let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); - let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() + let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); + let lhs_ndims = extract_ndims(unifier, lhs_ndims); + + let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); + let rhs_ndims = extract_ndims(unifier, rhs_ndims); + + if !(unifier.unioned(lhs_dtype, primitives.float) + && unifier.unioned(rhs_dtype, primitives.float)) + { + return Err(format!( + "ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}", + unifier.stringify(lhs), + unifier.stringify(rhs) + )); + } + + // Deduce the ndims of the resulting ndarray. + // If this is 0 (an unsized ndarray), matmul returns a scalar just like NumPy. + let result_ndims = match (lhs_ndims, rhs_ndims) { + (0, _) | (_, 0) => { + return Err( + "ndarray.__matmul__ does not allow unsized ndarray input".to_string() + ) } - _ => unreachable!(), - }; - let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); - let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() - } - _ => unreachable!(), + (1, 1) => 0, + (1, _) => rhs_ndims - 1, + (_, 1) => lhs_ndims - 1, + (m, n) => max(m, n), }; - match (lhs_ndims, rhs_ndims) { - (2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, - (lhs, rhs) if lhs == 0 || rhs == 0 => { - return Err(format!( - "Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", - u8::from(rhs == 0) - )) - } - (lhs, rhs) => { - return Err(format!( - "ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported" - )) - } + if result_ndims == 0 { + // If the result is unsized, NumPy returns a scalar. + primitives.float + } else { + let result_ndims_ty = + unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None); + make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty)) } } @@ -773,7 +770,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); - impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t)); + impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None); impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);