forked from M-Labs/nac3
core/ndstrides: implement general ndarray matmul
This commit is contained in:
parent
ae351f7678
commit
4fef633090
|
@ -8,6 +8,7 @@
|
||||||
#include <irrt/ndarray/def.hpp>
|
#include <irrt/ndarray/def.hpp>
|
||||||
#include <irrt/ndarray/indexing.hpp>
|
#include <irrt/ndarray/indexing.hpp>
|
||||||
#include <irrt/ndarray/iter.hpp>
|
#include <irrt/ndarray/iter.hpp>
|
||||||
|
#include <irrt/ndarray/matmul.hpp>
|
||||||
#include <irrt/ndarray/reshape.hpp>
|
#include <irrt/ndarray/reshape.hpp>
|
||||||
#include <irrt/ndarray/transpose.hpp>
|
#include <irrt/ndarray/transpose.hpp>
|
||||||
#include <irrt/original.hpp>
|
#include <irrt/original.hpp>
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/debug.hpp>
|
||||||
|
#include <irrt/exception.hpp>
|
||||||
|
#include <irrt/int_types.hpp>
|
||||||
|
#include <irrt/ndarray/basic.hpp>
|
||||||
|
#include <irrt/ndarray/broadcast.hpp>
|
||||||
|
#include <irrt/ndarray/iter.hpp>
|
||||||
|
|
||||||
|
// NOTE: Everything would be much easier and elegant if einsum is implemented.
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
namespace ndarray
|
||||||
|
{
|
||||||
|
namespace matmul
|
||||||
|
{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* Suppose `a_shape == [1, 97, 4, 2]`
|
||||||
|
* and `b_shape == [99, 98, 1, 2, 5]`,
|
||||||
|
*
|
||||||
|
* ...then `new_a_shape == [99, 98, 97, 4, 2]`,
|
||||||
|
* `new_b_shape == [99, 98, 97, 2, 5]`,
|
||||||
|
* and `dst_shape == [99, 98, 97, 4, 5]`.
|
||||||
|
* ^^^^^^^^^^ ^^^^
|
||||||
|
* (broadcasted) (4x2 @ 2x5 => 4x5)
|
||||||
|
*
|
||||||
|
* @param a_ndims Length of `a_shape`.
|
||||||
|
* @param a_shape Shape of `a`.
|
||||||
|
* @param b_ndims Length of `b_shape`.
|
||||||
|
* @param b_shape Shape of `b`.
|
||||||
|
* @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`,
|
||||||
|
* `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting.
|
||||||
|
*/
|
||||||
|
template <typename SizeT>
|
||||||
|
void calculate_shapes(SizeT a_ndims, SizeT *a_shape, SizeT b_ndims, SizeT *b_shape, SizeT final_ndims,
|
||||||
|
SizeT *new_a_shape, SizeT *new_b_shape, SizeT *dst_shape)
|
||||||
|
{
|
||||||
|
debug_assert(SizeT, a_ndims >= 2);
|
||||||
|
debug_assert(SizeT, b_ndims >= 2);
|
||||||
|
debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims);
|
||||||
|
|
||||||
|
// Check that a and b are compatible for matmul
|
||||||
|
if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2])
|
||||||
|
{
|
||||||
|
// This is a custom error message. Different from NumPy.
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})",
|
||||||
|
a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
|
const SizeT num_entries = 2;
|
||||||
|
ShapeEntry<SizeT> entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape},
|
||||||
|
{.ndims = b_ndims - 2, .shape = b_shape}};
|
||||||
|
|
||||||
|
// TODO: Optimize this
|
||||||
|
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_a_shape);
|
||||||
|
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_b_shape);
|
||||||
|
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, dst_shape);
|
||||||
|
|
||||||
|
new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
||||||
|
new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1];
|
||||||
|
new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2];
|
||||||
|
new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
||||||
|
dst_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
||||||
|
dst_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
||||||
|
}
|
||||||
|
} // namespace matmul
|
||||||
|
} // namespace ndarray
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
{
|
||||||
|
using namespace ndarray::matmul;
|
||||||
|
|
||||||
|
void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, int32_t *a_shape, int32_t b_ndims, int32_t *b_shape,
|
||||||
|
int32_t final_ndims, int32_t *new_a_shape, int32_t *new_b_shape,
|
||||||
|
int32_t *dst_shape)
|
||||||
|
{
|
||||||
|
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, int64_t *a_shape, int64_t b_ndims, int64_t *b_shape,
|
||||||
|
int64_t final_ndims, int64_t *new_a_shape, int64_t *new_b_shape,
|
||||||
|
int64_t *dst_shape)
|
||||||
|
{
|
||||||
|
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1573,7 +1573,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
if op.base == Operator::MatMult {
|
if op.base == Operator::MatMult {
|
||||||
// Handle matrix multiplication.
|
// Handle matrix multiplication.
|
||||||
todo!()
|
let left = left.to_ndarray(generator, ctx);
|
||||||
|
let right = right.to_ndarray(generator, ctx);
|
||||||
|
let result = NDArrayObject::matmul(generator, ctx, left, right, out)
|
||||||
|
.split_unsized(generator, ctx);
|
||||||
|
Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())))
|
||||||
} else {
|
} else {
|
||||||
// For other operations, they are all elementwise operations.
|
// For other operations, they are all elementwise operations.
|
||||||
|
|
||||||
|
|
|
@ -1231,3 +1231,30 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
.arg(axes)
|
.arg(axes)
|
||||||
.returning_void();
|
.returning_void();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a_ndims: Instance<'ctx, Int<SizeT>>,
|
||||||
|
a_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
b_ndims: Instance<'ctx, Int<SizeT>>,
|
||||||
|
b_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
final_ndims: Instance<'ctx, Int<SizeT>>,
|
||||||
|
new_a_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
new_b_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
dst_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
) {
|
||||||
|
let name =
|
||||||
|
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes");
|
||||||
|
CallFunction::begin(generator, ctx, &name)
|
||||||
|
.arg(a_ndims)
|
||||||
|
.arg(a_shape)
|
||||||
|
.arg(b_ndims)
|
||||||
|
.arg(b_shape)
|
||||||
|
.arg(final_ndims)
|
||||||
|
.arg(new_a_shape)
|
||||||
|
.arg(new_b_shape)
|
||||||
|
.arg(dst_shape)
|
||||||
|
.returning_void();
|
||||||
|
}
|
||||||
|
|
|
@ -1437,302 +1437,6 @@ where
|
||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
|
||||||
///
|
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
|
||||||
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
|
||||||
/// written to a new `ndarray`.
|
|
||||||
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
elem_ty: Type,
|
|
||||||
res: Option<NDArrayValue<'ctx>>,
|
|
||||||
lhs: NDArrayValue<'ctx>,
|
|
||||||
rhs: NDArrayValue<'ctx>,
|
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
if cfg!(debug_assertions) {
|
|
||||||
let lhs_ndims = lhs.load_ndims(ctx);
|
|
||||||
let rhs_ndims = rhs.load_ndims(ctx);
|
|
||||||
|
|
||||||
// lhs.ndims == 2
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "")
|
|
||||||
.unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// rhs.ndims == 2
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "")
|
|
||||||
.unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(res) = res {
|
|
||||||
let res_ndims = res.load_ndims(ctx);
|
|
||||||
let res_dim0 = unsafe {
|
|
||||||
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
let res_dim1 = unsafe {
|
|
||||||
res.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
let lhs_dim0 = unsafe {
|
|
||||||
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
let rhs_dim1 = unsafe {
|
|
||||||
rhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
// res.ndims == 2
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::EQ,
|
|
||||||
res_ndims,
|
|
||||||
llvm_usize.const_int(2, false),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// res.dims[0] == lhs.dims[0]
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// res.dims[1] == rhs.dims[0]
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
|
||||||
let lhs_dim1 = unsafe {
|
|
||||||
lhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
let rhs_dim0 = unsafe {
|
|
||||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
|
|
||||||
// lhs.dims[1] == rhs.dims[0]
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) {
|
|
||||||
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
|
|
||||||
} else {
|
|
||||||
lhs
|
|
||||||
};
|
|
||||||
|
|
||||||
let ndarray = res.unwrap_or_else(|| {
|
|
||||||
create_ndarray_dyn_shape(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
elem_ty,
|
|
||||||
&(lhs, rhs),
|
|
||||||
|_, _, _| Ok(llvm_usize.const_int(2, false)),
|
|
||||||
|generator, ctx, (lhs, rhs), idx| {
|
|
||||||
gen_if_else_expr_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|_, ctx| {
|
|
||||||
Ok(ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "")
|
|
||||||
.unwrap())
|
|
||||||
},
|
|
||||||
|generator, ctx| {
|
|
||||||
Ok(Some(unsafe {
|
|
||||||
lhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_zero(),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}))
|
|
||||||
},
|
|
||||||
|generator, ctx| {
|
|
||||||
Ok(Some(unsafe {
|
|
||||||
rhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}))
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.map(|v| v.map(BasicValueEnum::into_int_value).unwrap())
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
});
|
|
||||||
|
|
||||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| {
|
|
||||||
llvm_intrinsics::call_expect(
|
|
||||||
ctx,
|
|
||||||
idx.size(ctx, generator).get_type().const_int(2, false),
|
|
||||||
idx.size(ctx, generator),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
let common_dim = {
|
|
||||||
let lhs_idx1 = unsafe {
|
|
||||||
lhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
let rhs_idx0 = unsafe {
|
|
||||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
|
|
||||||
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
|
||||||
|
|
||||||
ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
let idx0 = unsafe {
|
|
||||||
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
|
||||||
|
|
||||||
ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap()
|
|
||||||
};
|
|
||||||
let idx1 = unsafe {
|
|
||||||
let idx1 =
|
|
||||||
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
|
|
||||||
|
|
||||||
ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
|
||||||
let result_identity = ndarray_zero_value(generator, ctx, elem_ty);
|
|
||||||
ctx.builder.build_store(result_addr, result_identity).unwrap();
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
llvm_i32.const_zero(),
|
|
||||||
(common_dim, false),
|
|
||||||
|generator, ctx, _, i| {
|
|
||||||
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
|
|
||||||
|
|
||||||
let ab_idx = generator.gen_array_var_alloc(
|
|
||||||
ctx,
|
|
||||||
llvm_i32.into(),
|
|
||||||
llvm_usize.const_int(2, false),
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let a = unsafe {
|
|
||||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into());
|
|
||||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into());
|
|
||||||
|
|
||||||
lhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
|
||||||
};
|
|
||||||
let b = unsafe {
|
|
||||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into());
|
|
||||||
ab_idx.set_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
idx1.into(),
|
|
||||||
);
|
|
||||||
|
|
||||||
rhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
|
||||||
};
|
|
||||||
|
|
||||||
let a_mul_b = gen_binop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(&Some(elem_ty), a),
|
|
||||||
Binop::normal(Operator::Mult),
|
|
||||||
(&Some(elem_ty), b),
|
|
||||||
ctx.current_loc,
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
|
||||||
|
|
||||||
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
|
||||||
let result = gen_binop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(&Some(elem_ty), result),
|
|
||||||
Binop::normal(Operator::Add),
|
|
||||||
(&Some(elem_ty), a_mul_b),
|
|
||||||
ctx.current_loc,
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
|
||||||
ctx.builder.build_store(result_addr, result).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
|
||||||
Ok(result)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(ndarray)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.empty`.
|
/// Generates LLVM IR for `ndarray.empty`.
|
||||||
pub fn gen_ndarray_empty<'ctx>(
|
pub fn gen_ndarray_empty<'ctx>(
|
||||||
context: &mut CodeGenContext<'ctx, '_>,
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
|
|
@ -0,0 +1,207 @@
|
||||||
|
use std::cmp::max;
|
||||||
|
|
||||||
|
use nac3parser::ast::Operator;
|
||||||
|
use util::gen_for_model;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
expr::gen_binop_expr_with_values, irrt::call_nac3_ndarray_matmul_calculate_shapes,
|
||||||
|
model::*, object::ndarray::indexing::RustNDIndex, CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::{magic_methods::Binop, typedef::Type},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{NDArrayObject, NDArrayOut};
|
||||||
|
|
||||||
|
/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`.
|
||||||
|
///
|
||||||
|
/// `dst_dtype` defines the dtype of the returned ndarray.
|
||||||
|
fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dst_dtype: Type,
|
||||||
|
in_a: NDArrayObject<'ctx>,
|
||||||
|
in_b: NDArrayObject<'ctx>,
|
||||||
|
) -> NDArrayObject<'ctx> {
|
||||||
|
assert!(in_a.ndims >= 2);
|
||||||
|
assert!(in_b.ndims >= 2);
|
||||||
|
|
||||||
|
// Deduce ndims of the result of matmul.
|
||||||
|
let ndims_int = max(in_a.ndims, in_b.ndims);
|
||||||
|
let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int);
|
||||||
|
|
||||||
|
let num_0 = Int(SizeT).const_int(generator, ctx.ctx, 0);
|
||||||
|
let num_1 = Int(SizeT).const_int(generator, ctx.ctx, 1);
|
||||||
|
|
||||||
|
// Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the
|
||||||
|
// destination ndarray to store the result of matmul.
|
||||||
|
let (a, b, dst) = {
|
||||||
|
let in_a_ndims = in_a.ndims_llvm(generator, ctx.ctx);
|
||||||
|
let in_a_shape = in_a.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
let in_b_ndims = in_b.ndims_llvm(generator, ctx.ctx);
|
||||||
|
let in_b_shape = in_b.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
let a_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
||||||
|
let b_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
||||||
|
let dst_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
||||||
|
|
||||||
|
// Matmul dimension compatibility is checked here.
|
||||||
|
call_nac3_ndarray_matmul_calculate_shapes(
|
||||||
|
generator, ctx, in_a_ndims, in_a_shape, in_b_ndims, in_b_shape, ndims, a_shape,
|
||||||
|
b_shape, dst_shape,
|
||||||
|
);
|
||||||
|
|
||||||
|
let a = in_a.broadcast_to(generator, ctx, ndims_int, a_shape);
|
||||||
|
let b = in_b.broadcast_to(generator, ctx, ndims_int, b_shape);
|
||||||
|
|
||||||
|
let dst = NDArrayObject::alloca(generator, ctx, dst_dtype, ndims_int);
|
||||||
|
dst.copy_shape_from_array(generator, ctx, dst_shape);
|
||||||
|
dst.create_data(generator, ctx);
|
||||||
|
|
||||||
|
(a, b, dst)
|
||||||
|
};
|
||||||
|
|
||||||
|
let len =
|
||||||
|
a.instance.get(generator, ctx, |f| f.shape).get_index_const(generator, ctx, ndims_int - 1);
|
||||||
|
|
||||||
|
let at_row = ndims_int - 2;
|
||||||
|
let at_col = ndims_int - 1;
|
||||||
|
|
||||||
|
let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype);
|
||||||
|
let dst_zero = dst_dtype_llvm.const_zero();
|
||||||
|
|
||||||
|
dst.foreach(generator, ctx, |generator, ctx, _, hdl| {
|
||||||
|
let pdst_ij = hdl.get_pointer(generator, ctx);
|
||||||
|
|
||||||
|
ctx.builder.build_store(pdst_ij, dst_zero).unwrap();
|
||||||
|
|
||||||
|
let indices = hdl.get_indices();
|
||||||
|
let i = indices.get_index_const(generator, ctx, at_row);
|
||||||
|
let j = indices.get_index_const(generator, ctx, at_col);
|
||||||
|
|
||||||
|
gen_for_model(generator, ctx, num_0, len, num_1, |generator, ctx, _, k| {
|
||||||
|
// `indices` is modified to index into `a` and `b`, and restored.
|
||||||
|
indices.set_index_const(ctx, at_row, i);
|
||||||
|
indices.set_index_const(ctx, at_col, k);
|
||||||
|
let a_ik = a.get_scalar_by_indices(generator, ctx, indices);
|
||||||
|
|
||||||
|
indices.set_index_const(ctx, at_row, k);
|
||||||
|
indices.set_index_const(ctx, at_col, j);
|
||||||
|
let b_kj = b.get_scalar_by_indices(generator, ctx, indices);
|
||||||
|
|
||||||
|
// Restore `indices`.
|
||||||
|
indices.set_index_const(ctx, at_row, i);
|
||||||
|
indices.set_index_const(ctx, at_col, j);
|
||||||
|
|
||||||
|
// x = a_[...]ik * b_[...]kj
|
||||||
|
let x = gen_binop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(&Some(a.dtype), a_ik.value),
|
||||||
|
Binop::normal(Operator::Mult),
|
||||||
|
(&Some(b.dtype), b_kj.value),
|
||||||
|
ctx.current_loc,
|
||||||
|
)?
|
||||||
|
.unwrap()
|
||||||
|
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
||||||
|
|
||||||
|
// dst_[...]ij += x
|
||||||
|
let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap();
|
||||||
|
let dst_ij = gen_binop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(&Some(dst_dtype), dst_ij),
|
||||||
|
Binop::normal(Operator::Add),
|
||||||
|
(&Some(dst_dtype), x),
|
||||||
|
ctx.current_loc,
|
||||||
|
)?
|
||||||
|
.unwrap()
|
||||||
|
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
||||||
|
ctx.builder.build_store(pdst_ij, dst_ij).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
dst
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
/// Perform `np.matmul` according to the rules in
|
||||||
|
/// <https://numpy.org/doc/stable/reference/generated/numpy.matmul.html>.
|
||||||
|
///
|
||||||
|
/// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`]
|
||||||
|
/// to handle when the output could be a scalar.
|
||||||
|
///
|
||||||
|
/// `dst_dtype` defines the dtype of the returned ndarray.
|
||||||
|
pub fn matmul<G: CodeGenerator>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: Self,
|
||||||
|
b: Self,
|
||||||
|
out: NDArrayOut<'ctx>,
|
||||||
|
) -> Self {
|
||||||
|
// Sanity check, but type inference should prevent this.
|
||||||
|
assert!(a.ndims > 0 && b.ndims > 0, "np.matmul disallows scalar input");
|
||||||
|
|
||||||
|
/*
|
||||||
|
If both arguments are 2-D they are multiplied like conventional matrices.
|
||||||
|
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indices and broadcast accordingly.
|
||||||
|
If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
|
||||||
|
If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
|
||||||
|
*/
|
||||||
|
|
||||||
|
let new_a = if a.ndims == 1 {
|
||||||
|
// Prepend 1 to its dimensions
|
||||||
|
a.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis])
|
||||||
|
} else {
|
||||||
|
a
|
||||||
|
};
|
||||||
|
|
||||||
|
let new_b = if b.ndims == 1 {
|
||||||
|
// Append 1 to its dimensions
|
||||||
|
b.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis])
|
||||||
|
} else {
|
||||||
|
b
|
||||||
|
};
|
||||||
|
|
||||||
|
// NOTE: `result` will always be a newly allocated ndarray.
|
||||||
|
// Current implementation cannot do in-place matrix muliplication.
|
||||||
|
let mut result = matmul_at_least_2d(generator, ctx, out.get_dtype(), new_a, new_b);
|
||||||
|
|
||||||
|
// Postprocessing on the result to remove prepended/appended axes.
|
||||||
|
let mut postindices = vec![];
|
||||||
|
let zero = Int(Int32).const_0(generator, ctx.ctx);
|
||||||
|
|
||||||
|
if a.ndims == 1 {
|
||||||
|
// Remove the prepended 1
|
||||||
|
postindices.push(RustNDIndex::SingleElement(zero));
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.ndims == 1 {
|
||||||
|
// Remove the appended 1
|
||||||
|
postindices.push(RustNDIndex::Ellipsis);
|
||||||
|
postindices.push(RustNDIndex::SingleElement(zero));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !postindices.is_empty() {
|
||||||
|
result = result.index(generator, ctx, &postindices);
|
||||||
|
}
|
||||||
|
|
||||||
|
match out {
|
||||||
|
NDArrayOut::NewNDArray { .. } => result,
|
||||||
|
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
|
||||||
|
let result_shape = result.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
out_ndarray.assert_can_be_written_by_out(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
result.ndims,
|
||||||
|
result_shape,
|
||||||
|
);
|
||||||
|
|
||||||
|
out_ndarray.copy_data_from(generator, ctx, result);
|
||||||
|
out_ndarray
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,6 +3,7 @@ pub mod broadcast;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod indexing;
|
pub mod indexing;
|
||||||
pub mod map;
|
pub mod map;
|
||||||
|
pub mod matmul;
|
||||||
pub mod nditer;
|
pub mod nditer;
|
||||||
pub mod shape_util;
|
pub mod shape_util;
|
||||||
pub mod view;
|
pub mod view;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use crate::toplevel::helper::PrimDef;
|
use crate::toplevel::helper::{extract_ndims, PrimDef};
|
||||||
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
||||||
use crate::typecheck::{
|
use crate::typecheck::{
|
||||||
type_inferencer::*,
|
type_inferencer::*,
|
||||||
|
@ -520,36 +520,41 @@ pub fn typeof_binop(
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator::MatMult => {
|
Operator::MatMult => {
|
||||||
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
||||||
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
let lhs_ndims = extract_ndims(unifier, lhs_ndims);
|
||||||
TypeEnum::TLiteral { values, .. } => {
|
|
||||||
assert_eq!(values.len(), 1);
|
let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
||||||
u64::try_from(values[0].clone()).unwrap()
|
let rhs_ndims = extract_ndims(unifier, rhs_ndims);
|
||||||
|
|
||||||
|
if !(unifier.unioned(lhs_dtype, primitives.float)
|
||||||
|
&& unifier.unioned(rhs_dtype, primitives.float))
|
||||||
|
{
|
||||||
|
return Err(format!(
|
||||||
|
"ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}",
|
||||||
|
unifier.stringify(lhs),
|
||||||
|
unifier.stringify(rhs)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let result_ndims = match (lhs_ndims, rhs_ndims) {
|
||||||
|
(0, _) | (_, 0) => {
|
||||||
|
return Err(
|
||||||
|
"ndarray.__matmul__ does not allow unsized ndarray input".to_string()
|
||||||
|
)
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
(1, 1) => 0,
|
||||||
};
|
(1, _) => rhs_ndims - 1,
|
||||||
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
(_, 1) => lhs_ndims - 1,
|
||||||
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
|
(m, n) => max(m, n),
|
||||||
TypeEnum::TLiteral { values, .. } => {
|
|
||||||
assert_eq!(values.len(), 1);
|
|
||||||
u64::try_from(values[0].clone()).unwrap()
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
match (lhs_ndims, rhs_ndims) {
|
if result_ndims == 0 {
|
||||||
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
|
// If the result is unsized, NumPy returns a scalar.
|
||||||
(lhs, rhs) if lhs == 0 || rhs == 0 => {
|
primitives.float
|
||||||
return Err(format!(
|
} else {
|
||||||
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
|
let result_ndims_ty =
|
||||||
u8::from(rhs == 0)
|
unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None);
|
||||||
))
|
make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty))
|
||||||
}
|
|
||||||
(lhs, rhs) => {
|
|
||||||
return Err(format!(
|
|
||||||
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -752,7 +757,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||||
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
||||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
|
impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None);
|
||||||
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
||||||
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
||||||
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
|
|
Loading…
Reference in New Issue