1
0
forked from M-Labs/nac3

core/ndstrides: implement general ndarray matmul

This commit is contained in:
lyken 2024-08-20 20:37:38 +08:00
parent ae351f7678
commit 4fef633090
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
8 changed files with 367 additions and 326 deletions

View File

@ -8,6 +8,7 @@
#include <irrt/ndarray/def.hpp>
#include <irrt/ndarray/indexing.hpp>
#include <irrt/ndarray/iter.hpp>
#include <irrt/ndarray/matmul.hpp>
#include <irrt/ndarray/reshape.hpp>
#include <irrt/ndarray/transpose.hpp>
#include <irrt/original.hpp>

View File

@ -0,0 +1,92 @@
#pragma once
#include <irrt/debug.hpp>
#include <irrt/exception.hpp>
#include <irrt/int_types.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/broadcast.hpp>
#include <irrt/ndarray/iter.hpp>
// NOTE: Everything would be much easier and elegant if einsum is implemented.
namespace
{
namespace ndarray
{
namespace matmul
{
/**
* @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`.
*
* Example:
* Suppose `a_shape == [1, 97, 4, 2]`
* and `b_shape == [99, 98, 1, 2, 5]`,
*
* ...then `new_a_shape == [99, 98, 97, 4, 2]`,
* `new_b_shape == [99, 98, 97, 2, 5]`,
* and `dst_shape == [99, 98, 97, 4, 5]`.
* ^^^^^^^^^^ ^^^^
* (broadcasted) (4x2 @ 2x5 => 4x5)
*
* @param a_ndims Length of `a_shape`.
* @param a_shape Shape of `a`.
* @param b_ndims Length of `b_shape`.
* @param b_shape Shape of `b`.
* @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`,
* `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting.
*/
template <typename SizeT>
void calculate_shapes(SizeT a_ndims, SizeT *a_shape, SizeT b_ndims, SizeT *b_shape, SizeT final_ndims,
SizeT *new_a_shape, SizeT *new_b_shape, SizeT *dst_shape)
{
debug_assert(SizeT, a_ndims >= 2);
debug_assert(SizeT, b_ndims >= 2);
debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims);
// Check that a and b are compatible for matmul
if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2])
{
// This is a custom error message. Different from NumPy.
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})",
a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM);
}
const SizeT num_entries = 2;
ShapeEntry<SizeT> entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape},
{.ndims = b_ndims - 2, .shape = b_shape}};
// TODO: Optimize this
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_a_shape);
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_b_shape);
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, dst_shape);
new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2];
new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1];
new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2];
new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1];
dst_shape[final_ndims - 2] = a_shape[a_ndims - 2];
dst_shape[final_ndims - 1] = b_shape[b_ndims - 1];
}
} // namespace matmul
} // namespace ndarray
} // namespace
extern "C"
{
using namespace ndarray::matmul;
void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, int32_t *a_shape, int32_t b_ndims, int32_t *b_shape,
int32_t final_ndims, int32_t *new_a_shape, int32_t *new_b_shape,
int32_t *dst_shape)
{
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
}
void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, int64_t *a_shape, int64_t b_ndims, int64_t *b_shape,
int64_t final_ndims, int64_t *new_a_shape, int64_t *new_b_shape,
int64_t *dst_shape)
{
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
}
}

View File

@ -1573,7 +1573,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
if op.base == Operator::MatMult {
// Handle matrix multiplication.
todo!()
let left = left.to_ndarray(generator, ctx);
let right = right.to_ndarray(generator, ctx);
let result = NDArrayObject::matmul(generator, ctx, left, right, out)
.split_unsized(generator, ctx);
Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())))
} else {
// For other operations, they are all elementwise operations.

View File

@ -1231,3 +1231,30 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
.arg(axes)
.returning_void();
}
#[allow(clippy::too_many_arguments)]
pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a_ndims: Instance<'ctx, Int<SizeT>>,
a_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
b_ndims: Instance<'ctx, Int<SizeT>>,
b_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
final_ndims: Instance<'ctx, Int<SizeT>>,
new_a_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
new_b_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
dst_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let name =
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes");
CallFunction::begin(generator, ctx, &name)
.arg(a_ndims)
.arg(a_shape)
.arg(b_ndims)
.arg(b_shape)
.arg(final_ndims)
.arg(new_a_shape)
.arg(new_b_shape)
.arg(dst_shape)
.returning_void();
}

View File

