core/ndstrides: implement general ndarray matmul

This commit is contained in:
lyken 2024-08-20 20:37:38 +08:00
parent 6779c82af5
commit 5b4a6dd781
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
8 changed files with 437 additions and 326 deletions

View File

@ -8,6 +8,7 @@
#include <irrt/ndarray/def.hpp>
#include <irrt/ndarray/indexing.hpp>
#include <irrt/ndarray/iter.hpp>
#include <irrt/ndarray/matmul.hpp>
#include <irrt/ndarray/reshape.hpp>
#include <irrt/ndarray/transpose.hpp>
#include <irrt/original.hpp>

View File

@ -0,0 +1,197 @@
#pragma once
#include <irrt/debug.hpp>
#include <irrt/exception.hpp>
#include <irrt/int_types.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/broadcast.hpp>
#include <irrt/ndarray/iter.hpp>
// NOTE: Everything would be much easier and elegant if einsum is implemented.
namespace
{
namespace ndarray
{
namespace matmul
{
/**
* @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`.
*
* Example:
* Suppose `a_shape == [1, 97, 4, 2]`
* and `b_shape == [99, 98, 1, 2, 5]`,
*
* ...then `new_a_shape == [99, 98, 97, 4, 2]`,
* `new_b_shape == [99, 98, 97, 2, 5]`,
* and `dst_shape == [99, 98, 97, 4, 5]`.
* ^^^^^^^^^^ ^^^^
* (broadcasted) (4x2 @ 2x5 => 4x5)
*
* @param a_ndims Length of `a_shape`.
* @param a_shape Shape of `a`.
* @param b_ndims Length of `b_shape`.
* @param b_shape Shape of `b`.
* @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`,
* `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting.
*/
template <typename SizeT>
void calculate_shapes(SizeT a_ndims, SizeT *a_shape, SizeT b_ndims, SizeT *b_shape, SizeT final_ndims,
SizeT *new_a_shape, SizeT *new_b_shape, SizeT *dst_shape)
{
debug_assert(SizeT, a_ndims >= 2);
debug_assert(SizeT, b_ndims >= 2);
debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims);
const SizeT num_entries = 2;
ShapeEntry<SizeT> entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape},
{.ndims = b_ndims - 2, .shape = b_shape}};
// TODO: Optimize this
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_a_shape);
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_b_shape);
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, dst_shape);
new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2];
new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1];
new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2];
new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1];
dst_shape[final_ndims - 2] = a_shape[a_ndims - 2];
dst_shape[final_ndims - 1] = b_shape[b_ndims - 1];
}
/**
* @brief Perform einsum notation, the output is the broadcasts performed by `np.einsum("...ij,...jk->...ik", a, b)`.
*
* This function is an auxillary function to compute a matrix multiplication.
*
* Also see https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy-matmul.
*
* This function expects `dst_ndarray` to contain the following content when called:
* - `dst_ndarray->data` is allocated. Can be uninitialized.
* - `dst_ndarray->itemsize` is set to `sizeof(T)`.
* - `dst_ndarray->ndims` is set to be the correct ndims of the result of the matmul..
* - `dst_ndarray->shape` is set to be the correct shape of the result of the matmul.
* - `dst_ndarray->strides` is ignored.
*/
template <typename SizeT, typename T>
void matmul_at_least_2d(NDArray<SizeT> *a_ndarray, NDArray<SizeT> *b_ndarray, NDArray<SizeT> *dst_ndarray)
{
// All inputs' ndims should be >= 2 and be the same.
debug_assert_eq(SizeT, a_ndarray->ndims, b_ndarray->ndims);
debug_assert_eq(SizeT, a_ndarray->ndims, dst_ndarray->ndims);
debug_assert(SizeT, a_ndarray->ndims >= 2);
debug_assert_eq(SizeT, a_ndarray->itemsize, sizeof(T));
debug_assert_eq(SizeT, b_ndarray->itemsize, sizeof(T));
debug_assert_eq(SizeT, dst_ndarray->itemsize, sizeof(T));
if (IRRT_DEBUG_ASSERT_BOOL)
{
// Check that the shapes are the same.
for (SizeT i = 0; i < a_ndarray->ndims - 2; i++)
{
if (dst_ndarray->shape[0] != a_ndarray->shape[0])
{
raise_debug_assert(SizeT, "Bad shape. At axis {0}, a has {1}, dst has {2}", i, a_ndarray->shape[i],
dst_ndarray->shape[i]);
}
if (dst_ndarray->shape[0] != b_ndarray->shape[0])
{
raise_debug_assert(SizeT, "Bad shape. At axis {0}, b has {1}, dst has {2}", i, b_ndarray->shape[i],
dst_ndarray->shape[i]);
}
}
}
// Number of dimensions dedicated to stacking
// e.g., [4, 6, 1, 2, 3]
// ^^^^^^^ count these
const SizeT u = a_ndarray->ndims - 2; // Alias
SizeT *a_mat_shape = a_ndarray->shape + u;
SizeT *b_mat_shape = b_ndarray->shape + u;
SizeT *dst_mat_shape = dst_ndarray->shape + u;
// Assert that dst_ndarray has the correct shape
debug_assert_eq(SizeT, dst_mat_shape[0], a_mat_shape[0]);
debug_assert_eq(SizeT, dst_mat_shape[1], b_mat_shape[1]);
// Check that a and b are compatible for matmul
if (a_mat_shape[1] != b_mat_shape[0])
{
// This is a custom error message. Different from NumPy.
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})",
a_mat_shape[1], b_mat_shape[0], NO_PARAM);
}
// Iterate through shape[:-2]. i.e,
// Given a = [5, 4, 3, m, p] and b = [5, 4, 3, p, n]. We iterate with shape [5, 4, 3].
SizeT *indices = (SizeT *)__builtin_alloca(sizeof(SizeT) * dst_ndarray->ndims);
SizeT *mat_indices = indices + u;
NDIter<SizeT> iter;
iter.initialize(u, dst_ndarray->shape, dst_ndarray->strides, dst_ndarray->data, indices);
for (; iter.has_next(); iter.next())
{
for (SizeT i = 0; i < dst_mat_shape[0]; i++)
{
for (SizeT j = 0; j < dst_mat_shape[1]; j++)
{
// `indices` is being reused to index into different ndarrays.
mat_indices[0] = i;
mat_indices[1] = j;
T *d = ndarray::basic::get_ptr<SizeT, T>(dst_ndarray, indices);
*d = 0;
for (SizeT k = 0; k < a_ndarray->shape[1]; k++)
{
mat_indices[0] = i;
mat_indices[1] = k;
T *a = ndarray::basic::get_ptr<SizeT, T>(a_ndarray, indices);
mat_indices[0] = k;
mat_indices[1] = j;
T *b = ndarray::basic::get_ptr<SizeT, T>(b_ndarray, indices);
*d += (*a) * (*b);
}
}
}
}
}
} // namespace matmul
} // namespace ndarray
} // namespace
extern "C"
{
using namespace ndarray::matmul;
void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, int32_t *a_shape, int32_t b_ndims, int32_t *b_shape,
int32_t final_ndims, int32_t *new_a_shape, int32_t *new_b_shape,
int32_t *dst_shape)
{
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
}
void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, int64_t *a_shape, int64_t b_ndims, int64_t *b_shape,
int64_t final_ndims, int64_t *new_a_shape, int64_t *new_b_shape,
int64_t *dst_shape)
{
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
}
void __nac3_ndarray_float64_matmul_at_least_2d(NDArray<int32_t> *a_ndarray, NDArray<int32_t> *b_ndarray,
NDArray<int32_t> *dst_ndarray)
{
matmul_at_least_2d<int32_t, double>(a_ndarray, b_ndarray, dst_ndarray);
}
void __nac3_ndarray_float64_matmul_at_least_2d64(NDArray<int64_t> *a_ndarray, NDArray<int64_t> *b_ndarray,
NDArray<int64_t> *dst_ndarray)
{
matmul_at_least_2d<int64_t, double>(a_ndarray, b_ndarray, dst_ndarray);
}
}

