[core] codegen/ndarray: Reimplement matmul

Based on 73c2203b: core/ndstrides: implement general matmul
This commit is contained in:
David Mak 2024-12-19 12:21:52 +08:00
parent 7da45a4fbd
commit a525592941
17 changed files with 581 additions and 995 deletions

View File

@ -1,7 +1,6 @@
#include "irrt/exception.hpp" #include "irrt/exception.hpp"
#include "irrt/list.hpp" #include "irrt/list.hpp"
#include "irrt/math.hpp" #include "irrt/math.hpp"
#include "irrt/ndarray.hpp"
#include "irrt/range.hpp" #include "irrt/range.hpp"
#include "irrt/slice.hpp" #include "irrt/slice.hpp"
#include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/basic.hpp"
@ -12,3 +11,4 @@
#include "irrt/ndarray/reshape.hpp" #include "irrt/ndarray/reshape.hpp"
#include "irrt/ndarray/broadcast.hpp" #include "irrt/ndarray/broadcast.hpp"
#include "irrt/ndarray/transpose.hpp" #include "irrt/ndarray/transpose.hpp"
#include "irrt/ndarray/matmul.hpp"

View File

@ -21,7 +21,5 @@ using uint64_t = unsigned _ExtInt(64);
#endif #endif
// NDArray indices are always `uint32_t`.
using NDIndexInt = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`. // The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t; using SliceIndex = int32_t;

View File

@ -1,50 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
// TODO: To be deleted since NDArray with strides is done.
namespace {
template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template<typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
} // namespace
extern "C" {
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
}

View File

@ -0,0 +1,100 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/basic.hpp"
#include "irrt/ndarray/broadcast.hpp"
#include "irrt/ndarray/iter.hpp"
// NOTE: Everything would be much easier and elegant if einsum is implemented.
namespace {
namespace ndarray {
namespace matmul {
/**
* @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`.
*
* Example:
* Suppose `a_shape == [1, 97, 4, 2]`
* and `b_shape == [99, 98, 1, 2, 5]`,
*
* ...then `new_a_shape == [99, 98, 97, 4, 2]`,
* `new_b_shape == [99, 98, 97, 2, 5]`,
* and `dst_shape == [99, 98, 97, 4, 5]`.
* ^^^^^^^^^^ ^^^^
* (broadcasted) (4x2 @ 2x5 => 4x5)
*
* @param a_ndims Length of `a_shape`.
* @param a_shape Shape of `a`.
* @param b_ndims Length of `b_shape`.
* @param b_shape Shape of `b`.
* @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`,
* `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting.
*/
template<typename SizeT>
void calculate_shapes(SizeT a_ndims,
SizeT* a_shape,
SizeT b_ndims,
SizeT* b_shape,
SizeT final_ndims,
SizeT* new_a_shape,
SizeT* new_b_shape,
SizeT* dst_shape) {
debug_assert(SizeT, a_ndims >= 2);
debug_assert(SizeT, b_ndims >= 2);
debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims);
// Check that a and b are compatible for matmul
if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) {
// This is a custom error message. Different from NumPy.
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})",
a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM);
}
const SizeT num_entries = 2;
ShapeEntry<SizeT> entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape},
{.ndims = b_ndims - 2, .shape = b_shape}};
// TODO: Optimize this
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_a_shape);
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_b_shape);
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, dst_shape);
new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2];
new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1];
new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2];
new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1];
dst_shape[final_ndims - 2] = a_shape[a_ndims - 2];
dst_shape[final_ndims - 1] = b_shape[b_ndims - 1];
}
} // namespace matmul
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::matmul;
void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims,
int32_t* a_shape,
int32_t b_ndims,
int32_t* b_shape,
int32_t final_ndims,
int32_t* new_a_shape,
int32_t* new_b_shape,
int32_t* dst_shape) {
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
}
void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims,
int64_t* a_shape,
int64_t b_ndims,
int64_t* b_shape,
int64_t final_ndims,
int64_t* new_a_shape,
int64_t* new_b_shape,
int64_t* dst_shape) {
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
}
}