@ -1437,302 +1437,6 @@ where
Ok(ndarray)
}
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
/// written to a new `ndarray`.
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
if cfg!(debug_assertions) {
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
// lhs.ndims == 2
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
// rhs.ndims == 2
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
if let Some(res) = res {
let res_ndims = res.load_ndims(ctx);
let res_dim0 = unsafe {
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let res_dim1 = unsafe {
res.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
let lhs_dim0 = unsafe {
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let rhs_dim1 = unsafe {
rhs.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
// res.ndims == 2
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::EQ,
res_ndims,
llvm_usize.const_int(2, false),
"",
)
.unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
// res.dims[0] == lhs.dims[0]
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
// res.dims[1] == rhs.dims[0]
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
}
}
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let lhs_dim1 = unsafe {
lhs.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
let rhs_dim0 = unsafe {
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
// lhs.dims[1] == rhs.dims[0]
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
}
let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) {
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
} else {
lhs
};
let ndarray = res.unwrap_or_else(|| {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&(lhs, rhs),
|_, _, _| Ok(llvm_usize.const_int(2, false)),
|generator, ctx, (lhs, rhs), idx| {
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "")
.unwrap())
},
|generator, ctx| {
Ok(Some(unsafe {
lhs.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
}))
},
|generator, ctx| {
Ok(Some(unsafe {
rhs.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
}))
},
)
.map(|v| v.map(BasicValueEnum::into_int_value).unwrap())
},
)
.unwrap()
});
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| {
llvm_intrinsics::call_expect(
ctx,
idx.size(ctx, generator).get_type().const_int(2, false),
idx.size(ctx, generator),
None,
);
let common_dim = {
let lhs_idx1 = unsafe {
lhs.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
let rhs_idx0 = unsafe {
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap()
};
let idx0 = unsafe {
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap()
};
let idx1 = unsafe {
let idx1 =
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap()
};
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
let result_identity = ndarray_zero_value(generator, ctx, elem_ty);
ctx.builder.build_store(result_addr, result_identity).unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_i32.const_zero(),
(common_dim, false),
|generator, ctx, _, i| {
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
let ab_idx = generator.gen_array_var_alloc(
ctx,
llvm_i32.into(),
llvm_usize.const_int(2, false),
None,
)?;
let a = unsafe {
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into());
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into());
lhs.data().get_unchecked(ctx, generator, &ab_idx, None)
};
let b = unsafe {
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into());
ab_idx.set_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
idx1.into(),
);
rhs.data().get_unchecked(ctx, generator, &ab_idx, None)
};
let a_mul_b = gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), a),
Binop::normal(Operator::Mult),
(&Some(elem_ty), b),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, elem_ty)?;
let result = ctx.builder.build_load(result_addr, "").unwrap();
let result = gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), result),
Binop::normal(Operator::Add),
(&Some(elem_ty), a_mul_b),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, elem_ty)?;
ctx.builder.build_store(result_addr, result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let result = ctx.builder.build_load(result_addr, "").unwrap();
Ok(result)
})?;
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,

View File