View File

@ -1572,7 +1572,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
if op.base == Operator::MatMult {
// Handle matrix multiplication.
todo!()
let left = left.to_ndarray(generator, ctx);
let right = right.to_ndarray(generator, ctx);
let result = NDArrayObject::matmul(generator, ctx, left, right, out)
.split_unsized(generator, ctx);
Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())))
} else {
// For other operations, they are all elementwise operations.

View File

@ -1227,3 +1227,49 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
.arg(axes)
.returning_void();
}
#[allow(clippy::too_many_arguments)]
pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a_ndims: Instance<'ctx, Int<SizeT>>,
a_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
b_ndims: Instance<'ctx, Int<SizeT>>,
b_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
final_ndims: Instance<'ctx, Int<SizeT>>,
new_a_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
new_b_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
dst_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
) {
let name =
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes");
CallFunction::begin(generator, ctx, &name)
.arg(a_ndims)
.arg(a_shape)
.arg(b_ndims)
.arg(b_shape)
.arg(final_ndims)
.arg(new_a_shape)
.arg(new_b_shape)
.arg(dst_shape)
.returning_void();
}
pub fn call_nac3_ndarray_float64_matmul_at_least_2d<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
b_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) {
let name = get_sizet_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_float64_matmul_at_least_2d",
);
CallFunction::begin(generator, ctx, &name)
.arg(a_ndarray)
.arg(b_ndarray)
.arg(dst_ndarray)
.returning_void();
}