View File

@ -27,7 +27,7 @@ use super::{
call_int_umin, call_memcpy_generic, call_int_umin, call_memcpy_generic,
}, },
macros::codegen_unreachable, macros::codegen_unreachable,
need_sret, numpy, need_sret,
stmt::{ stmt::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
gen_var, gen_var,
@ -1557,37 +1557,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let left = ScalarOrNDArray::from_value(generator, ctx, (ty1, left_val)); let left = ScalarOrNDArray::from_value(generator, ctx, (ty1, left_val));
let right = ScalarOrNDArray::from_value(generator, ctx, (ty2, right_val)); let right = ScalarOrNDArray::from_value(generator, ctx, (ty2, right_val));
if op.base == Operator::MatMult {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
let left = left.to_ndarray(generator, ctx);
let right = right.to_ndarray(generator, ctx);
// MatMult is the only binop which is not an elementwise op
let result = numpy::ndarray_matmul_2d(
generator,
ctx,
ndarray_dtype1,
match op.variant {
BinopVariant::Normal => None,
BinopVariant::AugAssign => Some(left),
},
left,
right,
)?;
Ok(Some(result.as_base_value().into()))
} else {
// For other operations, they are all elementwise operations.
// There are only three cases:
// - LHS is a scalar, RHS is an ndarray.
// - LHS is an ndarray, RHS is a scalar.
// - LHS is an ndarray, RHS is an ndarray.
//
// For all cases, the scalar operand is promoted to an ndarray,
// the two are then broadcasted, and starmapped through.
let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1); let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1);
let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2); let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2);
@ -1610,6 +1579,24 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
} }
}; };
if op.base == Operator::MatMult {
let left = left.to_ndarray(generator, ctx);
let right = right.to_ndarray(generator, ctx);
let result = left
.matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out))
.split_unsized(generator, ctx);
Ok(Some(result.to_basic_value_enum().into()))
} else {
// For other operations, they are all elementwise operations.
// There are only three cases:
// - LHS is a scalar, RHS is an ndarray.
// - LHS is an ndarray, RHS is a scalar.
// - LHS is an ndarray, RHS is an ndarray.
//
// For all cases, the scalar operand is promoted to an ndarray,
// the two are then broadcasted, and starmapped through.
let left = left.to_ndarray(generator, ctx); let left = left.to_ndarray(generator, ctx);
let right = right.to_ndarray(generator, ctx); let right = right.to_ndarray(generator, ctx);

View File

@ -0,0 +1,63 @@
use inkwell::types::BasicTypeEnum;
use inkwell::values::IntValue;
use crate::codegen::expr::infer_and_call_function;
use crate::codegen::irrt::get_usize_dependent_function_name;
use crate::codegen::values::TypedArrayLikeAccessor;
use crate::codegen::{CodeGenContext, CodeGenerator};
#[allow(clippy::too_many_arguments)]
pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
final_ndims: IntValue<'ctx>,
new_a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
assert_eq!(
BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
assert_eq!(
BasicTypeEnum::try_from(b_shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
assert_eq!(
BasicTypeEnum::try_from(new_a_shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
assert_eq!(
BasicTypeEnum::try_from(new_b_shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
assert_eq!(
BasicTypeEnum::try_from(dst_shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
let name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes");
infer_and_call_function(
ctx,
&name,
None,
&[
a_shape.size(ctx, generator).into(),
a_shape.base_ptr(ctx, generator).into(),
b_shape.size(ctx, generator).into(),
b_shape.base_ptr(ctx, generator).into(),
final_ndims.into(),
new_a_shape.base_ptr(ctx, generator).into(),
new_b_shape.base_ptr(ctx, generator).into(),
dst_shape.base_ptr(ctx, generator).into(),
],
None,
None,
);
}

View File

@ -1,23 +1,9 @@
use inkwell::{
types::BasicTypeEnum,
values::{BasicValueEnum, CallSiteValue, IntValue},
AddressSpace,
};
use itertools::Either;
use super::get_usize_dependent_function_name;
use crate::codegen::{
values::{
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue,
TypedArrayLikeAdapter,
},
CodeGenContext, CodeGenerator,
};
pub use array::*; pub use array::*;
pub use basic::*; pub use basic::*;
pub use broadcast::*; pub use broadcast::*;
pub use indexing::*; pub use indexing::*;
pub use iter::*; pub use iter::*;
pub use matmul::*;
pub use reshape::*; pub use reshape::*;
pub use transpose::*; pub use transpose::*;
@ -26,119 +12,6 @@ mod basic;
mod broadcast; mod broadcast;
mod indexing; mod indexing;
mod iter; mod iter;
mod matmul;
mod reshape; mod reshape;
mod transpose; mod transpose;
/// Generates a call to `__nac3_ndarray_calc_size`. Returns a
/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension
/// respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert!(begin.is_none_or(|begin| begin.get_type() == llvm_usize));
assert!(end.is_none_or(|end| end.get_type() == llvm_usize));
assert_eq!(
BasicTypeEnum::try_from(dims.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
let ndarray_calc_size_fn_name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_size");
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
let ndarray_calc_size_fn =
ctx.module.get_function(&ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(&ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypedArrayLikeAdapter`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The `llvm_usize` index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!(index.get_type(), llvm_usize);
let ndarray_calc_nd_indices_fn_name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_nd_indices");
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(&ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(&ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.shape();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
ctx.builder
.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
|_, _, v| v.into_int_value(),
|_, _, v| v.into(),
)
}

View File

@ -1,736 +1,23 @@
use inkwell::{ use inkwell::{
types::BasicType, values::{BasicValue, BasicValueEnum, PointerValue},
values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, IntPredicate,
IntPredicate, OptimizationLevel,
}; };
use nac3parser::ast::{Operator, StrRef}; use nac3parser::ast::StrRef;
use super::{ use super::{
expr::gen_binop_expr_with_values,
irrt::{
calculate_len_for_slice_range,
ndarray::{call_ndarray_calc_nd_indices, call_ndarray_calc_size},
},
llvm_intrinsics::{self, call_memcpy_generic},
macros::codegen_unreachable, macros::codegen_unreachable,
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, stmt::gen_for_callback_incrementing,
types::ndarray::{factory::ndarray_zero_value, NDArrayType}, types::ndarray::NDArrayType,
values::{ values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, UntypedArrayLikeAccessor},
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor,
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
UntypedArrayLikeMutator,
},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
use crate::{ use crate::{
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId},
typecheck::{ typecheck::typedef::{FunSignature, Type},
magic_methods::Binop,
typedef::{FunSignature, Type},
},
}; };
/// Creates an `NDArray` instance from a dynamic shape.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`.
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
shape: &V,
shape_len_fn: LenFn,
shape_data_fn: DataFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
DataFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
&V,
IntValue<'ctx>,
) -> Result<IntValue<'ctx>, String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
// Assert that all dimensions are non-negative
let shape_len = shape_len_fn(generator, ctx, shape)?;
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(shape_len, false),
|generator, ctx, _, i| {
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
let shape_dim_gez = ctx
.builder
.build_int_compare(
IntPredicate::SGE,
shape_dim,
shape_dim.get_type().const_zero(),
"",
)
.unwrap();
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
// TODO: Disallow shape > u32_MAX
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let num_dims = shape_len_fn(generator, ctx, shape)?;
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None)
.construct_dyn_ndims(generator, ctx, num_dims, None);
// Copy the dimension sizes from shape to ndarray.dims
let shape_len = shape_len_fn(generator, ctx, shape)?;
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(shape_len, false),
|generator, ctx, _, i| {
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let ndarray_pdim =
unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) };
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
unsafe { ndarray.create_data(generator, ctx) };
Ok(ndarray)
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
/// its input.
fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarray: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
IntValue<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_num_elems = ndarray.size(generator, ctx);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(ndarray_num_elems, false),
|generator, ctx, _, i| {
let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) };
let value = value_fn(generator, ctx, i)?;
ctx.builder.build_store(elem, value).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
/// as its input.
fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarray: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
&TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| {
let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray);
value_fn(generator, ctx, &indices)
})
}
/// Copies a slice of an [`NDArrayValue`] to another.
///
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape`
/// fields should be populated before calling this function.
/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
/// dimensional slice in the destination array.
/// - `src_arr`: The [`NDArrayValue`] instance of the source array.
/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
/// dimensional slice in the source array.
/// - `dim`: The index of the currently processing dimension.
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
/// this dimension. The `start`/`stop` values of each slice must be non-negative indices.
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
(src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
dim: u64,
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
) -> Result<(), String> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type());
let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap();
// If there are no (remaining) slice expressions, memcpy the entire dimension
if slices.is_empty() {
let stride = call_ndarray_calc_size(
generator,
ctx,
&src_arr.shape(),
(Some(llvm_usize.const_int(dim, false)), None),
);
let stride =
ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap();
let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap();
call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero());
return Ok(());
}
// The stride of elements in this dimension, i.e. the number of elements between arr[i] and
// arr[i + 1] in this dimension
let src_stride = call_ndarray_calc_size(
generator,
ctx,
&src_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
let dst_stride = call_ndarray_calc_size(
generator,
ctx,
&dst_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
let (start, stop, step) = slices[0];
let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap();
let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap();
let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap();
let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap();
ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap();
gen_for_range_callback(
generator,
ctx,
None,
false,
|_, _| Ok(start),
(|_, _| Ok(stop), true),
|_, _| Ok(step),
|generator, ctx, _, src_i| {
// Calculate the offset of the active slice
let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap();
let src_data_offset = ctx
.builder
.build_int_mul(
src_data_offset,
ctx.builder
.build_int_z_extend_or_bit_cast(sizeof_elem, src_data_offset.get_type(), "")
.unwrap(),
"",
)
.unwrap();
let dst_i =
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap();
let dst_data_offset = ctx
.builder
.build_int_mul(
dst_data_offset,
ctx.builder
.build_int_z_extend_or_bit_cast(sizeof_elem, dst_data_offset.get_type(), "")
.unwrap(),
"",
)
.unwrap();
let (src_ptr, dst_ptr) = unsafe {
(
ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(),
ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(),
)
};
ndarray_sliced_copyto_impl(
generator,
ctx,
(dst_arr, dst_ptr),
(src_arr, src_ptr),
dim + 1,
&slices[1..],
)?;
let dst_i =
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let dst_i_add1 =
ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap();
ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap();
Ok(())
},
)?;
Ok(())
}
/// Copies a [`NDArrayValue`] using slices.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
/// this dimension. The `start`/`stop` values of each slice must be positive indices.
pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
this: NDArrayValue<'ctx>,
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray =
if slices.is_empty() {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&this,
|_, ctx, shape| Ok(shape.load_ndims(ctx)),
|generator, ctx, shape, idx| unsafe {
Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None))
},
)?
} else {
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None)
.construct_dyn_ndims(generator, ctx, this.load_ndims(ctx), None);
// Populate the first slices.len() dimensions by computing the size of each dim slice
for (i, (start, stop, step)) in slices.iter().enumerate() {
// HACK: workaround calculate_len_for_slice_range requiring exclusive stop
let stop = ctx
.builder
.build_select(
ctx.builder
.build_int_compare(
IntPredicate::SLT,
*step,
llvm_i32.const_zero(),
"is_neg",
)
.unwrap(),
ctx.builder
.build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one")
.unwrap(),
ctx.builder
.build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one")
.unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step);
let slice_len =
ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap();
unsafe {
ndarray.shape().set_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
slice_len,
);
}
}
// Populate the rest by directly copying the dim size from the source array
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_int(slices.len() as u64, false),
(this.load_ndims(ctx), false),
|generator, ctx, _, idx| {
unsafe {
let shape = this.shape().get_typed_unchecked(ctx, generator, &idx, None);
ndarray.shape().set_typed_unchecked(ctx, generator, &idx, shape);
}
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
unsafe { ndarray.create_data(generator, ctx) };
ndarray
};
ndarray_sliced_copyto_impl(
generator,
ctx,
(ndarray, ndarray.data().base_ptr(ctx, generator)),
(this, this.data().base_ptr(ctx, generator)),
0,
slices,
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
///
/// * `elem_ty` - The element type of the `NDArray`.
fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
this: NDArrayValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
}
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
/// written to a new `ndarray`.
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
if cfg!(debug_assertions) {
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
// lhs.ndims == 2
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
// rhs.ndims == 2
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
if let Some(res) = res {
let res_ndims = res.load_ndims(ctx);
let res_dim0 = unsafe {
res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let res_dim1 = unsafe {
res.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
let lhs_dim0 = unsafe {
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let rhs_dim1 = unsafe {
rhs.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
// res.ndims == 2
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::EQ,
res_ndims,
llvm_usize.const_int(2, false),
"",
)
.unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
// res.dims[0] == lhs.dims[0]
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
// res.dims[1] == rhs.dims[0]
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
}
}
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let lhs_dim1 = unsafe {
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
};
let rhs_dim0 = unsafe {
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
// lhs.dims[1] == rhs.dims[0]
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(),
"0:ValueError",
"",
[None, None, None],
ctx.current_loc,
);
}
let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) {
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
} else {
lhs
};
let ndarray = res.unwrap_or_else(|| {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&(lhs, rhs),
|_, _, _| Ok(llvm_usize.const_int(2, false)),
|generator, ctx, (lhs, rhs), idx| {
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "")
.unwrap())
},
|generator, ctx| {
Ok(Some(unsafe {
lhs.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
}))
},
|generator, ctx| {
Ok(Some(unsafe {
rhs.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
}))
},
)
.map(|v| v.map(BasicValueEnum::into_int_value).unwrap())
},
)
.unwrap()
});
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| {
llvm_intrinsics::call_expect(
ctx,
idx.size(ctx, generator).get_type().const_int(2, false),
idx.size(ctx, generator),
None,
);
let common_dim = {
let lhs_idx1 = unsafe {
lhs.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
let rhs_idx0 = unsafe {
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
ctx.builder.build_int_z_extend_or_bit_cast(idx, llvm_usize, "").unwrap()
};
let idx0 = unsafe {
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_int_z_extend_or_bit_cast(idx0, llvm_usize, "").unwrap()
};
let idx1 = unsafe {
let idx1 =
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
ctx.builder.build_int_z_extend_or_bit_cast(idx1, llvm_usize, "").unwrap()
};
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
let result_identity = ndarray_zero_value(generator, ctx, elem_ty);
ctx.builder.build_store(result_addr, result_identity).unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(common_dim, false),
|generator, ctx, _, i| {
let ab_idx = generator.gen_array_var_alloc(
ctx,
llvm_usize.into(),
llvm_usize.const_int(2, false),
None,
)?;
let a = unsafe {
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into());
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into());
lhs.data().get_unchecked(ctx, generator, &ab_idx, None)
};
let b = unsafe {
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into());
ab_idx.set_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
idx1.into(),
);
rhs.data().get_unchecked(ctx, generator, &ab_idx, None)
};
let a_mul_b = gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), a),
Binop::normal(Operator::Mult),
(&Some(elem_ty), b),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, elem_ty)?;
let result = ctx.builder.build_load(result_addr, "").unwrap();
let result = gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), result),
Binop::normal(Operator::Add),
(&Some(elem_ty), a_mul_b),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, elem_ty)?;
ctx.builder.build_store(result_addr, result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let result = ctx.builder.build_load(result_addr, "").unwrap();
Ok(result)
})?;
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`. /// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx>( pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>, context: &mut CodeGenContext<'ctx, '_>,

View File

@ -12,7 +12,7 @@ use crate::{
}; };
/// Get the zero value in `np.zeros()` of a `dtype`. /// Get the zero value in `np.zeros()` of a `dtype`.
pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type, dtype: Type,

View File

@ -0,0 +1,331 @@
use crate::codegen::expr::gen_binop_expr_with_values;
use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::types::ndarray::NDArrayType;
use crate::codegen::values::ndarray::{NDArrayOut, NDArrayValue, RustNDIndex};
use crate::codegen::values::{
ArrayLikeValue, ArraySliceValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
};
use crate::codegen::{irrt, CodeGenContext, CodeGenerator};
use crate::toplevel::helper::arraylike_flatten_element_type;
use crate::typecheck::magic_methods::Binop;
use crate::typecheck::typedef::Type;
use nac3parser::ast::Operator;
use std::cmp::max;
/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`.
///
/// `dst_dtype` defines the dtype of the returned ndarray.
fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dst_dtype: Type,
(in_a_ty, in_a): (Type, NDArrayValue<'ctx>),
(in_b_ty, in_b): (Type, NDArrayValue<'ctx>),
) -> NDArrayValue<'ctx> {
assert!(
in_a.ndims.is_some_and(|ndims| ndims >= 2),
"in_a (which is {:?}) must be compile-time known and >= 2",
in_a.ndims
);
assert!(
in_b.ndims.is_some_and(|ndims| ndims >= 2),
"in_b (which is {:?}) must be compile-time known and >= 2",
in_b.ndims
);
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty);
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype);
// Deduce ndims of the result of matmul.
let ndims_int = max(in_a.ndims.unwrap(), in_b.ndims.unwrap());
let ndims = llvm_usize.const_int(ndims_int, false);
// Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the
// destination ndarray to store the result of matmul.
let (lhs, rhs, dst) = {
let in_lhs_ndims = llvm_usize.const_int(in_a.ndims.unwrap(), false);
let in_lhs_shape = TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(
in_a.shape().base_ptr(ctx, generator),
in_lhs_ndims,
None,
),
|_, _, val| val.into_int_value(),
|_, _, val| val.into(),
);
let in_rhs_ndims = llvm_usize.const_int(in_b.ndims.unwrap(), false);
let in_rhs_shape = TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(
in_b.shape().base_ptr(ctx, generator),
in_rhs_ndims,
None,
),
|_, _, val| val.into_int_value(),
|_, _, val| val.into(),
);
let lhs_shape = TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(
ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(),
ndims,
None,
),
|_, _, val| val.into_int_value(),
|_, _, val| val.into(),
);
let rhs_shape = TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(
ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(),
ndims,
None,
),
|_, _, val| val.into_int_value(),
|_, _, val| val.into(),
);
let dst_shape = TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(
ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(),
ndims,
None,
),
|_, _, val| val.into_int_value(),
|_, _, val| val.into(),
);
// Matmul dimension compatibility is checked here.
irrt::ndarray::call_nac3_ndarray_matmul_calculate_shapes(
generator,
ctx,
&in_lhs_shape,
&in_rhs_shape,
ndims,
&lhs_shape,
&rhs_shape,
&dst_shape,
);
let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape);
let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape);
let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, Some(ndims_int))
.construct_uninitialized(generator, ctx, None);
dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator));
unsafe {
dst.create_data(generator, ctx);
}
(lhs, rhs, dst)
};
let len = unsafe {
lhs.shape().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(ndims_int - 1, false),
None,
)
};
let at_row = i64::try_from(ndims_int - 2).unwrap();
let at_col = i64::try_from(ndims_int - 1).unwrap();
let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype);
let dst_zero = dst_dtype_llvm.const_zero();
dst.foreach(generator, ctx, |generator, ctx, _, hdl| {
let pdst_ij = hdl.get_pointer(ctx);
ctx.builder.build_store(pdst_ij, dst_zero).unwrap();
let indices = hdl.get_indices::<G>();
let i = unsafe {
indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_row as u64, true), None)
};
let j = unsafe {
indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_col as u64, true), None)
};
let num_0 = llvm_usize.const_int(0, false);
let num_1 = llvm_usize.const_int(1, false);
gen_for_callback_incrementing(
generator,
ctx,
None,
num_0,
(len, false),
|generator, ctx, _, k| {
// `indices` is modified to index into `a` and `b`, and restored.
unsafe {
indices.set_unchecked(
ctx,
generator,
&llvm_usize.const_int(at_row as u64, true),
i,
);
indices.set_unchecked(
ctx,
generator,
&llvm_usize.const_int(at_col as u64, true),
k.into(),
);
}
let a_ik = unsafe { lhs.data().get_unchecked(ctx, generator, &indices, None) };
unsafe {
indices.set_unchecked(
ctx,
generator,
&llvm_usize.const_int(at_row as u64, true),
k.into(),
);
indices.set_unchecked(
ctx,
generator,
&llvm_usize.const_int(at_col as u64, true),
j,
);
}
let b_kj = unsafe { rhs.data().get_unchecked(ctx, generator, &indices, None) };
// Restore `indices`.
unsafe {
indices.set_unchecked(
ctx,
generator,
&llvm_usize.const_int(at_row as u64, true),
i,
);
indices.set_unchecked(
ctx,
generator,
&llvm_usize.const_int(at_col as u64, true),
j,
);
}
// x = a_[...]ik * b_[...]kj
let x = gen_binop_expr_with_values(
generator,
ctx,
(&Some(lhs_dtype), a_ik),
Binop::normal(Operator::Mult),
(&Some(rhs_dtype), b_kj),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, dst_dtype)?;
// dst_[...]ij += x
let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap();
let dst_ij = gen_binop_expr_with_values(
generator,
ctx,
(&Some(dst_dtype), dst_ij),
Binop::normal(Operator::Add),
(&Some(dst_dtype), x),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, dst_dtype)?;
ctx.builder.build_store(pdst_ij, dst_ij).unwrap();
Ok(())
},
num_1,
)
})
.unwrap();
dst
}
impl<'ctx> NDArrayValue<'ctx> {
/// Perform `np.matmul` according to the rules in
/// <https://numpy.org/doc/stable/reference/generated/numpy.matmul.html>.
///
/// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`]
/// to handle when the output could be a scalar.
///
/// `dst_dtype` defines the dtype of the returned ndarray.
#[must_use]
pub fn matmul<G: CodeGenerator>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
self_ty: Type,
(other_ty, other): (Type, Self),
(out_dtype, out): (Type, NDArrayOut<'ctx>),
) -> Self {
// Sanity check, but type inference should prevent this.
assert!(
self.ndims.is_some_and(|ndims| ndims > 0) && other.ndims.is_some_and(|ndims| ndims > 0),
"np.matmul disallows scalar input"
);
/*
If both arguments are 2-D they are multiplied like conventional matrices.
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indices and broadcast accordingly.
If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
*/
let new_a = if self.ndims.unwrap() == 1 {
// Prepend 1 to its dimensions
self.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis])
} else {
*self
};
let new_b = if other.ndims.unwrap() == 1 {
// Append 1 to its dimensions
other.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis])
} else {
other
};
// NOTE: `result` will always be a newly allocated ndarray.
// Current implementation cannot do in-place matrix muliplication.
let mut result =
matmul_at_least_2d(generator, ctx, out_dtype, (self_ty, new_a), (other_ty, new_b));
// Postprocessing on the result to remove prepended/appended axes.
let mut postindices = vec![];
let zero = ctx.ctx.i32_type().const_zero();
if self.ndims.unwrap() == 1 {
// Remove the prepended 1
postindices.push(RustNDIndex::SingleElement(zero));
}
if other.ndims.unwrap() == 1 {
// Remove the appended 1
postindices.push(RustNDIndex::Ellipsis);
postindices.push(RustNDIndex::SingleElement(zero));
}
if !postindices.is_empty() {
result = result.index(generator, ctx, &postindices);
}
match out {
NDArrayOut::NewNDArray { .. } => result,
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
// let result_shape = result.instance.get(generator, ctx, |f| f.shape);
let result_shape = result.shape();
out_ndarray.assert_can_be_written_by_out(
generator,
ctx,
result.ndims.unwrap(),
result_shape.as_slice_value(ctx, generator),
);
out_ndarray.copy_data_from(generator, ctx, result);
out_ndarray
}
}
}
}

