From 66b8a5e01d73d84f1f1312f6915e3977ec7cc24c Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:21:52 +0800 Subject: [PATCH] [core] codegen/ndarray: Reimplement matmul Based on 73c2203b: core/ndstrides: implement general matmul --- nac3core/irrt/irrt.cpp | 4 +- nac3core/irrt/irrt/int_types.hpp | 2 - nac3core/irrt/irrt/ndarray.hpp | 50 -- nac3core/irrt/irrt/ndarray/matmul.hpp | 98 +++ nac3core/src/codegen/expr.rs | 67 +- nac3core/src/codegen/irrt/ndarray/matmul.rs | 66 ++ nac3core/src/codegen/irrt/ndarray/mod.rs | 131 +--- nac3core/src/codegen/numpy.rs | 727 +----------------- nac3core/src/codegen/types/ndarray/factory.rs | 2 +- nac3core/src/codegen/values/ndarray/matmul.rs | 334 ++++++++ nac3core/src/codegen/values/ndarray/mod.rs | 1 + ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/magic_methods.rs | 84 +- 17 files changed, 585 insertions(+), 995 deletions(-) delete mode 100644 nac3core/irrt/irrt/ndarray.hpp create mode 100644 nac3core/irrt/irrt/ndarray/matmul.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/matmul.rs create mode 100644 nac3core/src/codegen/values/ndarray/matmul.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 39ddba67..87dcb428 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -1,7 +1,6 @@ #include "irrt/exception.hpp" #include "irrt/list.hpp" #include "irrt/math.hpp" -#include "irrt/ndarray.hpp" #include "irrt/range.hpp" #include "irrt/slice.hpp" #include "irrt/string.hpp" @@ -12,4 +11,5 @@ #include "irrt/ndarray/array.hpp" #include "irrt/ndarray/reshape.hpp" #include "irrt/ndarray/broadcast.hpp" -#include "irrt/ndarray/transpose.hpp" \ No newline at end of file +#include "irrt/ndarray/transpose.hpp" +#include "irrt/ndarray/matmul.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/int_types.hpp b/nac3core/irrt/irrt/int_types.hpp index ed8a48b8..17ccf604 100644 --- a/nac3core/irrt/irrt/int_types.hpp +++ b/nac3core/irrt/irrt/int_types.hpp @@ -21,7 +21,5 @@ using uint64_t = unsigned _ExtInt(64); #endif -// NDArray indices are always `uint32_t`. -using NDIndexInt = uint32_t; // The type of an index or a value describing the length of a range/slice is always `int32_t`. using SliceIndex = int32_t; diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp deleted file mode 100644 index 534f18d6..00000000 --- a/nac3core/irrt/irrt/ndarray.hpp +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include "irrt/int_types.hpp" - -// TODO: To be deleted since NDArray with strides is done. - -namespace { -template -SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) { - __builtin_assume(end_idx <= list_len); - - SizeT num_elems = 1; - for (SizeT i = begin_idx; i < end_idx; ++i) { - SizeT val = list_data[i]; - __builtin_assume(val > 0); - num_elems *= val; - } - return num_elems; -} - -template -void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) { - SizeT stride = 1; - for (SizeT dim = 0; dim < num_dims; dim++) { - SizeT i = num_dims - dim - 1; - __builtin_assume(dims[i] > 0); - idxs[i] = (index / stride) % dims[i]; - stride *= dims[i]; - } -} -} // namespace - -extern "C" { -uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); -} - -uint64_t -__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); -} - -void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); -} - -void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); -} -} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/matmul.hpp b/nac3core/irrt/irrt/ndarray/matmul.hpp new file mode 100644 index 00000000..b0fd4d86 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/matmul.hpp @@ -0,0 +1,98 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/broadcast.hpp" +#include "irrt/ndarray/iter.hpp" + +// NOTE: Everything would be much easier and elegant if einsum is implemented. + +namespace { +namespace ndarray::matmul { + +/** + * @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`. + * + * Example: + * Suppose `a_shape == [1, 97, 4, 2]` + * and `b_shape == [99, 98, 1, 2, 5]`, + * + * ...then `new_a_shape == [99, 98, 97, 4, 2]`, + * `new_b_shape == [99, 98, 97, 2, 5]`, + * and `dst_shape == [99, 98, 97, 4, 5]`. + * ^^^^^^^^^^ ^^^^ + * (broadcasted) (4x2 @ 2x5 => 4x5) + * + * @param a_ndims Length of `a_shape`. + * @param a_shape Shape of `a`. + * @param b_ndims Length of `b_shape`. + * @param b_shape Shape of `b`. + * @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`, + * `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting. + */ +template +void calculate_shapes(SizeT a_ndims, + SizeT* a_shape, + SizeT b_ndims, + SizeT* b_shape, + SizeT final_ndims, + SizeT* new_a_shape, + SizeT* new_b_shape, + SizeT* dst_shape) { + debug_assert(SizeT, a_ndims >= 2); + debug_assert(SizeT, b_ndims >= 2); + debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims); + + // Check that a and b are compatible for matmul + if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) { + // This is a custom error message. Different from NumPy. + raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})", + a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM); + } + + const SizeT num_entries = 2; + ShapeEntry entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape}, + {.ndims = b_ndims - 2, .shape = b_shape}}; + + // TODO: Optimize this + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_a_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_b_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, dst_shape); + + new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1]; + new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2]; + new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1]; + dst_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + dst_shape[final_ndims - 1] = b_shape[b_ndims - 1]; +} +} // namespace ndarray::matmul +} // namespace + +extern "C" { +using namespace ndarray::matmul; + +void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, + int32_t* a_shape, + int32_t b_ndims, + int32_t* b_shape, + int32_t final_ndims, + int32_t* new_a_shape, + int32_t* new_b_shape, + int32_t* dst_shape) { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); +} + +void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, + int64_t* a_shape, + int64_t b_ndims, + int64_t* b_shape, + int64_t final_ndims, + int64_t* new_a_shape, + int64_t* new_b_shape, + int64_t* dst_shape) { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8a002bb3..4b83e63e 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -27,7 +27,7 @@ use super::{ call_memcpy_generic, }, macros::codegen_unreachable, - need_sret, numpy, + need_sret, stmt::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, @@ -1534,26 +1534,35 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let left = ScalarOrNDArray::from_value(generator, ctx, (ty1, left_val)); let right = ScalarOrNDArray::from_value(generator, ctx, (ty2, right_val)); - if op.base == Operator::MatMult { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); + let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1); + let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2); + // Inhomogeneous binary operations are not supported. + assert!(ctx.unifier.unioned(ty1_dtype, ty2_dtype)); + + let common_dtype = ty1_dtype; + let llvm_common_dtype = left.get_dtype(); + + let out = match op.variant { + BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + BinopVariant::AugAssign => { + // Augmented assignment - `left` has to be an ndarray. If it were a scalar then NAC3 + // simply doesn't support it. + if let ScalarOrNDArray::NDArray(out_ndarray) = left { + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } + } else { + panic!("left must be an ndarray") + } + } + }; + + if op.base == Operator::MatMult { let left = left.to_ndarray(generator, ctx); let right = right.to_ndarray(generator, ctx); - - // MatMult is the only binop which is not an elementwise op - let result = numpy::ndarray_matmul_2d( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left), - }, - left, - right, - )?; - - Ok(Some(result.as_base_value().into())) + let result = left + .matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out)) + .split_unsized(generator, ctx); + Ok(Some(result.to_basic_value_enum().into())) } else { // For other operations, they are all elementwise operations. @@ -1565,28 +1574,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( // For all cases, the scalar operand is promoted to an ndarray, // the two are then broadcasted, and starmapped through. - let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1); - let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2); - - // Inhomogeneous binary operations are not supported. - assert!(ctx.unifier.unioned(ty1_dtype, ty2_dtype)); - - let common_dtype = ty1_dtype; - let llvm_common_dtype = left.get_dtype(); - - let out = match op.variant { - BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, - BinopVariant::AugAssign => { - // If this is an augmented assignment. - // `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it. - if let ScalarOrNDArray::NDArray(out_ndarray) = left { - NDArrayOut::WriteToNDArray { ndarray: out_ndarray } - } else { - panic!("left must be an ndarray") - } - } - }; - let left = left.to_ndarray(generator, ctx); let right = right.to_ndarray(generator, ctx); diff --git a/nac3core/src/codegen/irrt/ndarray/matmul.rs b/nac3core/src/codegen/irrt/ndarray/matmul.rs new file mode 100644 index 00000000..551cb7c7 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/matmul.rs @@ -0,0 +1,66 @@ +use inkwell::{types::BasicTypeEnum, values::IntValue}; + +use crate::codegen::{ + expr::infer_and_call_function, irrt::get_usize_dependent_function_name, + values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_matmul_calculate_shapes`. +/// +/// Calculates the broadcasted shapes for `a`, `b`, and the `ndarray` holding the final values of +/// `a @ b`. +#[allow(clippy::too_many_arguments)] +pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + final_ndims: IntValue<'ctx>, + new_a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + + assert_eq!( + BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(b_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(new_a_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(new_b_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + assert_eq!( + BasicTypeEnum::try_from(dst_shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes"); + + infer_and_call_function( + ctx, + &name, + None, + &[ + a_shape.size(ctx, generator).into(), + a_shape.base_ptr(ctx, generator).into(), + b_shape.size(ctx, generator).into(), + b_shape.base_ptr(ctx, generator).into(), + final_ndims.into(), + new_a_shape.base_ptr(ctx, generator).into(), + new_b_shape.base_ptr(ctx, generator).into(), + dst_shape.base_ptr(ctx, generator).into(), + ], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 151795c5..b1530685 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,23 +1,9 @@ -use inkwell::{ - types::BasicTypeEnum, - values::{BasicValueEnum, CallSiteValue, IntValue}, - AddressSpace, -}; -use itertools::Either; - -use super::get_usize_dependent_function_name; -use crate::codegen::{ - values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, - TypedArrayLikeAdapter, - }, - CodeGenContext, CodeGenerator, -}; pub use array::*; pub use basic::*; pub use broadcast::*; pub use indexing::*; pub use iter::*; +pub use matmul::*; pub use reshape::*; pub use transpose::*; @@ -26,119 +12,6 @@ mod basic; mod broadcast; mod indexing; mod iter; +mod matmul; mod reshape; mod transpose; - -/// Generates a call to `__nac3_ndarray_calc_size`. Returns a -/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size. -/// -/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. -/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, -/// or [`None`] if starting from the first dimension and ending at the last dimension -/// respectively. -pub fn call_ndarray_calc_size<'ctx, G, Dims>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - dims: &Dims, - (begin, end): (Option>, Option>), -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Dims: ArrayLikeIndexer<'ctx>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - assert!(begin.is_none_or(|begin| begin.get_type() == llvm_usize)); - assert!(end.is_none_or(|end| end.get_type() == llvm_usize)); - assert_eq!( - BasicTypeEnum::try_from(dims.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - - let ndarray_calc_size_fn_name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_size"); - let ndarray_calc_size_fn_t = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], - false, - ); - let ndarray_calc_size_fn = - ctx.module.get_function(&ndarray_calc_size_fn_name).unwrap_or_else(|| { - ctx.module.add_function(&ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) - }); - - let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); - let end = end.unwrap_or_else(|| dims.size(ctx, generator)); - ctx.builder - .build_call( - ndarray_calc_size_fn, - &[ - dims.base_ptr(ctx, generator).into(), - dims.size(ctx, generator).into(), - begin.into(), - end.into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypedArrayLikeAdapter`] -/// containing `i32` indices of the flattened index. -/// -/// * `index` - The `llvm_usize` index to compute the multidimensional index for. -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - ndarray: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { - let llvm_void = ctx.ctx.void_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - assert_eq!(index.get_type(), llvm_usize); - - let ndarray_calc_nd_indices_fn_name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_nd_indices"); - let ndarray_calc_nd_indices_fn = - ctx.module.get_function(&ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(&ndarray_calc_nd_indices_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.shape(); - - let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); - - ctx.builder - .build_call( - ndarray_calc_nd_indices_fn, - &[ - index.into(), - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), - |_, _, v| v.into_int_value(), - |_, _, v| v.into(), - ) -} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9fe5a972..d46a6119 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,736 +1,23 @@ use inkwell::{ - types::BasicType, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, - IntPredicate, OptimizationLevel, + values::{BasicValue, BasicValueEnum, PointerValue}, + IntPredicate, }; -use nac3parser::ast::{Operator, StrRef}; +use nac3parser::ast::StrRef; use super::{ - expr::gen_binop_expr_with_values, - irrt::{ - calculate_len_for_slice_range, - ndarray::{call_ndarray_calc_nd_indices, call_ndarray_calc_size}, - }, - llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, - stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::ndarray::{factory::ndarray_zero_value, NDArrayType}, - values::{ - ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, - TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, - UntypedArrayLikeMutator, - }, + stmt::gen_for_callback_incrementing, + types::ndarray::NDArrayType, + values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, UntypedArrayLikeAccessor}, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, - typecheck::{ - magic_methods::Binop, - typedef::{FunSignature, Type}, - }, + typecheck::typedef::{FunSignature, Type}, }; -/// Creates an `NDArray` instance from a dynamic shape. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The shape of the `NDArray`. -/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`. -/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`. -fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - shape: &V, - shape_len_fn: LenFn, - shape_data_fn: DataFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result, String>, - DataFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - &V, - IntValue<'ctx>, - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - // Assert that all dimensions are non-negative - let shape_len = shape_len_fn(generator, ctx, shape)?; - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_len, false), - |generator, ctx, _, i| { - let shape_dim = shape_data_fn(generator, ctx, shape, i)?; - debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - - let shape_dim_gez = ctx - .builder - .build_int_compare( - IntPredicate::SGE, - shape_dim, - shape_dim.get_type().const_zero(), - "", - ) - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - // TODO: Disallow shape > u32_MAX - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let num_dims = shape_len_fn(generator, ctx, shape)?; - - let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None) - .construct_dyn_ndims(generator, ctx, num_dims, None); - - // Copy the dimension sizes from shape to ndarray.dims - let shape_len = shape_len_fn(generator, ctx, shape)?; - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_len, false), - |generator, ctx, _, i| { - let shape_dim = shape_data_fn(generator, ctx, shape, i)?; - debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - - let ndarray_pdim = - unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) }; - - ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - unsafe { ndarray.create_data(generator, ctx) }; - - Ok(ndarray) -} - -/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as -/// its input. -fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: NDArrayValue<'ctx>, - value_fn: ValueFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - IntValue<'ctx>, - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndarray_num_elems = ndarray.size(generator, ctx); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (ndarray_num_elems, false), - |generator, ctx, _, i| { - let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) }; - - let value = value_fn(generator, ctx, i)?; - ctx.builder.build_store(elem, value).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) -} - -/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices -/// as its input. -fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: NDArrayValue<'ctx>, - value_fn: ValueFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - &TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>, - ) -> Result, String>, -{ - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| { - let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray); - - value_fn(generator, ctx, &indices) - }) -} - -/// Copies a slice of an [`NDArrayValue`] to another. -/// -/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape` -/// fields should be populated before calling this function. -/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing -/// dimensional slice in the destination array. -/// - `src_arr`: The [`NDArrayValue`] instance of the source array. -/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing -/// dimensional slice in the source array. -/// - `dim`: The index of the currently processing dimension. -/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to -/// this dimension. The `start`/`stop` values of each slice must be non-negative indices. -fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - (src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - dim: u64, - slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], -) -> Result<(), String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type()); - - let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); - - // If there are no (remaining) slice expressions, memcpy the entire dimension - if slices.is_empty() { - let stride = call_ndarray_calc_size( - generator, - ctx, - &src_arr.shape(), - (Some(llvm_usize.const_int(dim, false)), None), - ); - let stride = - ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap(); - - let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap(); - - call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero()); - - return Ok(()); - } - - // The stride of elements in this dimension, i.e. the number of elements between arr[i] and - // arr[i + 1] in this dimension - let src_stride = call_ndarray_calc_size( - generator, - ctx, - &src_arr.shape(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - let dst_stride = call_ndarray_calc_size( - generator, - ctx, - &dst_arr.shape(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - - let (start, stop, step) = slices[0]; - let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap(); - let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap(); - let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap(); - - let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap(); - ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap(); - - gen_for_range_callback( - generator, - ctx, - None, - false, - |_, _| Ok(start), - (|_, _| Ok(stop), true), - |_, _| Ok(step), - |generator, ctx, _, src_i| { - // Calculate the offset of the active slice - let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); - let src_data_offset = ctx - .builder - .build_int_mul( - src_data_offset, - ctx.builder - .build_int_z_extend_or_bit_cast(sizeof_elem, src_data_offset.get_type(), "") - .unwrap(), - "", - ) - .unwrap(); - let dst_i = - ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap(); - let dst_data_offset = ctx - .builder - .build_int_mul( - dst_data_offset, - ctx.builder - .build_int_z_extend_or_bit_cast(sizeof_elem, dst_data_offset.get_type(), "") - .unwrap(), - "", - ) - .unwrap(); - - let (src_ptr, dst_ptr) = unsafe { - ( - ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(), - ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(), - ) - }; - - ndarray_sliced_copyto_impl( - generator, - ctx, - (dst_arr, dst_ptr), - (src_arr, src_ptr), - dim + 1, - &slices[1..], - )?; - - let dst_i = - ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let dst_i_add1 = - ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap(); - ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap(); - - Ok(()) - }, - )?; - - Ok(()) -} - -/// Copies a [`NDArrayValue`] using slices. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to -/// this dimension. The `start`/`stop` values of each slice must be positive indices. -pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - this: NDArrayValue<'ctx>, - slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - - let ndarray = - if slices.is_empty() { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &this, - |_, ctx, shape| Ok(shape.load_ndims(ctx)), - |generator, ctx, shape, idx| unsafe { - Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None)) - }, - )? - } else { - let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None) - .construct_dyn_ndims(generator, ctx, this.load_ndims(ctx), None); - - // Populate the first slices.len() dimensions by computing the size of each dim slice - for (i, (start, stop, step)) in slices.iter().enumerate() { - // HACK: workaround calculate_len_for_slice_range requiring exclusive stop - let stop = ctx - .builder - .build_select( - ctx.builder - .build_int_compare( - IntPredicate::SLT, - *step, - llvm_i32.const_zero(), - "is_neg", - ) - .unwrap(), - ctx.builder - .build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one") - .unwrap(), - ctx.builder - .build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one") - .unwrap(), - "final_e", - ) - .map(BasicValueEnum::into_int_value) - .unwrap(); - - let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step); - let slice_len = - ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap(); - - unsafe { - ndarray.shape().set_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - slice_len, - ); - } - } - - // Populate the rest by directly copying the dim size from the source array - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_int(slices.len() as u64, false), - (this.load_ndims(ctx), false), - |generator, ctx, _, idx| { - unsafe { - let shape = this.shape().get_typed_unchecked(ctx, generator, &idx, None); - ndarray.shape().set_typed_unchecked(ctx, generator, &idx, shape); - } - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - unsafe { ndarray.create_data(generator, ctx) }; - - ndarray - }; - - ndarray_sliced_copyto_impl( - generator, - ctx, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - (this, this.data().base_ptr(ctx, generator)), - 0, - slices, - )?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.copy`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - this: NDArrayValue<'ctx>, -) -> Result, String> { - ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) -} - -/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - res: Option>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - if cfg!(debug_assertions) { - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - - // lhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // rhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - if let Some(res) = res { - let res_ndims = res.load_ndims(ctx); - let res_dim0 = unsafe { - res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let res_dim1 = unsafe { - res.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let lhs_dim0 = unsafe { - lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let rhs_dim1 = unsafe { - rhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - - // res.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare( - IntPredicate::EQ, - res_ndims, - llvm_usize.const_int(2, false), - "", - ) - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[0] == lhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - } - - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let lhs_dim1 = unsafe { - lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - }; - let rhs_dim0 = unsafe { - rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - // lhs.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - - let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) { - ndarray_copy_impl(generator, ctx, elem_ty, lhs)? - } else { - lhs - }; - - let ndarray = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &(lhs, rhs), - |_, _, _| Ok(llvm_usize.const_int(2, false)), - |generator, ctx, (lhs, rhs), idx| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "") - .unwrap()) - }, - |generator, ctx| { - Ok(Some(unsafe { - lhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - })) - }, - |generator, ctx| { - Ok(Some(unsafe { - rhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - })) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value).unwrap()) - }, - ) - .unwrap() - }); - - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| { - llvm_intrinsics::call_expect( - ctx, - idx.size(ctx, generator).get_type().const_int(2, false), - idx.size(ctx, generator), - None, - ); - - let common_dim = { - let lhs_idx1 = unsafe { - lhs.shape().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_idx0 = unsafe { - rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - - ctx.builder.build_int_z_extend_or_bit_cast(idx, llvm_usize, "").unwrap() - }; - - let idx0 = unsafe { - let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - - ctx.builder.build_int_z_extend_or_bit_cast(idx0, llvm_usize, "").unwrap() - }; - let idx1 = unsafe { - let idx1 = - idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); - - ctx.builder.build_int_z_extend_or_bit_cast(idx1, llvm_usize, "").unwrap() - }; - - let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let result_identity = ndarray_zero_value(generator, ctx, elem_ty); - ctx.builder.build_store(result_addr, result_identity).unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (common_dim, false), - |generator, ctx, _, i| { - let ab_idx = generator.gen_array_var_alloc( - ctx, - llvm_usize.into(), - llvm_usize.const_int(2, false), - None, - )?; - - let a = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into()); - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into()); - - lhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - let b = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into()); - ab_idx.set_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - idx1.into(), - ); - - rhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - - let a_mul_b = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), a), - Binop::normal(Operator::Mult), - (&Some(elem_ty), b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - let result = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), result), - Binop::normal(Operator::Add), - (&Some(elem_ty), a_mul_b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - ctx.builder.build_store(result_addr, result).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - Ok(result) - })?; - - Ok(ndarray) -} - /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx>( context: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/types/ndarray/factory.rs b/nac3core/src/codegen/types/ndarray/factory.rs index 300167f7..2d0dca76 100644 --- a/nac3core/src/codegen/types/ndarray/factory.rs +++ b/nac3core/src/codegen/types/ndarray/factory.rs @@ -12,7 +12,7 @@ use crate::{ }; /// Get the zero value in `np.zeros()` of a `dtype`. -pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( +fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, dtype: Type, diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs new file mode 100644 index 00000000..88a94394 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -0,0 +1,334 @@ +use std::cmp::max; + +use nac3parser::ast::Operator; + +use super::{NDArrayOut, NDArrayValue, RustNDIndex}; +use crate::{ + codegen::{ + expr::gen_binop_expr_with_values, + irrt, + stmt::gen_for_callback_incrementing, + types::ndarray::NDArrayType, + values::{ + ArrayLikeValue, ArraySliceValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, + }, + toplevel::helper::arraylike_flatten_element_type, + typecheck::{magic_methods::Binop, typedef::Type}, +}; + +/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`. +/// +/// `dst_dtype` defines the dtype of the returned ndarray. +fn matmul_at_least_2d<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dst_dtype: Type, + (in_a_ty, in_a): (Type, NDArrayValue<'ctx>), + (in_b_ty, in_b): (Type, NDArrayValue<'ctx>), +) -> NDArrayValue<'ctx> { + assert!( + in_a.ndims.is_some_and(|ndims| ndims >= 2), + "in_a (which is {:?}) must be compile-time known and >= 2", + in_a.ndims + ); + assert!( + in_b.ndims.is_some_and(|ndims| ndims >= 2), + "in_b (which is {:?}) must be compile-time known and >= 2", + in_b.ndims + ); + + let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty); + let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty); + + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype); + + // Deduce ndims of the result of matmul. + let ndims_int = max(in_a.ndims.unwrap(), in_b.ndims.unwrap()); + let ndims = llvm_usize.const_int(ndims_int, false); + + // Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the + // destination ndarray to store the result of matmul. + let (lhs, rhs, dst) = { + let in_lhs_ndims = llvm_usize.const_int(in_a.ndims.unwrap(), false); + let in_lhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + in_a.shape().base_ptr(ctx, generator), + in_lhs_ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let in_rhs_ndims = llvm_usize.const_int(in_b.ndims.unwrap(), false); + let in_rhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + in_b.shape().base_ptr(ctx, generator), + in_rhs_ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let lhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(), + ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let rhs_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(), + ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let dst_shape = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val( + ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(), + ndims, + None, + ), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + // Matmul dimension compatibility is checked here. + irrt::ndarray::call_nac3_ndarray_matmul_calculate_shapes( + generator, + ctx, + &in_lhs_shape, + &in_rhs_shape, + ndims, + &lhs_shape, + &rhs_shape, + &dst_shape, + ); + + let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape); + let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape); + + let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, Some(ndims_int)) + .construct_uninitialized(generator, ctx, None); + dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator)); + unsafe { + dst.create_data(generator, ctx); + } + + (lhs, rhs, dst) + }; + + let len = unsafe { + lhs.shape().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(ndims_int - 1, false), + None, + ) + }; + + let at_row = i64::try_from(ndims_int - 2).unwrap(); + let at_col = i64::try_from(ndims_int - 1).unwrap(); + + let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype); + let dst_zero = dst_dtype_llvm.const_zero(); + + dst.foreach(generator, ctx, |generator, ctx, _, hdl| { + let pdst_ij = hdl.get_pointer(ctx); + + ctx.builder.build_store(pdst_ij, dst_zero).unwrap(); + + let indices = hdl.get_indices::(); + let i = unsafe { + indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_row as u64, true), None) + }; + let j = unsafe { + indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_col as u64, true), None) + }; + + let num_0 = llvm_usize.const_int(0, false); + let num_1 = llvm_usize.const_int(1, false); + + gen_for_callback_incrementing( + generator, + ctx, + None, + num_0, + (len, false), + |generator, ctx, _, k| { + // `indices` is modified to index into `a` and `b`, and restored. + unsafe { + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_row as u64, true), + i, + ); + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_col as u64, true), + k.into(), + ); + } + let a_ik = unsafe { lhs.data().get_unchecked(ctx, generator, &indices, None) }; + + unsafe { + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_row as u64, true), + k.into(), + ); + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_col as u64, true), + j, + ); + } + let b_kj = unsafe { rhs.data().get_unchecked(ctx, generator, &indices, None) }; + + // Restore `indices`. + unsafe { + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_row as u64, true), + i, + ); + indices.set_unchecked( + ctx, + generator, + &llvm_usize.const_int(at_col as u64, true), + j, + ); + } + + // x = a_[...]ik * b_[...]kj + let x = gen_binop_expr_with_values( + generator, + ctx, + (&Some(lhs_dtype), a_ik), + Binop::normal(Operator::Mult), + (&Some(rhs_dtype), b_kj), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + + // dst_[...]ij += x + let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap(); + let dst_ij = gen_binop_expr_with_values( + generator, + ctx, + (&Some(dst_dtype), dst_ij), + Binop::normal(Operator::Add), + (&Some(dst_dtype), x), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + ctx.builder.build_store(pdst_ij, dst_ij).unwrap(); + + Ok(()) + }, + num_1, + ) + }) + .unwrap(); + + dst +} + +impl<'ctx> NDArrayValue<'ctx> { + /// Perform [`np.matmul`](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). + /// + /// This function always return an [`NDArrayValue`]. You may want to use + /// [`NDArrayValue::split_unsized`] to handle when the output could be a scalar. + /// + /// `dst_dtype` defines the dtype of the returned ndarray. + #[must_use] + pub fn matmul( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + self_ty: Type, + (other_ty, other): (Type, Self), + (out_dtype, out): (Type, NDArrayOut<'ctx>), + ) -> Self { + // Sanity check, but type inference should prevent this. + assert!( + self.ndims.is_some_and(|ndims| ndims > 0) && other.ndims.is_some_and(|ndims| ndims > 0), + "np.matmul disallows scalar input" + ); + + // If both arguments are 2-D they are multiplied like conventional matrices. + // + // If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the + // last two indices and broadcast accordingly. + // + // If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its + // dimensions. After matrix multiplication the prepended 1 is removed. + // + // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its + // dimensions. After matrix multiplication the appended 1 is removed. + + let new_a = if self.ndims.unwrap() == 1 { + // Prepend 1 to its dimensions + self.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis]) + } else { + *self + }; + + let new_b = if other.ndims.unwrap() == 1 { + // Append 1 to its dimensions + other.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis]) + } else { + other + }; + + // NOTE: `result` will always be a newly allocated ndarray. + // Current implementation cannot do in-place matrix muliplication. + let mut result = + matmul_at_least_2d(generator, ctx, out_dtype, (self_ty, new_a), (other_ty, new_b)); + + // Postprocessing on the result to remove prepended/appended axes. + let mut postindices = vec![]; + let zero = ctx.ctx.i32_type().const_zero(); + + if self.ndims.unwrap() == 1 { + // Remove the prepended 1 + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if other.ndims.unwrap() == 1 { + // Remove the appended 1 + postindices.push(RustNDIndex::Ellipsis); + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if !postindices.is_empty() { + result = result.index(generator, ctx, &postindices); + } + + match out { + NDArrayOut::NewNDArray { .. } => result, + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => { + let result_shape = result.shape(); + out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape); + + out_ndarray.copy_data_from(generator, ctx, result); + out_ndarray + } + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 89f88e74..707c79a2 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -32,6 +32,7 @@ mod broadcast; mod contiguous; mod indexing; mod map; +mod matmul; mod nditer; pub mod shape; mod view; diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 41b39bb8..4332b474 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(254)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 90408d91..60e0c194 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar238]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar238\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index f0418889..46601817 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(251)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 72e54e02..da58d121 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar237, typevar238]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar237\", \"typevar238\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index a8a534cd..8f384fa1 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(257)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(265)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n", ] diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 60972f03..40bbdeab 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -7,12 +7,12 @@ use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop}; use super::{ type_inferencer::*, - typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, + typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, }; use crate::{ symbol_resolver::SymbolValue, toplevel::{ - helper::PrimDef, + helper::{extract_ndims, PrimDef}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, }, }; @@ -175,19 +175,8 @@ pub fn impl_binop( ops: &[Operator], ) { with_fields(unifier, ty, |unifier, fields| { - let (other_ty, other_var_id) = if other_ty.len() == 1 { - (other_ty[0], None) - } else { - let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); - (tvar.ty, Some(tvar.id)) - }; - - let function_vars = if let Some(var_id) = other_var_id { - vec![(var_id, other_ty)].into_iter().collect::() - } else { - VarMap::new() - }; - + let other_tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); + let function_vars = into_var_map([other_tvar]); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) { @@ -198,7 +187,7 @@ pub fn impl_binop( ret: ret_ty, vars: function_vars.clone(), args: vec![FuncArg { - ty: other_ty, + ty: other_tvar.ty, default_value: None, name: "other".into(), is_vararg: false, @@ -541,36 +530,43 @@ pub fn typeof_binop( } } - let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); - let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() + let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); + let lhs_ndims = extract_ndims(unifier, lhs_ndims); + + let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); + let rhs_ndims = extract_ndims(unifier, rhs_ndims); + + if !(unifier.unioned(lhs_dtype, primitives.float) + && unifier.unioned(rhs_dtype, primitives.float)) + { + return Err(format!( + "ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}", + unifier.stringify(lhs), + unifier.stringify(rhs) + )); + } + + // Deduce the ndims of the resulting ndarray. + // If this is 0 (an unsized ndarray), matmul returns a scalar just like NumPy. + let result_ndims = match (lhs_ndims, rhs_ndims) { + (0, _) | (_, 0) => { + return Err( + "ndarray.__matmul__ does not allow unsized ndarray input".to_string() + ) } - _ => unreachable!(), - }; - let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); - let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() - } - _ => unreachable!(), + (1, 1) => 0, + (1, _) => rhs_ndims - 1, + (_, 1) => lhs_ndims - 1, + (m, n) => max(m, n), }; - match (lhs_ndims, rhs_ndims) { - (2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, - (lhs, rhs) if lhs == 0 || rhs == 0 => { - return Err(format!( - "Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", - u8::from(rhs == 0) - )) - } - (lhs, rhs) => { - return Err(format!( - "ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported" - )) - } + if result_ndims == 0 { + // If the result is unsized, NumPy returns a scalar. + primitives.float + } else { + let result_ndims_ty = + unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None); + make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty)) } } @@ -773,7 +769,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); - impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t)); + impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None); impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);