core/ndstrides: implement general ndarray matmul
This commit is contained in:
parent
e20a005437
commit
685e2cb0cd
|
@ -8,6 +8,7 @@
|
||||||
#include <irrt/ndarray/def.hpp>
|
#include <irrt/ndarray/def.hpp>
|
||||||
#include <irrt/ndarray/indexing.hpp>
|
#include <irrt/ndarray/indexing.hpp>
|
||||||
#include <irrt/ndarray/iter.hpp>
|
#include <irrt/ndarray/iter.hpp>
|
||||||
|
#include <irrt/ndarray/matmul.hpp>
|
||||||
#include <irrt/ndarray/reshape.hpp>
|
#include <irrt/ndarray/reshape.hpp>
|
||||||
#include <irrt/ndarray/transpose.hpp>
|
#include <irrt/ndarray/transpose.hpp>
|
||||||
#include <irrt/original.hpp>
|
#include <irrt/original.hpp>
|
||||||
|
|
|
@ -0,0 +1,197 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/debug.hpp>
|
||||||
|
#include <irrt/exception.hpp>
|
||||||
|
#include <irrt/int_types.hpp>
|
||||||
|
#include <irrt/ndarray/basic.hpp>
|
||||||
|
#include <irrt/ndarray/broadcast.hpp>
|
||||||
|
#include <irrt/ndarray/iter.hpp>
|
||||||
|
|
||||||
|
// NOTE: Everything would be much easier and elegant if einsum is implemented.
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
namespace ndarray
|
||||||
|
{
|
||||||
|
namespace matmul
|
||||||
|
{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* Suppose `a_shape == [1, 97, 4, 2]`
|
||||||
|
* and `b_shape == [99, 98, 1, 2, 5]`,
|
||||||
|
*
|
||||||
|
* ...then `new_a_shape == [99, 98, 97, 4, 2]`,
|
||||||
|
* `new_b_shape == [99, 98, 97, 2, 5]`,
|
||||||
|
* and `dst_shape == [99, 98, 97, 4, 5]`.
|
||||||
|
* ^^^^^^^^^^ ^^^^
|
||||||
|
* (broadcasted) (4x2 @ 2x5 => 4x5)
|
||||||
|
*
|
||||||
|
* @param a_ndims Length of `a_shape`.
|
||||||
|
* @param a_shape Shape of `a`.
|
||||||
|
* @param b_ndims Length of `b_shape`.
|
||||||
|
* @param b_shape Shape of `b`.
|
||||||
|
* @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`,
|
||||||
|
* `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting.
|
||||||
|
*/
|
||||||
|
template <typename SizeT>
|
||||||
|
void calculate_shapes(SizeT a_ndims, SizeT *a_shape, SizeT b_ndims, SizeT *b_shape, SizeT final_ndims,
|
||||||
|
SizeT *new_a_shape, SizeT *new_b_shape, SizeT *dst_shape)
|
||||||
|
{
|
||||||
|
debug_assert(SizeT, a_ndims >= 2);
|
||||||
|
debug_assert(SizeT, b_ndims >= 2);
|
||||||
|
debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims);
|
||||||
|
|
||||||
|
const SizeT num_entries = 2;
|
||||||
|
ShapeEntry<SizeT> entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape},
|
||||||
|
{.ndims = b_ndims - 2, .shape = b_shape}};
|
||||||
|
|
||||||
|
// TODO: Optimize this
|
||||||
|
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_a_shape);
|
||||||
|
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_b_shape);
|
||||||
|
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, dst_shape);
|
||||||
|
|
||||||
|
new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
||||||
|
new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1];
|
||||||
|
new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2];
|
||||||
|
new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
||||||
|
dst_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
||||||
|
dst_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Perform einsum notation, the output is the broadcasts performed by `np.einsum("...ij,...jk->...ik", a, b)`.
|
||||||
|
*
|
||||||
|
* This function is an auxillary function to compute a matrix multiplication.
|
||||||
|
*
|
||||||
|
* Also see https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy-matmul.
|
||||||
|
*
|
||||||
|
* This function expects `dst_ndarray` to contain the following content when called:
|
||||||
|
* - `dst_ndarray->data` is allocated. Can be uninitialized.
|
||||||
|
* - `dst_ndarray->itemsize` is set to `sizeof(T)`.
|
||||||
|
* - `dst_ndarray->ndims` is set to be the correct ndims of the result of the matmul..
|
||||||
|
* - `dst_ndarray->shape` is set to be the correct shape of the result of the matmul.
|
||||||
|
* - `dst_ndarray->strides` is ignored.
|
||||||
|
*/
|
||||||
|
template <typename SizeT, typename T>
|
||||||
|
void matmul_at_least_2d(NDArray<SizeT> *a_ndarray, NDArray<SizeT> *b_ndarray, NDArray<SizeT> *dst_ndarray)
|
||||||
|
{
|
||||||
|
// All inputs' ndims should be >= 2 and be the same.
|
||||||
|
debug_assert_eq(SizeT, a_ndarray->ndims, b_ndarray->ndims);
|
||||||
|
debug_assert_eq(SizeT, a_ndarray->ndims, dst_ndarray->ndims);
|
||||||
|
debug_assert(SizeT, a_ndarray->ndims >= 2);
|
||||||
|
|
||||||
|
debug_assert_eq(SizeT, a_ndarray->itemsize, sizeof(T));
|
||||||
|
debug_assert_eq(SizeT, b_ndarray->itemsize, sizeof(T));
|
||||||
|
debug_assert_eq(SizeT, dst_ndarray->itemsize, sizeof(T));
|
||||||
|
|
||||||
|
if (IRRT_DEBUG_ASSERT_BOOL)
|
||||||
|
{
|
||||||
|
// Check that the shapes are the same.
|
||||||
|
for (SizeT i = 0; i < a_ndarray->ndims - 2; i++)
|
||||||
|
{
|
||||||
|
if (dst_ndarray->shape[0] != a_ndarray->shape[0])
|
||||||
|
{
|
||||||
|
raise_debug_assert(SizeT, "Bad shape. At axis {0}, a has {1}, dst has {2}", i, a_ndarray->shape[i],
|
||||||
|
dst_ndarray->shape[i]);
|
||||||
|
}
|
||||||
|
if (dst_ndarray->shape[0] != b_ndarray->shape[0])
|
||||||
|
{
|
||||||
|
raise_debug_assert(SizeT, "Bad shape. At axis {0}, b has {1}, dst has {2}", i, b_ndarray->shape[i],
|
||||||
|
dst_ndarray->shape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number of dimensions dedicated to stacking
|
||||||
|
// e.g., [4, 6, 1, 2, 3]
|
||||||
|
// ^^^^^^^ count these
|
||||||
|
const SizeT u = a_ndarray->ndims - 2; // Alias
|
||||||
|
|
||||||
|
SizeT *a_mat_shape = a_ndarray->shape + u;
|
||||||
|
SizeT *b_mat_shape = b_ndarray->shape + u;
|
||||||
|
SizeT *dst_mat_shape = dst_ndarray->shape + u;
|
||||||
|
|
||||||
|
// Assert that dst_ndarray has the correct shape
|
||||||
|
debug_assert_eq(SizeT, dst_mat_shape[0], a_mat_shape[0]);
|
||||||
|
debug_assert_eq(SizeT, dst_mat_shape[1], b_mat_shape[1]);
|
||||||
|
|
||||||
|
// Check that a and b are compatible for matmul
|
||||||
|
if (a_mat_shape[1] != b_mat_shape[0])
|
||||||
|
{
|
||||||
|
// This is a custom error message. Different from NumPy.
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})",
|
||||||
|
a_mat_shape[1], b_mat_shape[0], NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate through shape[:-2]. i.e,
|
||||||
|
// Given a = [5, 4, 3, m, p] and b = [5, 4, 3, p, n]. We iterate with shape [5, 4, 3].
|
||||||
|
SizeT *indices = (SizeT *)__builtin_alloca(sizeof(SizeT) * dst_ndarray->ndims);
|
||||||
|
SizeT *mat_indices = indices + u;
|
||||||
|
NDIter<SizeT> iter;
|
||||||
|
iter.initialize(u, dst_ndarray->shape, dst_ndarray->strides, dst_ndarray->data, indices);
|
||||||
|
|
||||||
|
for (; iter.has_next(); iter.next())
|
||||||
|
{
|
||||||
|
for (SizeT i = 0; i < dst_mat_shape[0]; i++)
|
||||||
|
{
|
||||||
|
for (SizeT j = 0; j < dst_mat_shape[1]; j++)
|
||||||
|
{
|
||||||
|
// `indices` is being reused to index into different ndarrays.
|
||||||
|
mat_indices[0] = i;
|
||||||
|
mat_indices[1] = j;
|
||||||
|
T *d = ndarray::basic::get_ptr<SizeT, T>(dst_ndarray, indices);
|
||||||
|
*d = 0;
|
||||||
|
|
||||||
|
for (SizeT k = 0; k < a_ndarray->shape[1]; k++)
|
||||||
|
{
|
||||||
|
mat_indices[0] = i;
|
||||||
|
mat_indices[1] = k;
|
||||||
|
T *a = ndarray::basic::get_ptr<SizeT, T>(a_ndarray, indices);
|
||||||
|
|
||||||
|
mat_indices[0] = k;
|
||||||
|
mat_indices[1] = j;
|
||||||
|
T *b = ndarray::basic::get_ptr<SizeT, T>(b_ndarray, indices);
|
||||||
|
|
||||||
|
*d += (*a) * (*b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace matmul
|
||||||
|
} // namespace ndarray
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
{
|
||||||
|
using namespace ndarray::matmul;
|
||||||
|
|
||||||
|
void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, int32_t *a_shape, int32_t b_ndims, int32_t *b_shape,
|
||||||
|
int32_t final_ndims, int32_t *new_a_shape, int32_t *new_b_shape,
|
||||||
|
int32_t *dst_shape)
|
||||||
|
{
|
||||||
|
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, int64_t *a_shape, int64_t b_ndims, int64_t *b_shape,
|
||||||
|
int64_t final_ndims, int64_t *new_a_shape, int64_t *new_b_shape,
|
||||||
|
int64_t *dst_shape)
|
||||||
|
{
|
||||||
|
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_float64_matmul_at_least_2d(NDArray<int32_t> *a_ndarray, NDArray<int32_t> *b_ndarray,
|
||||||
|
NDArray<int32_t> *dst_ndarray)
|
||||||
|
{
|
||||||
|
matmul_at_least_2d<int32_t, double>(a_ndarray, b_ndarray, dst_ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_float64_matmul_at_least_2d64(NDArray<int64_t> *a_ndarray, NDArray<int64_t> *b_ndarray,
|
||||||
|
NDArray<int64_t> *dst_ndarray)
|
||||||
|
{
|
||||||
|
matmul_at_least_2d<int64_t, double>(a_ndarray, b_ndarray, dst_ndarray);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1572,7 +1572,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
if op.base == Operator::MatMult {
|
if op.base == Operator::MatMult {
|
||||||
// Handle matrix multiplication.
|
// Handle matrix multiplication.
|
||||||
todo!()
|
let left = left.to_ndarray(generator, ctx);
|
||||||
|
let right = right.to_ndarray(generator, ctx);
|
||||||
|
let result = NDArrayObject::matmul(generator, ctx, left, right, out)
|
||||||
|
.split_unsized(generator, ctx);
|
||||||
|
Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())))
|
||||||
} else {
|
} else {
|
||||||
// For other operations, they are all elementwise operations.
|
// For other operations, they are all elementwise operations.
|
||||||
|
|
||||||
|
|
|
@ -1219,3 +1219,49 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
.arg(axes)
|
.arg(axes)
|
||||||
.returning_void();
|
.returning_void();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a_ndims: Instance<'ctx, Int<SizeT>>,
|
||||||
|
a_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
b_ndims: Instance<'ctx, Int<SizeT>>,
|
||||||
|
b_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
final_ndims: Instance<'ctx, Int<SizeT>>,
|
||||||
|
new_a_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
new_b_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
dst_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
) {
|
||||||
|
let name =
|
||||||
|
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes");
|
||||||
|
CallFunction::begin(generator, ctx, &name)
|
||||||
|
.arg(a_ndims)
|
||||||
|
.arg(a_shape)
|
||||||
|
.arg(b_ndims)
|
||||||
|
.arg(b_shape)
|
||||||
|
.arg(final_ndims)
|
||||||
|
.arg(new_a_shape)
|
||||||
|
.arg(new_b_shape)
|
||||||
|
.arg(dst_shape)
|
||||||
|
.returning_void();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_float64_matmul_at_least_2d<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||||
|
b_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||||
|
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
|
||||||
|
) {
|
||||||
|
let name = get_sizet_dependent_function_name(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
"__nac3_ndarray_float64_matmul_at_least_2d",
|
||||||
|
);
|
||||||
|
CallFunction::begin(generator, ctx, &name)
|
||||||
|
.arg(a_ndarray)
|
||||||
|
.arg(b_ndarray)
|
||||||
|
.arg(dst_ndarray)
|
||||||
|
.returning_void();
|
||||||
|
}
|
||||||
|
|
|
@ -1437,302 +1437,6 @@ where
|
||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
|
||||||
///
|
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
|
||||||
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
|
||||||
/// written to a new `ndarray`.
|
|
||||||
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
elem_ty: Type,
|
|
||||||
res: Option<NDArrayValue<'ctx>>,
|
|
||||||
lhs: NDArrayValue<'ctx>,
|
|
||||||
rhs: NDArrayValue<'ctx>,
|
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
if cfg!(debug_assertions) {
|
|
||||||
let lhs_ndims = lhs.load_ndims(ctx);
|
|
||||||
let rhs_ndims = rhs.load_ndims(ctx);
|
|
||||||
|
|
||||||
// lhs.ndims == 2
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "")
|
|
||||||
.unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// rhs.ndims == 2
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "")
|
|
||||||
.unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(res) = res {
|
|
||||||
let res_ndims = res.load_ndims(ctx);
|
|
||||||
let res_dim0 = unsafe {
|
|
||||||
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
let res_dim1 = unsafe {
|
|
||||||
res.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
let lhs_dim0 = unsafe {
|
|
||||||
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
let rhs_dim1 = unsafe {
|
|
||||||
rhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
// res.ndims == 2
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::EQ,
|
|
||||||
res_ndims,
|
|
||||||
llvm_usize.const_int(2, false),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// res.dims[0] == lhs.dims[0]
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// res.dims[1] == rhs.dims[0]
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
|
||||||
let lhs_dim1 = unsafe {
|
|
||||||
lhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
let rhs_dim0 = unsafe {
|
|
||||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
|
|
||||||
// lhs.dims[1] == rhs.dims[0]
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) {
|
|
||||||
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
|
|
||||||
} else {
|
|
||||||
lhs
|
|
||||||
};
|
|
||||||
|
|
||||||
let ndarray = res.unwrap_or_else(|| {
|
|
||||||
create_ndarray_dyn_shape(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
elem_ty,
|
|
||||||
&(lhs, rhs),
|
|
||||||
|_, _, _| Ok(llvm_usize.const_int(2, false)),
|
|
||||||
|generator, ctx, (lhs, rhs), idx| {
|
|
||||||
gen_if_else_expr_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|_, ctx| {
|
|
||||||
Ok(ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "")
|
|
||||||
.unwrap())
|
|
||||||
},
|
|
||||||
|generator, ctx| {
|
|
||||||
Ok(Some(unsafe {
|
|
||||||
lhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_zero(),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}))
|
|
||||||
},
|
|
||||||
|generator, ctx| {
|
|
||||||
Ok(Some(unsafe {
|
|
||||||
rhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}))
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.map(|v| v.map(BasicValueEnum::into_int_value).unwrap())
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
});
|
|
||||||
|
|
||||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| {
|
|
||||||
llvm_intrinsics::call_expect(
|
|
||||||
ctx,
|
|
||||||
idx.size(ctx, generator).get_type().const_int(2, false),
|
|
||||||
idx.size(ctx, generator),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
let common_dim = {
|
|
||||||
let lhs_idx1 = unsafe {
|
|
||||||
lhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
let rhs_idx0 = unsafe {
|
|
||||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
|
|
||||||
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
|
||||||
|
|
||||||
ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
let idx0 = unsafe {
|
|
||||||
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
|
||||||
|
|
||||||
ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap()
|
|
||||||
};
|
|
||||||
let idx1 = unsafe {
|
|
||||||
let idx1 =
|
|
||||||
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
|
|
||||||
|
|
||||||
ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
|
||||||
let result_identity = ndarray_zero_value(generator, ctx, elem_ty);
|
|
||||||
ctx.builder.build_store(result_addr, result_identity).unwrap();
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
llvm_i32.const_zero(),
|
|
||||||
(common_dim, false),
|
|
||||||
|generator, ctx, _, i| {
|
|
||||||
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
|
|
||||||
|
|
||||||
let ab_idx = generator.gen_array_var_alloc(
|
|
||||||
ctx,
|
|
||||||
llvm_i32.into(),
|
|
||||||
llvm_usize.const_int(2, false),
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let a = unsafe {
|
|
||||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into());
|
|
||||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into());
|
|
||||||
|
|
||||||
lhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
|
||||||
};
|
|
||||||
let b = unsafe {
|
|
||||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into());
|
|
||||||
ab_idx.set_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
idx1.into(),
|
|
||||||
);
|
|
||||||
|
|
||||||
rhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
|
||||||
};
|
|
||||||
|
|
||||||
let a_mul_b = gen_binop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(&Some(elem_ty), a),
|
|
||||||
Binop::normal(Operator::Mult),
|
|
||||||
(&Some(elem_ty), b),
|
|
||||||
ctx.current_loc,
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
|
||||||
|
|
||||||
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
|
||||||
let result = gen_binop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(&Some(elem_ty), result),
|
|
||||||
Binop::normal(Operator::Add),
|
|
||||||
(&Some(elem_ty), a_mul_b),
|
|
||||||
ctx.current_loc,
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
|
||||||
ctx.builder.build_store(result_addr, result).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
|
||||||
Ok(result)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(ndarray)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.empty`.
|
/// Generates LLVM IR for `ndarray.empty`.
|
||||||
pub fn gen_ndarray_empty<'ctx>(
|
pub fn gen_ndarray_empty<'ctx>(
|
||||||
context: &mut CodeGenContext<'ctx, '_>,
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
|
|
@ -0,0 +1,153 @@
|
||||||
|
use std::cmp::max;
|
||||||
|
|
||||||
|
use crate::codegen::{
|
||||||
|
irrt::{
|
||||||
|
call_nac3_ndarray_float64_matmul_at_least_2d, call_nac3_ndarray_matmul_calculate_shapes,
|
||||||
|
},
|
||||||
|
model::*,
|
||||||
|
object::ndarray::indexing::RustNDIndex,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{NDArrayObject, NDArrayOut};
|
||||||
|
|
||||||
|
/// Perform `np.einsum("...ij,...jk->...ik", a, b)`.
|
||||||
|
fn matmul_helper<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: NDArrayObject<'ctx>,
|
||||||
|
b: NDArrayObject<'ctx>,
|
||||||
|
) -> NDArrayObject<'ctx> {
|
||||||
|
assert!(a.ndims >= 2);
|
||||||
|
assert!(b.ndims >= 2);
|
||||||
|
|
||||||
|
assert!(ctx.unifier.unioned(ctx.primitives.float, a.dtype));
|
||||||
|
assert!(ctx.unifier.unioned(ctx.primitives.float, b.dtype));
|
||||||
|
|
||||||
|
let final_ndims_int = max(a.ndims, b.ndims);
|
||||||
|
|
||||||
|
let a_ndims = a.ndims_llvm(generator, ctx.ctx);
|
||||||
|
let a_shape = a.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
let b_ndims = b.ndims_llvm(generator, ctx.ctx);
|
||||||
|
let b_shape = b.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
let final_ndims = Int(SizeT).const_int(generator, ctx.ctx, final_ndims_int);
|
||||||
|
let new_a_shape = Int(SizeT).array_alloca(generator, ctx, final_ndims.value);
|
||||||
|
let new_b_shape = Int(SizeT).array_alloca(generator, ctx, final_ndims.value);
|
||||||
|
let dst_shape = Int(SizeT).array_alloca(generator, ctx, final_ndims.value);
|
||||||
|
|
||||||
|
call_nac3_ndarray_matmul_calculate_shapes(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
a_ndims,
|
||||||
|
a_shape,
|
||||||
|
b_ndims,
|
||||||
|
b_shape,
|
||||||
|
final_ndims,
|
||||||
|
new_a_shape,
|
||||||
|
new_b_shape,
|
||||||
|
dst_shape,
|
||||||
|
);
|
||||||
|
|
||||||
|
let dst = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, final_ndims_int);
|
||||||
|
dst.copy_shape_from_array(generator, ctx, dst_shape);
|
||||||
|
dst.create_data(generator, ctx);
|
||||||
|
|
||||||
|
let new_a = a.broadcast_to(generator, ctx, final_ndims_int, new_a_shape);
|
||||||
|
let new_b = b.broadcast_to(generator, ctx, final_ndims_int, new_b_shape);
|
||||||
|
|
||||||
|
call_nac3_ndarray_float64_matmul_at_least_2d(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
new_a.instance,
|
||||||
|
new_b.instance,
|
||||||
|
dst.instance,
|
||||||
|
);
|
||||||
|
|
||||||
|
dst
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
/// Perform `np.matmul` according to the rules in
|
||||||
|
/// <https://numpy.org/doc/stable/reference/generated/numpy.matmul.html>.
|
||||||
|
///
|
||||||
|
/// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`]
|
||||||
|
/// to handle when the output could be a scalar.
|
||||||
|
pub fn matmul<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: Self,
|
||||||
|
b: Self,
|
||||||
|
out: NDArrayOut<'ctx>,
|
||||||
|
) -> Self {
|
||||||
|
// Sanity check, but type inference should prevent this.
|
||||||
|
assert!(a.ndims > 0 && b.ndims > 0, "np.matmul disallows scalar input");
|
||||||
|
|
||||||
|
/*
|
||||||
|
If both arguments are 2-D they are multiplied like conventional matrices.
|
||||||
|
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indices and broadcast accordingly.
|
||||||
|
If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
|
||||||
|
If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
|
||||||
|
*/
|
||||||
|
|
||||||
|
let new_a = if a.ndims == 1 {
|
||||||
|
// Prepend 1 to its dimensions
|
||||||
|
a.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis])
|
||||||
|
} else {
|
||||||
|
a
|
||||||
|
};
|
||||||
|
|
||||||
|
let new_b = if b.ndims == 1 {
|
||||||
|
// Append 1 to its dimensions
|
||||||
|
b.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis])
|
||||||
|
} else {
|
||||||
|
b
|
||||||
|
};
|
||||||
|
|
||||||
|
// NOTE: `result` will always be a newly allocated ndarray.
|
||||||
|
// Current implementation cannot do in-place matrix muliplication.
|
||||||
|
let mut result = matmul_helper(generator, ctx, new_a, new_b);
|
||||||
|
|
||||||
|
let zero = Int(Int32).const_0(generator, ctx.ctx);
|
||||||
|
|
||||||
|
// Postprocessing on the result to remove prepended/appended axes.
|
||||||
|
let mut postindices = vec![];
|
||||||
|
|
||||||
|
if a.ndims == 1 {
|
||||||
|
// Remove the prepended 1
|
||||||
|
postindices.push(RustNDIndex::SingleElement(zero));
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.ndims == 1 {
|
||||||
|
// Remove the appended 1
|
||||||
|
postindices.push(RustNDIndex::Ellipsis);
|
||||||
|
postindices.push(RustNDIndex::SingleElement(zero));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !postindices.is_empty() {
|
||||||
|
result = result.index(generator, ctx, &postindices);
|
||||||
|
}
|
||||||
|
|
||||||
|
match out {
|
||||||
|
NDArrayOut::NewNDArray { dtype } => {
|
||||||
|
// We don't support auto-casting right now, nor anything other than float64.
|
||||||
|
// Force the output dtype to be float64.
|
||||||
|
assert!(ctx.unifier.unioned(ctx.primitives.float, dtype));
|
||||||
|
result
|
||||||
|
}
|
||||||
|
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
|
||||||
|
// TODO: It is possible to check the shapes before computing the matmul to die
|
||||||
|
// quicker to save computes.
|
||||||
|
let result_shape = result.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
out_ndarray.assert_can_be_written_by_out(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
result.ndims,
|
||||||
|
result_shape,
|
||||||
|
);
|
||||||
|
|
||||||
|
out_ndarray.copy_data_from(generator, ctx, result);
|
||||||
|
out_ndarray
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,6 +3,7 @@ pub mod broadcast;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod indexing;
|
pub mod indexing;
|
||||||
pub mod map;
|
pub mod map;
|
||||||
|
pub mod matmul;
|
||||||
pub mod nditer;
|
pub mod nditer;
|
||||||
pub mod shape_util;
|
pub mod shape_util;
|
||||||
pub mod view;
|
pub mod view;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use crate::toplevel::helper::PrimDef;
|
use crate::toplevel::helper::{extract_ndims, PrimDef};
|
||||||
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
||||||
use crate::typecheck::{
|
use crate::typecheck::{
|
||||||
type_inferencer::*,
|
type_inferencer::*,
|
||||||
|
@ -520,36 +520,41 @@ pub fn typeof_binop(
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator::MatMult => {
|
Operator::MatMult => {
|
||||||
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
||||||
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
let lhs_ndims = extract_ndims(unifier, lhs_ndims);
|
||||||
TypeEnum::TLiteral { values, .. } => {
|
|
||||||
assert_eq!(values.len(), 1);
|
let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
||||||
u64::try_from(values[0].clone()).unwrap()
|
let rhs_ndims = extract_ndims(unifier, rhs_ndims);
|
||||||
|
|
||||||
|
if !(unifier.unioned(lhs_dtype, primitives.float)
|
||||||
|
&& unifier.unioned(rhs_dtype, primitives.float))
|
||||||
|
{
|
||||||
|
return Err(format!(
|
||||||
|
"ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}",
|
||||||
|
unifier.stringify(lhs),
|
||||||
|
unifier.stringify(rhs)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let result_ndims = match (lhs_ndims, rhs_ndims) {
|
||||||
|
(0, _) | (_, 0) => {
|
||||||
|
return Err(
|
||||||
|
"ndarray.__matmul__ does not allow unsized ndarray input".to_string()
|
||||||
|
)
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
(1, 1) => 0,
|
||||||
};
|
(1, _) => rhs_ndims - 1,
|
||||||
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
(_, 1) => lhs_ndims - 1,
|
||||||
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
|
(m, n) => max(m, n),
|
||||||
TypeEnum::TLiteral { values, .. } => {
|
|
||||||
assert_eq!(values.len(), 1);
|
|
||||||
u64::try_from(values[0].clone()).unwrap()
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
match (lhs_ndims, rhs_ndims) {
|
if result_ndims == 0 {
|
||||||
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
|
// If the result is unsized, NumPy returns a scalar.
|
||||||
(lhs, rhs) if lhs == 0 || rhs == 0 => {
|
primitives.float
|
||||||
return Err(format!(
|
} else {
|
||||||
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
|
let result_ndims_ty =
|
||||||
u8::from(rhs == 0)
|
unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None);
|
||||||
))
|
make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty))
|
||||||
}
|
|
||||||
(lhs, rhs) => {
|
|
||||||
return Err(format!(
|
|
||||||
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -748,7 +753,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||||
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
||||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
|
impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None);
|
||||||
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
||||||
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
||||||
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
|
|
Loading…
Reference in New Issue