View File

@ -32,6 +32,7 @@ mod broadcast;
mod contiguous; mod contiguous;
mod indexing; mod indexing;
mod map; mod map;
mod matmul;
mod nditer; mod nditer;
pub mod shape; pub mod shape;
mod view; mod view;

View File

@ -8,5 +8,5 @@ expression: res_vec
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(254)]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(262)]\n}\n",
] ]

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar238]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar238\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar246]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar246\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(251)]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(259)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(264)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar237, typevar238]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar237\", \"typevar238\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar245, typevar246]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar245\", \"typevar246\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(257)]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(265)]\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(265)]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(273)]\n}\n",
] ]

View File

@ -7,12 +7,12 @@ use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop};
use super::{ use super::{
type_inferencer::*, type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
}; };
use crate::{ use crate::{
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::{ toplevel::{
helper::PrimDef, helper::{extract_ndims, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
}, },
}; };
@ -175,19 +175,8 @@ pub fn impl_binop(
ops: &[Operator], ops: &[Operator],
) { ) {
with_fields(unifier, ty, |unifier, fields| { with_fields(unifier, ty, |unifier, fields| {
let (other_ty, other_var_id) = if other_ty.len() == 1 { let other_tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(other_ty[0], None) let function_vars = into_var_map([other_tvar]);
} else {
let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(tvar.ty, Some(tvar.id))
};
let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
} else {
VarMap::new()
};
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) { for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) {
@ -198,7 +187,7 @@ pub fn impl_binop(
ret: ret_ty, ret: ret_ty,
vars: function_vars.clone(), vars: function_vars.clone(),
args: vec![FuncArg { args: vec![FuncArg {
ty: other_ty, ty: other_tvar.ty,
default_value: None, default_value: None,
name: "other".into(), name: "other".into(),
is_vararg: false, is_vararg: false,
@ -541,36 +530,43 @@ pub fn typeof_binop(
} }
} }
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { let lhs_ndims = extract_ndims(unifier, lhs_ndims);
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1); let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
u64::try_from(values[0].clone()).unwrap() let rhs_ndims = extract_ndims(unifier, rhs_ndims);
if !(unifier.unioned(lhs_dtype, primitives.float)
&& unifier.unioned(rhs_dtype, primitives.float))
{
return Err(format!(
"ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}",
unifier.stringify(lhs),
unifier.stringify(rhs)
));
} }
_ => unreachable!(),
}; // Deduce the ndims of the resulting ndarray.
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); // If this is 0 (an unsized ndarray), matmul returns a scalar just like NumPy.
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { let result_ndims = match (lhs_ndims, rhs_ndims) {
TypeEnum::TLiteral { values, .. } => { (0, _) | (_, 0) => {
assert_eq!(values.len(), 1); return Err(
u64::try_from(values[0].clone()).unwrap() "ndarray.__matmul__ does not allow unsized ndarray input".to_string()
)
} }
_ => unreachable!(), (1, 1) => 0,
(1, _) => rhs_ndims - 1,
(_, 1) => lhs_ndims - 1,
(m, n) => max(m, n),
}; };
match (lhs_ndims, rhs_ndims) { if result_ndims == 0 {
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, // If the result is unsized, NumPy returns a scalar.
(lhs, rhs) if lhs == 0 || rhs == 0 => { primitives.float
return Err(format!( } else {
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", let result_ndims_ty =
u8::from(rhs == 0) unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None);
)) make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty))
}
(lhs, rhs) => {
return Err(format!(
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
))
}
} }
} }
@ -773,7 +769,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t)); impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None);
impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);