@ -0,0 +1,207 @@
use std::cmp::max;
use nac3parser::ast::Operator;
use util::gen_for_model;
use crate::{
codegen::{
expr::gen_binop_expr_with_values, irrt::call_nac3_ndarray_matmul_calculate_shapes,
model::*, object::ndarray::indexing::RustNDIndex, CodeGenContext, CodeGenerator,
},
typecheck::{magic_methods::Binop, typedef::Type},
};
use super::{NDArrayObject, NDArrayOut};
/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`.
///
/// `dst_dtype` defines the dtype of the returned ndarray.
fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dst_dtype: Type,
in_a: NDArrayObject<'ctx>,
in_b: NDArrayObject<'ctx>,
) -> NDArrayObject<'ctx> {
assert!(in_a.ndims >= 2);
assert!(in_b.ndims >= 2);
// Deduce ndims of the result of matmul.
let ndims_int = max(in_a.ndims, in_b.ndims);
let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int);
let num_0 = Int(SizeT).const_int(generator, ctx.ctx, 0);
let num_1 = Int(SizeT).const_int(generator, ctx.ctx, 1);
// Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the
// destination ndarray to store the result of matmul.
let (a, b, dst) = {
let in_a_ndims = in_a.ndims_llvm(generator, ctx.ctx);
let in_a_shape = in_a.instance.get(generator, ctx, |f| f.shape);
let in_b_ndims = in_b.ndims_llvm(generator, ctx.ctx);
let in_b_shape = in_b.instance.get(generator, ctx, |f| f.shape);
let a_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
let b_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
let dst_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
// Matmul dimension compatibility is checked here.
call_nac3_ndarray_matmul_calculate_shapes(
generator, ctx, in_a_ndims, in_a_shape, in_b_ndims, in_b_shape, ndims, a_shape,
b_shape, dst_shape,
);
let a = in_a.broadcast_to(generator, ctx, ndims_int, a_shape);
let b = in_b.broadcast_to(generator, ctx, ndims_int, b_shape);
let dst = NDArrayObject::alloca(generator, ctx, dst_dtype, ndims_int);
dst.copy_shape_from_array(generator, ctx, dst_shape);
dst.create_data(generator, ctx);
(a, b, dst)
};
let len =
a.instance.get(generator, ctx, |f| f.shape).get_index_const(generator, ctx, ndims_int - 1);
let at_row = ndims_int - 2;
let at_col = ndims_int - 1;
let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype);
let dst_zero = dst_dtype_llvm.const_zero();
dst.foreach(generator, ctx, |generator, ctx, _, hdl| {
let pdst_ij = hdl.get_pointer(generator, ctx);
ctx.builder.build_store(pdst_ij, dst_zero).unwrap();
let indices = hdl.get_indices();
let i = indices.get_index_const(generator, ctx, at_row);
let j = indices.get_index_const(generator, ctx, at_col);
gen_for_model(generator, ctx, num_0, len, num_1, |generator, ctx, _, k| {
// `indices` is modified to index into `a` and `b`, and restored.
indices.set_index_const(ctx, at_row, i);
indices.set_index_const(ctx, at_col, k);
let a_ik = a.get_scalar_by_indices(generator, ctx, indices);
indices.set_index_const(ctx, at_row, k);
indices.set_index_const(ctx, at_col, j);
let b_kj = b.get_scalar_by_indices(generator, ctx, indices);
// Restore `indices`.
indices.set_index_const(ctx, at_row, i);
indices.set_index_const(ctx, at_col, j);
// x = a_[...]ik * b_[...]kj
let x = gen_binop_expr_with_values(
generator,
ctx,
(&Some(a.dtype), a_ik.value),
Binop::normal(Operator::Mult),
(&Some(b.dtype), b_kj.value),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, dst_dtype)?;
// dst_[...]ij += x
let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap();
let dst_ij = gen_binop_expr_with_values(
generator,
ctx,
(&Some(dst_dtype), dst_ij),
Binop::normal(Operator::Add),
(&Some(dst_dtype), x),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, dst_dtype)?;
ctx.builder.build_store(pdst_ij, dst_ij).unwrap();
Ok(())
})
})
.unwrap();
dst
}
impl<'ctx> NDArrayObject<'ctx> {
/// Perform `np.matmul` according to the rules in
/// <https://numpy.org/doc/stable/reference/generated/numpy.matmul.html>.
///
/// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`]
/// to handle when the output could be a scalar.
///
/// `dst_dtype` defines the dtype of the returned ndarray.
pub fn matmul<G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: Self,
b: Self,
out: NDArrayOut<'ctx>,
) -> Self {
// Sanity check, but type inference should prevent this.
assert!(a.ndims > 0 && b.ndims > 0, "np.matmul disallows scalar input");
/*
If both arguments are 2-D they are multiplied like conventional matrices.
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indices and broadcast accordingly.
If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
*/
let new_a = if a.ndims == 1 {
// Prepend 1 to its dimensions
a.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis])
} else {
a
};
let new_b = if b.ndims == 1 {
// Append 1 to its dimensions
b.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis])
} else {
b
};
// NOTE: `result` will always be a newly allocated ndarray.
// Current implementation cannot do in-place matrix muliplication.
let mut result = matmul_at_least_2d(generator, ctx, out.get_dtype(), new_a, new_b);
// Postprocessing on the result to remove prepended/appended axes.
let mut postindices = vec![];
let zero = Int(Int32).const_0(generator, ctx.ctx);
if a.ndims == 1 {
// Remove the prepended 1
postindices.push(RustNDIndex::SingleElement(zero));
}
if b.ndims == 1 {
// Remove the appended 1
postindices.push(RustNDIndex::Ellipsis);
postindices.push(RustNDIndex::SingleElement(zero));
}
if !postindices.is_empty() {
result = result.index(generator, ctx, &postindices);
}
match out {
NDArrayOut::NewNDArray { .. } => result,
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
let result_shape = result.instance.get(generator, ctx, |f| f.shape);
out_ndarray.assert_can_be_written_by_out(
generator,
ctx,
result.ndims,
result_shape,
);
out_ndarray.copy_data_from(generator, ctx, result);
out_ndarray
}
}
}
}

View File

@ -3,6 +3,7 @@ pub mod broadcast;
pub mod factory;
pub mod indexing;
pub mod map;
pub mod matmul;
pub mod nditer;
pub mod shape_util;
pub mod view;

View File

@ -1,5 +1,5 @@
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::toplevel::helper::{extract_ndims, PrimDef};
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use crate::typecheck::{
type_inferencer::*,
@ -520,36 +520,41 @@ pub fn typeof_binop(
}
Operator::MatMult => {
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = extract_ndims(unifier, lhs_ndims);
let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
let rhs_ndims = extract_ndims(unifier, rhs_ndims);
if !(unifier.unioned(lhs_dtype, primitives.float)
&& unifier.unioned(rhs_dtype, primitives.float))
{
return Err(format!(
"ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}",
unifier.stringify(lhs),
unifier.stringify(rhs)
));
}
let result_ndims = match (lhs_ndims, rhs_ndims) {
(0, _) | (_, 0) => {
return Err(
"ndarray.__matmul__ does not allow unsized ndarray input".to_string()
)
}
_ => unreachable!(),
};
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
}
_ => unreachable!(),
(1, 1) => 0,
(1, _) => rhs_ndims - 1,
(_, 1) => lhs_ndims - 1,
(m, n) => max(m, n),
};
match (lhs_ndims, rhs_ndims) {
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
(lhs, rhs) if lhs == 0 || rhs == 0 => {
return Err(format!(
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
u8::from(rhs == 0)
))
}
(lhs, rhs) => {
return Err(format!(
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
))
}
if result_ndims == 0 {
// If the result is unsized, NumPy returns a scalar.
primitives.float
} else {
let result_ndims_ty =
unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None);
make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty))
}
}
@ -752,7 +757,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None);
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);