View File

@ -1437,302 +1437,6 @@ where
Ok(ndarray)
}
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
/// written to a new `ndarray`.
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
if cfg!(debug_assertions) {
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
// lhs.ndims == 2
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
// rhs.ndims == 2
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
if let Some(res) = res {
let res_ndims = res.load_ndims(ctx);
let res_dim0 = unsafe {
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let res_dim1 = unsafe {
res.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
let lhs_dim0 = unsafe {
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let rhs_dim1 = unsafe {
rhs.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
// res.ndims == 2
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::EQ,
res_ndims,
llvm_usize.const_int(2, false),
"",
)
.unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
// res.dims[0] == lhs.dims[0]
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
// res.dims[1] == rhs.dims[0]
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
}
}
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let lhs_dim1 = unsafe {
lhs.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
let rhs_dim0 = unsafe {
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
// lhs.dims[1] == rhs.dims[0]
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
}
let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) {
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
} else {
lhs
};
let ndarray = res.unwrap_or_else(|| {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&(lhs, rhs),
|_, _, _| Ok(llvm_usize.const_int(2, false)),
|generator, ctx, (lhs, rhs), idx| {
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "")
.unwrap())
},
|generator, ctx| {
Ok(Some(unsafe {
lhs.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
}))
},
|generator, ctx| {
Ok(Some(unsafe {
rhs.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
}))
},
)
.map(|v| v.map(BasicValueEnum::into_int_value).unwrap())
},
)
.unwrap()
});
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| {
llvm_intrinsics::call_expect(
ctx,
idx.size(ctx, generator).get_type().const_int(2, false),
idx.size(ctx, generator),
None,
);
let common_dim = {
let lhs_idx1 = unsafe {
lhs.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
let rhs_idx0 = unsafe {
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap()
};
let idx0 = unsafe {
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap()
};
let idx1 = unsafe {
let idx1 =
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap()
};
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
let result_identity = ndarray_zero_value(generator, ctx, elem_ty);
ctx.builder.build_store(result_addr, result_identity).unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_i32.const_zero(),
(common_dim, false),
|generator, ctx, _, i| {
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
let ab_idx = generator.gen_array_var_alloc(
ctx,
llvm_i32.into(),
llvm_usize.const_int(2, false),
None,
)?;
let a = unsafe {
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into());
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into());
lhs.data().get_unchecked(ctx, generator, &ab_idx, None)
};
let b = unsafe {
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into());
ab_idx.set_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
idx1.into(),
);
rhs.data().get_unchecked(ctx, generator, &ab_idx, None)
};
let a_mul_b = gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), a),
Binop::normal(Operator::Mult),
(&Some(elem_ty), b),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, elem_ty)?;
let result = ctx.builder.build_load(result_addr, "").unwrap();
let result = gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), result),
Binop::normal(Operator::Add),
(&Some(elem_ty), a_mul_b),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, elem_ty)?;
ctx.builder.build_store(result_addr, result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let result = ctx.builder.build_load(result_addr, "").unwrap();
Ok(result)
})?;
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,

View File

@ -0,0 +1,153 @@
use std::cmp::max;
use crate::codegen::{
irrt::{
call_nac3_ndarray_float64_matmul_at_least_2d, call_nac3_ndarray_matmul_calculate_shapes,
},
model::*,
object::ndarray::indexing::RustNDIndex,
CodeGenContext, CodeGenerator,
};
use super::{NDArrayObject, NDArrayOut};
/// Perform `np.einsum("...ij,...jk->...ik", a, b)`.
fn matmul_helper<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: NDArrayObject<'ctx>,
b: NDArrayObject<'ctx>,
) -> NDArrayObject<'ctx> {
assert!(a.ndims >= 2);
assert!(b.ndims >= 2);
assert!(ctx.unifier.unioned(ctx.primitives.float, a.dtype));
assert!(ctx.unifier.unioned(ctx.primitives.float, b.dtype));
let final_ndims_int = max(a.ndims, b.ndims);
let a_ndims = a.ndims_llvm(generator, ctx.ctx);
let a_shape = a.instance.get(generator, ctx, |f| f.shape);
let b_ndims = b.ndims_llvm(generator, ctx.ctx);
let b_shape = b.instance.get(generator, ctx, |f| f.shape);
let final_ndims = Int(SizeT).const_int(generator, ctx.ctx, final_ndims_int);
let new_a_shape = Int(SizeT).array_alloca(generator, ctx, final_ndims.value);
let new_b_shape = Int(SizeT).array_alloca(generator, ctx, final_ndims.value);
let dst_shape = Int(SizeT).array_alloca(generator, ctx, final_ndims.value);
call_nac3_ndarray_matmul_calculate_shapes(
generator,
ctx,
a_ndims,
a_shape,
b_ndims,
b_shape,
final_ndims,
new_a_shape,
new_b_shape,
dst_shape,
);
let dst = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, final_ndims_int);
dst.copy_shape_from_array(generator, ctx, dst_shape);
dst.create_data(generator, ctx);
let new_a = a.broadcast_to(generator, ctx, final_ndims_int, new_a_shape);
let new_b = b.broadcast_to(generator, ctx, final_ndims_int, new_b_shape);
call_nac3_ndarray_float64_matmul_at_least_2d(
generator,
ctx,
new_a.instance,
new_b.instance,
dst.instance,
);
dst
}
impl<'ctx> NDArrayObject<'ctx> {
/// Perform `np.matmul` according to the rules in
/// <https://numpy.org/doc/stable/reference/generated/numpy.matmul.html>.
///
/// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`]
/// to handle when the output could be a scalar.
pub fn matmul<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: Self,
b: Self,
out: NDArrayOut<'ctx>,
) -> Self {
// Sanity check, but type inference should prevent this.
assert!(a.ndims > 0 && b.ndims > 0, "np.matmul disallows scalar input");
/*
If both arguments are 2-D they are multiplied like conventional matrices.
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indices and broadcast accordingly.
If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
*/
let new_a = if a.ndims == 1 {
// Prepend 1 to its dimensions
a.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis])
} else {
a
};
let new_b = if b.ndims == 1 {
// Append 1 to its dimensions
b.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis])
} else {
b
};
// NOTE: `result` will always be a newly allocated ndarray.
// Current implementation cannot do in-place matrix muliplication.
let mut result = matmul_helper(generator, ctx, new_a, new_b);
let zero = Int(Int32).const_0(generator, ctx.ctx);
// Postprocessing on the result to remove prepended/appended axes.
let mut postindices = vec![];
if a.ndims == 1 {
// Remove the prepended 1
postindices.push(RustNDIndex::SingleElement(zero));
}
if b.ndims == 1 {
// Remove the appended 1
postindices.push(RustNDIndex::Ellipsis);
postindices.push(RustNDIndex::SingleElement(zero));
}
if !postindices.is_empty() {
result = result.index(generator, ctx, &postindices);
}
match out {
NDArrayOut::NewNDArray { dtype } => {
// We don't support auto-casting right now, nor anything other than float64.
// Force the output dtype to be float64.
assert!(ctx.unifier.unioned(ctx.primitives.float, dtype));
result
}
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
// TODO: It is possible to check the shapes before computing the matmul to die
// quicker to save computes.
let result_shape = result.instance.get(generator, ctx, |f| f.shape);
out_ndarray.assert_can_be_written_by_out(
generator,
ctx,
result.ndims,
result_shape,
);
out_ndarray.copy_data_from(generator, ctx, result);
out_ndarray
}
}
}
}

View File

@ -3,6 +3,7 @@ pub mod broadcast;
pub mod factory;
pub mod indexing;
pub mod map;
pub mod matmul;
pub mod nditer;
pub mod shape_util;
pub mod view;

View File

@ -1,5 +1,5 @@
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::toplevel::helper::{extract_ndims, PrimDef};
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use crate::typecheck::{
type_inferencer::*,
@ -520,36 +520,41 @@ pub fn typeof_binop(
}
Operator::MatMult => {
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = extract_ndims(unifier, lhs_ndims);
let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
let rhs_ndims = extract_ndims(unifier, rhs_ndims);
if !(unifier.unioned(lhs_dtype, primitives.float)
&& unifier.unioned(rhs_dtype, primitives.float))
{
return Err(format!(
"ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}",
unifier.stringify(lhs),
unifier.stringify(rhs)
));
}
_ => unreachable!(),
};
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
let result_ndims = match (lhs_ndims, rhs_ndims) {
(0, _) | (_, 0) => {
return Err(
"ndarray.__matmul__ does not allow unsized ndarray input".to_string()
)
}
_ => unreachable!(),
(1, 1) => 0,
(1, _) => rhs_ndims - 1,
(_, 1) => lhs_ndims - 1,
(m, n) => max(m, n),
};
match (lhs_ndims, rhs_ndims) {
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
(lhs, rhs) if lhs == 0 || rhs == 0 => {
return Err(format!(
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
u8::from(rhs == 0)
))
}
(lhs, rhs) => {
return Err(format!(
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
))
}
if result_ndims == 0 {
// If the result is unsized, NumPy returns a scalar.
primitives.float
} else {
let result_ndims_ty =
unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None);
make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty))
}
}
@ -748,7 +753,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None);
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);