forked from M-Labs/nac3
[core] codegen/ndarray: Reimplement matmul
Based on 73c2203b
: core/ndstrides: implement general matmul
This commit is contained in:
parent
ebbadc2d74
commit
66b8a5e01d
@ -1,7 +1,6 @@
|
||||
#include "irrt/exception.hpp"
|
||||
#include "irrt/list.hpp"
|
||||
#include "irrt/math.hpp"
|
||||
#include "irrt/ndarray.hpp"
|
||||
#include "irrt/range.hpp"
|
||||
#include "irrt/slice.hpp"
|
||||
#include "irrt/string.hpp"
|
||||
@ -13,3 +12,4 @@
|
||||
#include "irrt/ndarray/reshape.hpp"
|
||||
#include "irrt/ndarray/broadcast.hpp"
|
||||
#include "irrt/ndarray/transpose.hpp"
|
||||
#include "irrt/ndarray/matmul.hpp"
|
@ -21,7 +21,5 @@ using uint64_t = unsigned _ExtInt(64);
|
||||
|
||||
#endif
|
||||
|
||||
// NDArray indices are always `uint32_t`.
|
||||
using NDIndexInt = uint32_t;
|
||||
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
|
||||
using SliceIndex = int32_t;
|
||||
|
@ -1,50 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
// TODO: To be deleted since NDArray with strides is done.
|
||||
|
||||
namespace {
|
||||
template<typename SizeT>
|
||||
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
||||
__builtin_assume(end_idx <= list_len);
|
||||
|
||||
SizeT num_elems = 1;
|
||||
for (SizeT i = begin_idx; i < end_idx; ++i) {
|
||||
SizeT val = list_data[i];
|
||||
__builtin_assume(val > 0);
|
||||
num_elems *= val;
|
||||
}
|
||||
return num_elems;
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) {
|
||||
SizeT stride = 1;
|
||||
for (SizeT dim = 0; dim < num_dims; dim++) {
|
||||
SizeT i = num_dims - dim - 1;
|
||||
__builtin_assume(dims[i] > 0);
|
||||
idxs[i] = (index / stride) % dims[i];
|
||||
stride *= dims[i];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
|
||||
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||
}
|
||||
|
||||
uint64_t
|
||||
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
|
||||
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) {
|
||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) {
|
||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||
}
|
||||
}
|
98
nac3core/irrt/irrt/ndarray/matmul.hpp
Normal file
98
nac3core/irrt/irrt/ndarray/matmul.hpp
Normal file
@ -0,0 +1,98 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/debug.hpp"
|
||||
#include "irrt/exception.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/ndarray/basic.hpp"
|
||||
#include "irrt/ndarray/broadcast.hpp"
|
||||
#include "irrt/ndarray/iter.hpp"
|
||||
|
||||
// NOTE: Everything would be much easier and elegant if einsum is implemented.
|
||||
|
||||
namespace {
|
||||
namespace ndarray::matmul {
|
||||
|
||||
/**
|
||||
* @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`.
|
||||
*
|
||||
* Example:
|
||||
* Suppose `a_shape == [1, 97, 4, 2]`
|
||||
* and `b_shape == [99, 98, 1, 2, 5]`,
|
||||
*
|
||||
* ...then `new_a_shape == [99, 98, 97, 4, 2]`,
|
||||
* `new_b_shape == [99, 98, 97, 2, 5]`,
|
||||
* and `dst_shape == [99, 98, 97, 4, 5]`.
|
||||
* ^^^^^^^^^^ ^^^^
|
||||
* (broadcasted) (4x2 @ 2x5 => 4x5)
|
||||
*
|
||||
* @param a_ndims Length of `a_shape`.
|
||||
* @param a_shape Shape of `a`.
|
||||
* @param b_ndims Length of `b_shape`.
|
||||
* @param b_shape Shape of `b`.
|
||||
* @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`,
|
||||
* `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void calculate_shapes(SizeT a_ndims,
|
||||
SizeT* a_shape,
|
||||
SizeT b_ndims,
|
||||
SizeT* b_shape,
|
||||
SizeT final_ndims,
|
||||
SizeT* new_a_shape,
|
||||
SizeT* new_b_shape,
|
||||
SizeT* dst_shape) {
|
||||
debug_assert(SizeT, a_ndims >= 2);
|
||||
debug_assert(SizeT, b_ndims >= 2);
|
||||
debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims);
|
||||
|
||||
// Check that a and b are compatible for matmul
|
||||
if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) {
|
||||
// This is a custom error message. Different from NumPy.
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})",
|
||||
a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM);
|
||||
}
|
||||
|
||||
const SizeT num_entries = 2;
|
||||
ShapeEntry<SizeT> entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape},
|
||||
{.ndims = b_ndims - 2, .shape = b_shape}};
|
||||
|
||||
// TODO: Optimize this
|
||||
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_a_shape);
|
||||
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_b_shape);
|
||||
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, dst_shape);
|
||||
|
||||
new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
||||
new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1];
|
||||
new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2];
|
||||
new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
||||
dst_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
||||
dst_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
||||
}
|
||||
} // namespace ndarray::matmul
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
using namespace ndarray::matmul;
|
||||
|
||||
void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims,
|
||||
int32_t* a_shape,
|
||||
int32_t b_ndims,
|
||||
int32_t* b_shape,
|
||||
int32_t final_ndims,
|
||||
int32_t* new_a_shape,
|
||||
int32_t* new_b_shape,
|
||||
int32_t* dst_shape) {
|
||||
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims,
|
||||
int64_t* a_shape,
|
||||
int64_t b_ndims,
|
||||
int64_t* b_shape,
|
||||
int64_t final_ndims,
|
||||
int64_t* new_a_shape,
|
||||
int64_t* new_b_shape,
|
||||
int64_t* dst_shape) {
|
||||
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
||||
}
|
||||
}
|
@ -27,7 +27,7 @@ use super::{
|
||||
call_memcpy_generic,
|
||||
},
|
||||
macros::codegen_unreachable,
|
||||
need_sret, numpy,
|
||||
need_sret,
|
||||
stmt::{
|
||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||
gen_var,
|
||||
@ -1534,37 +1534,6 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
let left = ScalarOrNDArray::from_value(generator, ctx, (ty1, left_val));
|
||||
let right = ScalarOrNDArray::from_value(generator, ctx, (ty2, right_val));
|
||||
|
||||
if op.base == Operator::MatMult {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
||||
|
||||
let left = left.to_ndarray(generator, ctx);
|
||||
let right = right.to_ndarray(generator, ctx);
|
||||
|
||||
// MatMult is the only binop which is not an elementwise op
|
||||
let result = numpy::ndarray_matmul_2d(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype1,
|
||||
match op.variant {
|
||||
BinopVariant::Normal => None,
|
||||
BinopVariant::AugAssign => Some(left),
|
||||
},
|
||||
left,
|
||||
right,
|
||||
)?;
|
||||
|
||||
Ok(Some(result.as_base_value().into()))
|
||||
} else {
|
||||
// For other operations, they are all elementwise operations.
|
||||
|
||||
// There are only three cases:
|
||||
// - LHS is a scalar, RHS is an ndarray.
|
||||
// - LHS is an ndarray, RHS is a scalar.
|
||||
// - LHS is an ndarray, RHS is an ndarray.
|
||||
//
|
||||
// For all cases, the scalar operand is promoted to an ndarray,
|
||||
// the two are then broadcasted, and starmapped through.
|
||||
|
||||
let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1);
|
||||
let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2);
|
||||
|
||||
@ -1577,8 +1546,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
let out = match op.variant {
|
||||
BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: llvm_common_dtype },
|
||||
BinopVariant::AugAssign => {
|
||||
// If this is an augmented assignment.
|
||||
// `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it.
|
||||
// Augmented assignment - `left` has to be an ndarray. If it were a scalar then NAC3
|
||||
// simply doesn't support it.
|
||||
if let ScalarOrNDArray::NDArray(out_ndarray) = left {
|
||||
NDArrayOut::WriteToNDArray { ndarray: out_ndarray }
|
||||
} else {
|
||||
@ -1587,6 +1556,24 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
};
|
||||
|
||||
if op.base == Operator::MatMult {
|
||||
let left = left.to_ndarray(generator, ctx);
|
||||
let right = right.to_ndarray(generator, ctx);
|
||||
let result = left
|
||||
.matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out))
|
||||
.split_unsized(generator, ctx);
|
||||
Ok(Some(result.to_basic_value_enum().into()))
|
||||
} else {
|
||||
// For other operations, they are all elementwise operations.
|
||||
|
||||
// There are only three cases:
|
||||
// - LHS is a scalar, RHS is an ndarray.
|
||||
// - LHS is an ndarray, RHS is a scalar.
|
||||
// - LHS is an ndarray, RHS is an ndarray.
|
||||
//
|
||||
// For all cases, the scalar operand is promoted to an ndarray,
|
||||
// the two are then broadcasted, and starmapped through.
|
||||
|
||||
let left = left.to_ndarray(generator, ctx);
|
||||
let right = right.to_ndarray(generator, ctx);
|
||||
|
||||
|
66
nac3core/src/codegen/irrt/ndarray/matmul.rs
Normal file
66
nac3core/src/codegen/irrt/ndarray/matmul.rs
Normal file
@ -0,0 +1,66 @@
|
||||
use inkwell::{types::BasicTypeEnum, values::IntValue};
|
||||
|
||||
use crate::codegen::{
|
||||
expr::infer_and_call_function, irrt::get_usize_dependent_function_name,
|
||||
values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_matmul_calculate_shapes`.
|
||||
///
|
||||
/// Calculates the broadcasted shapes for `a`, `b`, and the `ndarray` holding the final values of
|
||||
/// `a @ b`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
final_ndims: IntValue<'ctx>,
|
||||
new_a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(b_shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(new_a_shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(new_b_shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(dst_shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name =
|
||||
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes");
|
||||
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[
|
||||
a_shape.size(ctx, generator).into(),
|
||||
a_shape.base_ptr(ctx, generator).into(),
|
||||
b_shape.size(ctx, generator).into(),
|
||||
b_shape.base_ptr(ctx, generator).into(),
|
||||
final_ndims.into(),
|
||||
new_a_shape.base_ptr(ctx, generator).into(),
|
||||
new_b_shape.base_ptr(ctx, generator).into(),
|
||||
dst_shape.base_ptr(ctx, generator).into(),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
@ -1,23 +1,9 @@
|
||||
use inkwell::{
|
||||
types::BasicTypeEnum,
|
||||
values::{BasicValueEnum, CallSiteValue, IntValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Either;
|
||||
|
||||
use super::get_usize_dependent_function_name;
|
||||
use crate::codegen::{
|
||||
values::{
|
||||
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue,
|
||||
TypedArrayLikeAdapter,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
pub use array::*;
|
||||
pub use basic::*;
|
||||
pub use broadcast::*;
|
||||
pub use indexing::*;
|
||||
pub use iter::*;
|
||||
pub use matmul::*;
|
||||
pub use reshape::*;
|
||||
pub use transpose::*;
|
||||
|
||||
@ -26,119 +12,6 @@ mod basic;
|
||||
mod broadcast;
|
||||
mod indexing;
|
||||
mod iter;
|
||||
mod matmul;
|
||||
mod reshape;
|
||||
mod transpose;
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_size`. Returns a
|
||||
/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size.
|
||||
///
|
||||
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
|
||||
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
|
||||
/// or [`None`] if starting from the first dimension and ending at the last dimension
|
||||
/// respectively.
|
||||
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dims: &Dims,
|
||||
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
|
||||
) -> IntValue<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Dims: ArrayLikeIndexer<'ctx>,
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
assert!(begin.is_none_or(|begin| begin.get_type() == llvm_usize));
|
||||
assert!(end.is_none_or(|end| end.get_type() == llvm_usize));
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(dims.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let ndarray_calc_size_fn_name =
|
||||
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_size");
|
||||
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
|
||||
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
|
||||
false,
|
||||
);
|
||||
let ndarray_calc_size_fn =
|
||||
ctx.module.get_function(&ndarray_calc_size_fn_name).unwrap_or_else(|| {
|
||||
ctx.module.add_function(&ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
|
||||
});
|
||||
|
||||
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
|
||||
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_size_fn,
|
||||
&[
|
||||
dims.base_ptr(ctx, generator).into(),
|
||||
dims.size(ctx, generator).into(),
|
||||
begin.into(),
|
||||
end.into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypedArrayLikeAdapter`]
|
||||
/// containing `i32` indices of the flattened index.
|
||||
///
|
||||
/// * `index` - The `llvm_usize` index to compute the multidimensional index for.
|
||||
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
index: IntValue<'ctx>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> {
|
||||
let llvm_void = ctx.ctx.void_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
assert_eq!(index.get_type(), llvm_usize);
|
||||
|
||||
let ndarray_calc_nd_indices_fn_name =
|
||||
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_nd_indices");
|
||||
let ndarray_calc_nd_indices_fn =
|
||||
ctx.module.get_function(&ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_void.fn_type(
|
||||
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(&ndarray_calc_nd_indices_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let ndarray_dims = ndarray.shape();
|
||||
|
||||
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_nd_indices_fn,
|
||||
&[
|
||||
index.into(),
|
||||
ndarray_dims.base_ptr(ctx, generator).into(),
|
||||
ndarray_num_dims.into(),
|
||||
indices.into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
|
||||
|_, _, v| v.into_int_value(),
|
||||
|_, _, v| v.into(),
|
||||
)
|
||||
}
|
||||
|
@ -1,736 +1,23 @@
|
||||
use inkwell::{
|
||||
types::BasicType,
|
||||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
||||
IntPredicate, OptimizationLevel,
|
||||
values::{BasicValue, BasicValueEnum, PointerValue},
|
||||
IntPredicate,
|
||||
};
|
||||
|
||||
use nac3parser::ast::{Operator, StrRef};
|
||||
use nac3parser::ast::StrRef;
|
||||
|
||||
use super::{
|
||||
expr::gen_binop_expr_with_values,
|
||||
irrt::{
|
||||
calculate_len_for_slice_range,
|
||||
ndarray::{call_ndarray_calc_nd_indices, call_ndarray_calc_size},
|
||||
},
|
||||
llvm_intrinsics::{self, call_memcpy_generic},
|
||||
macros::codegen_unreachable,
|
||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||
types::ndarray::{factory::ndarray_zero_value, NDArrayType},
|
||||
values::{
|
||||
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
|
||||
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor,
|
||||
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
|
||||
UntypedArrayLikeMutator,
|
||||
},
|
||||
stmt::gen_for_callback_incrementing,
|
||||
types::ndarray::NDArrayType,
|
||||
values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, UntypedArrayLikeAccessor},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
use crate::{
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId},
|
||||
typecheck::{
|
||||
magic_methods::Binop,
|
||||
typedef::{FunSignature, Type},
|
||||
},
|
||||
typecheck::typedef::{FunSignature, Type},
|
||||
};
|
||||
|
||||
/// Creates an `NDArray` instance from a dynamic shape.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `shape` - The shape of the `NDArray`.
|
||||
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
|
||||
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
|
||||
fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
shape: &V,
|
||||
shape_len_fn: LenFn,
|
||||
shape_data_fn: DataFn,
|
||||
) -> Result<NDArrayValue<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
|
||||
DataFn: Fn(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
&V,
|
||||
IntValue<'ctx>,
|
||||
) -> Result<IntValue<'ctx>, String>,
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
// Assert that all dimensions are non-negative
|
||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(shape_len, false),
|
||||
|generator, ctx, _, i| {
|
||||
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
||||
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||
|
||||
let shape_dim_gez = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::SGE,
|
||||
shape_dim,
|
||||
shape_dim.get_type().const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
shape_dim_gez,
|
||||
"0:ValueError",
|
||||
"negative dimensions not supported",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// TODO: Disallow shape > u32_MAX
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let num_dims = shape_len_fn(generator, ctx, shape)?;
|
||||
|
||||
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None)
|
||||
.construct_dyn_ndims(generator, ctx, num_dims, None);
|
||||
|
||||
// Copy the dimension sizes from shape to ndarray.dims
|
||||
let shape_len = shape_len_fn(generator, ctx, shape)?;
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(shape_len, false),
|
||||
|generator, ctx, _, i| {
|
||||
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
||||
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
||||
|
||||
let ndarray_pdim =
|
||||
unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) };
|
||||
|
||||
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
unsafe { ndarray.create_data(generator, ctx) };
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
||||
/// its input.
|
||||
fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
value_fn: ValueFn,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
ValueFn: Fn(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
IntValue<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let ndarray_num_elems = ndarray.size(generator, ctx);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(ndarray_num_elems, false),
|
||||
|generator, ctx, _, i| {
|
||||
let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) };
|
||||
|
||||
let value = value_fn(generator, ctx, i)?;
|
||||
ctx.builder.build_store(elem, value).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
|
||||
/// as its input.
|
||||
fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
value_fn: ValueFn,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
ValueFn: Fn(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
&TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| {
|
||||
let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray);
|
||||
|
||||
value_fn(generator, ctx, &indices)
|
||||
})
|
||||
}
|
||||
|
||||
/// Copies a slice of an [`NDArrayValue`] to another.
|
||||
///
|
||||
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape`
|
||||
/// fields should be populated before calling this function.
|
||||
/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
|
||||
/// dimensional slice in the destination array.
|
||||
/// - `src_arr`: The [`NDArrayValue`] instance of the source array.
|
||||
/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
|
||||
/// dimensional slice in the source array.
|
||||
/// - `dim`: The index of the currently processing dimension.
|
||||
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
|
||||
/// this dimension. The `start`/`stop` values of each slice must be non-negative indices.
|
||||
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
||||
(src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
|
||||
dim: u64,
|
||||
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
|
||||
) -> Result<(), String> {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type());
|
||||
|
||||
let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap();
|
||||
|
||||
// If there are no (remaining) slice expressions, memcpy the entire dimension
|
||||
if slices.is_empty() {
|
||||
let stride = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
&src_arr.shape(),
|
||||
(Some(llvm_usize.const_int(dim, false)), None),
|
||||
);
|
||||
let stride =
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap();
|
||||
|
||||
let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap();
|
||||
|
||||
call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero());
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// The stride of elements in this dimension, i.e. the number of elements between arr[i] and
|
||||
// arr[i + 1] in this dimension
|
||||
let src_stride = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
&src_arr.shape(),
|
||||
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
||||
);
|
||||
let dst_stride = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
&dst_arr.shape(),
|
||||
(Some(llvm_usize.const_int(dim + 1, false)), None),
|
||||
);
|
||||
|
||||
let (start, stop, step) = slices[0];
|
||||
let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap();
|
||||
let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap();
|
||||
let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap();
|
||||
|
||||
let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap();
|
||||
ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap();
|
||||
|
||||
gen_for_range_callback(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
false,
|
||||
|_, _| Ok(start),
|
||||
(|_, _| Ok(stop), true),
|
||||
|_, _| Ok(step),
|
||||
|generator, ctx, _, src_i| {
|
||||
// Calculate the offset of the active slice
|
||||
let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap();
|
||||
let src_data_offset = ctx
|
||||
.builder
|
||||
.build_int_mul(
|
||||
src_data_offset,
|
||||
ctx.builder
|
||||
.build_int_z_extend_or_bit_cast(sizeof_elem, src_data_offset.get_type(), "")
|
||||
.unwrap(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
let dst_i =
|
||||
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
||||
let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap();
|
||||
let dst_data_offset = ctx
|
||||
.builder
|
||||
.build_int_mul(
|
||||
dst_data_offset,
|
||||
ctx.builder
|
||||
.build_int_z_extend_or_bit_cast(sizeof_elem, dst_data_offset.get_type(), "")
|
||||
.unwrap(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (src_ptr, dst_ptr) = unsafe {
|
||||
(
|
||||
ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(),
|
||||
ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(),
|
||||
)
|
||||
};
|
||||
|
||||
ndarray_sliced_copyto_impl(
|
||||
generator,
|
||||
ctx,
|
||||
(dst_arr, dst_ptr),
|
||||
(src_arr, src_ptr),
|
||||
dim + 1,
|
||||
&slices[1..],
|
||||
)?;
|
||||
|
||||
let dst_i =
|
||||
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
||||
let dst_i_add1 =
|
||||
ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap();
|
||||
ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Copies a [`NDArrayValue`] using slices.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
|
||||
/// this dimension. The `start`/`stop` values of each slice must be positive indices.
|
||||
pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
this: NDArrayValue<'ctx>,
|
||||
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let ndarray =
|
||||
if slices.is_empty() {
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&this,
|
||||
|_, ctx, shape| Ok(shape.load_ndims(ctx)),
|
||||
|generator, ctx, shape, idx| unsafe {
|
||||
Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
||||
},
|
||||
)?
|
||||
} else {
|
||||
let ndarray = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, None)
|
||||
.construct_dyn_ndims(generator, ctx, this.load_ndims(ctx), None);
|
||||
|
||||
// Populate the first slices.len() dimensions by computing the size of each dim slice
|
||||
for (i, (start, stop, step)) in slices.iter().enumerate() {
|
||||
// HACK: workaround calculate_len_for_slice_range requiring exclusive stop
|
||||
let stop = ctx
|
||||
.builder
|
||||
.build_select(
|
||||
ctx.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
*step,
|
||||
llvm_i32.const_zero(),
|
||||
"is_neg",
|
||||
)
|
||||
.unwrap(),
|
||||
ctx.builder
|
||||
.build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one")
|
||||
.unwrap(),
|
||||
ctx.builder
|
||||
.build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one")
|
||||
.unwrap(),
|
||||
"final_e",
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
|
||||
let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step);
|
||||
let slice_len =
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap();
|
||||
|
||||
unsafe {
|
||||
ndarray.shape().set_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(i as u64, false),
|
||||
slice_len,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Populate the rest by directly copying the dim size from the source array
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_int(slices.len() as u64, false),
|
||||
(this.load_ndims(ctx), false),
|
||||
|generator, ctx, _, idx| {
|
||||
unsafe {
|
||||
let shape = this.shape().get_typed_unchecked(ctx, generator, &idx, None);
|
||||
ndarray.shape().set_typed_unchecked(ctx, generator, &idx, shape);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
unsafe { ndarray.create_data(generator, ctx) };
|
||||
|
||||
ndarray
|
||||
};
|
||||
|
||||
ndarray_sliced_copyto_impl(
|
||||
generator,
|
||||
ctx,
|
||||
(ndarray, ndarray.data().base_ptr(ctx, generator)),
|
||||
(this, this.data().base_ptr(ctx, generator)),
|
||||
0,
|
||||
slices,
|
||||
)?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
this: NDArrayValue<'ctx>,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
||||
/// written to a new `ndarray`.
|
||||
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
res: Option<NDArrayValue<'ctx>>,
|
||||
lhs: NDArrayValue<'ctx>,
|
||||
rhs: NDArrayValue<'ctx>,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if cfg!(debug_assertions) {
|
||||
let lhs_ndims = lhs.load_ndims(ctx);
|
||||
let rhs_ndims = rhs.load_ndims(ctx);
|
||||
|
||||
// lhs.ndims == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "")
|
||||
.unwrap(),
|
||||
"0:ValueError",
|
||||
"",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// rhs.ndims == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "")
|
||||
.unwrap(),
|
||||
"0:ValueError",
|
||||
"",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
if let Some(res) = res {
|
||||
let res_ndims = res.load_ndims(ctx);
|
||||
let res_dim0 = unsafe {
|
||||
res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
let res_dim1 = unsafe {
|
||||
res.shape().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(1, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
let lhs_dim0 = unsafe {
|
||||
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
let rhs_dim1 = unsafe {
|
||||
rhs.shape().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(1, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
// res.ndims == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
res_ndims,
|
||||
llvm_usize.const_int(2, false),
|
||||
"",
|
||||
)
|
||||
.unwrap(),
|
||||
"0:ValueError",
|
||||
"",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// res.dims[0] == lhs.dims[0]
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(),
|
||||
"0:ValueError",
|
||||
"",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// res.dims[1] == rhs.dims[0]
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(),
|
||||
"0:ValueError",
|
||||
"",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let lhs_dim1 = unsafe {
|
||||
lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
};
|
||||
let rhs_dim0 = unsafe {
|
||||
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
|
||||
// lhs.dims[1] == rhs.dims[0]
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(),
|
||||
"0:ValueError",
|
||||
"",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) {
|
||||
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
|
||||
} else {
|
||||
lhs
|
||||
};
|
||||
|
||||
let ndarray = res.unwrap_or_else(|| {
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&(lhs, rhs),
|
||||
|_, _, _| Ok(llvm_usize.const_int(2, false)),
|
||||
|generator, ctx, (lhs, rhs), idx| {
|
||||
gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "")
|
||||
.unwrap())
|
||||
},
|
||||
|generator, ctx| {
|
||||
Ok(Some(unsafe {
|
||||
lhs.shape().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_zero(),
|
||||
None,
|
||||
)
|
||||
}))
|
||||
},
|
||||
|generator, ctx| {
|
||||
Ok(Some(unsafe {
|
||||
rhs.shape().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(1, false),
|
||||
None,
|
||||
)
|
||||
}))
|
||||
},
|
||||
)
|
||||
.map(|v| v.map(BasicValueEnum::into_int_value).unwrap())
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| {
|
||||
llvm_intrinsics::call_expect(
|
||||
ctx,
|
||||
idx.size(ctx, generator).get_type().const_int(2, false),
|
||||
idx.size(ctx, generator),
|
||||
None,
|
||||
);
|
||||
|
||||
let common_dim = {
|
||||
let lhs_idx1 = unsafe {
|
||||
lhs.shape().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(1, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
let rhs_idx0 = unsafe {
|
||||
rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
|
||||
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
||||
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(idx, llvm_usize, "").unwrap()
|
||||
};
|
||||
|
||||
let idx0 = unsafe {
|
||||
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
||||
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(idx0, llvm_usize, "").unwrap()
|
||||
};
|
||||
let idx1 = unsafe {
|
||||
let idx1 =
|
||||
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
|
||||
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(idx1, llvm_usize, "").unwrap()
|
||||
};
|
||||
|
||||
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
||||
let result_identity = ndarray_zero_value(generator, ctx, elem_ty);
|
||||
ctx.builder.build_store(result_addr, result_identity).unwrap();
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(common_dim, false),
|
||||
|generator, ctx, _, i| {
|
||||
let ab_idx = generator.gen_array_var_alloc(
|
||||
ctx,
|
||||
llvm_usize.into(),
|
||||
llvm_usize.const_int(2, false),
|
||||
None,
|
||||
)?;
|
||||
|
||||
let a = unsafe {
|
||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into());
|
||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into());
|
||||
|
||||
lhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
||||
};
|
||||
let b = unsafe {
|
||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into());
|
||||
ab_idx.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(1, false),
|
||||
idx1.into(),
|
||||
);
|
||||
|
||||
rhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
||||
};
|
||||
|
||||
let a_mul_b = gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(elem_ty), a),
|
||||
Binop::normal(Operator::Mult),
|
||||
(&Some(elem_ty), b),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
||||
|
||||
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
||||
let result = gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(elem_ty), result),
|
||||
Binop::normal(Operator::Add),
|
||||
(&Some(elem_ty), a_mul_b),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
||||
ctx.builder.build_store(result_addr, result).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
||||
Ok(result)
|
||||
})?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.empty`.
|
||||
pub fn gen_ndarray_empty<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
|
@ -12,7 +12,7 @@ use crate::{
|
||||
};
|
||||
|
||||
/// Get the zero value in `np.zeros()` of a `dtype`.
|
||||
pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
|
334
nac3core/src/codegen/values/ndarray/matmul.rs
Normal file
334
nac3core/src/codegen/values/ndarray/matmul.rs
Normal file
@ -0,0 +1,334 @@
|
||||
use std::cmp::max;
|
||||
|
||||
use nac3parser::ast::Operator;
|
||||
|
||||
use super::{NDArrayOut, NDArrayValue, RustNDIndex};
|
||||
use crate::{
|
||||
codegen::{
|
||||
expr::gen_binop_expr_with_values,
|
||||
irrt,
|
||||
stmt::gen_for_callback_incrementing,
|
||||
types::ndarray::NDArrayType,
|
||||
values::{
|
||||
ArrayLikeValue, ArraySliceValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
toplevel::helper::arraylike_flatten_element_type,
|
||||
typecheck::{magic_methods::Binop, typedef::Type},
|
||||
};
|
||||
|
||||
/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`.
|
||||
///
|
||||
/// `dst_dtype` defines the dtype of the returned ndarray.
|
||||
fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dst_dtype: Type,
|
||||
(in_a_ty, in_a): (Type, NDArrayValue<'ctx>),
|
||||
(in_b_ty, in_b): (Type, NDArrayValue<'ctx>),
|
||||
) -> NDArrayValue<'ctx> {
|
||||
assert!(
|
||||
in_a.ndims.is_some_and(|ndims| ndims >= 2),
|
||||
"in_a (which is {:?}) must be compile-time known and >= 2",
|
||||
in_a.ndims
|
||||
);
|
||||
assert!(
|
||||
in_b.ndims.is_some_and(|ndims| ndims >= 2),
|
||||
"in_b (which is {:?}) must be compile-time known and >= 2",
|
||||
in_b.ndims
|
||||
);
|
||||
|
||||
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty);
|
||||
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype);
|
||||
|
||||
// Deduce ndims of the result of matmul.
|
||||
let ndims_int = max(in_a.ndims.unwrap(), in_b.ndims.unwrap());
|
||||
let ndims = llvm_usize.const_int(ndims_int, false);
|
||||
|
||||
// Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the
|
||||
// destination ndarray to store the result of matmul.
|
||||
let (lhs, rhs, dst) = {
|
||||
let in_lhs_ndims = llvm_usize.const_int(in_a.ndims.unwrap(), false);
|
||||
let in_lhs_shape = TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(
|
||||
in_a.shape().base_ptr(ctx, generator),
|
||||
in_lhs_ndims,
|
||||
None,
|
||||
),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
let in_rhs_ndims = llvm_usize.const_int(in_b.ndims.unwrap(), false);
|
||||
let in_rhs_shape = TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(
|
||||
in_b.shape().base_ptr(ctx, generator),
|
||||
in_rhs_ndims,
|
||||
None,
|
||||
),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
let lhs_shape = TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(
|
||||
ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(),
|
||||
ndims,
|
||||
None,
|
||||
),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
let rhs_shape = TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(
|
||||
ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(),
|
||||
ndims,
|
||||
None,
|
||||
),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
let dst_shape = TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(
|
||||
ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(),
|
||||
ndims,
|
||||
None,
|
||||
),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
|
||||
// Matmul dimension compatibility is checked here.
|
||||
irrt::ndarray::call_nac3_ndarray_matmul_calculate_shapes(
|
||||
generator,
|
||||
ctx,
|
||||
&in_lhs_shape,
|
||||
&in_rhs_shape,
|
||||
ndims,
|
||||
&lhs_shape,
|
||||
&rhs_shape,
|
||||
&dst_shape,
|
||||
);
|
||||
|
||||
let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape);
|
||||
let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape);
|
||||
|
||||
let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, Some(ndims_int))
|
||||
.construct_uninitialized(generator, ctx, None);
|
||||
dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator));
|
||||
unsafe {
|
||||
dst.create_data(generator, ctx);
|
||||
}
|
||||
|
||||
(lhs, rhs, dst)
|
||||
};
|
||||
|
||||
let len = unsafe {
|
||||
lhs.shape().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(ndims_int - 1, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
let at_row = i64::try_from(ndims_int - 2).unwrap();
|
||||
let at_col = i64::try_from(ndims_int - 1).unwrap();
|
||||
|
||||
let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype);
|
||||
let dst_zero = dst_dtype_llvm.const_zero();
|
||||
|
||||
dst.foreach(generator, ctx, |generator, ctx, _, hdl| {
|
||||
let pdst_ij = hdl.get_pointer(ctx);
|
||||
|
||||
ctx.builder.build_store(pdst_ij, dst_zero).unwrap();
|
||||
|
||||
let indices = hdl.get_indices::<G>();
|
||||
let i = unsafe {
|
||||
indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_row as u64, true), None)
|
||||
};
|
||||
let j = unsafe {
|
||||
indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_col as u64, true), None)
|
||||
};
|
||||
|
||||
let num_0 = llvm_usize.const_int(0, false);
|
||||
let num_1 = llvm_usize.const_int(1, false);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
num_0,
|
||||
(len, false),
|
||||
|generator, ctx, _, k| {
|
||||
// `indices` is modified to index into `a` and `b`, and restored.
|
||||
unsafe {
|
||||
indices.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(at_row as u64, true),
|
||||
i,
|
||||
);
|
||||
indices.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(at_col as u64, true),
|
||||
k.into(),
|
||||
);
|
||||
}
|
||||
let a_ik = unsafe { lhs.data().get_unchecked(ctx, generator, &indices, None) };
|
||||
|
||||
unsafe {
|
||||
indices.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(at_row as u64, true),
|
||||
k.into(),
|
||||
);
|
||||
indices.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(at_col as u64, true),
|
||||
j,
|
||||
);
|
||||
}
|
||||
let b_kj = unsafe { rhs.data().get_unchecked(ctx, generator, &indices, None) };
|
||||
|
||||
// Restore `indices`.
|
||||
unsafe {
|
||||
indices.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(at_row as u64, true),
|
||||
i,
|
||||
);
|
||||
indices.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(at_col as u64, true),
|
||||
j,
|
||||
);
|
||||
}
|
||||
|
||||
// x = a_[...]ik * b_[...]kj
|
||||
let x = gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(lhs_dtype), a_ik),
|
||||
Binop::normal(Operator::Mult),
|
||||
(&Some(rhs_dtype), b_kj),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
||||
|
||||
// dst_[...]ij += x
|
||||
let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap();
|
||||
let dst_ij = gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(dst_dtype), dst_ij),
|
||||
Binop::normal(Operator::Add),
|
||||
(&Some(dst_dtype), x),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
||||
ctx.builder.build_store(pdst_ij, dst_ij).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
num_1,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
dst
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayValue<'ctx> {
|
||||
/// Perform [`np.matmul`](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
|
||||
///
|
||||
/// This function always return an [`NDArrayValue`]. You may want to use
|
||||
/// [`NDArrayValue::split_unsized`] to handle when the output could be a scalar.
|
||||
///
|
||||
/// `dst_dtype` defines the dtype of the returned ndarray.
|
||||
#[must_use]
|
||||
pub fn matmul<G: CodeGenerator>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
self_ty: Type,
|
||||
(other_ty, other): (Type, Self),
|
||||
(out_dtype, out): (Type, NDArrayOut<'ctx>),
|
||||
) -> Self {
|
||||
// Sanity check, but type inference should prevent this.
|
||||
assert!(
|
||||
self.ndims.is_some_and(|ndims| ndims > 0) && other.ndims.is_some_and(|ndims| ndims > 0),
|
||||
"np.matmul disallows scalar input"
|
||||
);
|
||||
|
||||
// If both arguments are 2-D they are multiplied like conventional matrices.
|
||||
//
|
||||
// If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the
|
||||
// last two indices and broadcast accordingly.
|
||||
//
|
||||
// If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its
|
||||
// dimensions. After matrix multiplication the prepended 1 is removed.
|
||||
//
|
||||
// If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its
|
||||
// dimensions. After matrix multiplication the appended 1 is removed.
|
||||
|
||||
let new_a = if self.ndims.unwrap() == 1 {
|
||||
// Prepend 1 to its dimensions
|
||||
self.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis])
|
||||
} else {
|
||||
*self
|
||||
};
|
||||
|
||||
let new_b = if other.ndims.unwrap() == 1 {
|
||||
// Append 1 to its dimensions
|
||||
other.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis])
|
||||
} else {
|
||||
other
|
||||
};
|
||||
|
||||
// NOTE: `result` will always be a newly allocated ndarray.
|
||||
// Current implementation cannot do in-place matrix muliplication.
|
||||
let mut result =
|
||||
matmul_at_least_2d(generator, ctx, out_dtype, (self_ty, new_a), (other_ty, new_b));
|
||||
|
||||
// Postprocessing on the result to remove prepended/appended axes.
|
||||
let mut postindices = vec![];
|
||||
let zero = ctx.ctx.i32_type().const_zero();
|
||||
|
||||
if self.ndims.unwrap() == 1 {
|
||||
// Remove the prepended 1
|
||||
postindices.push(RustNDIndex::SingleElement(zero));
|
||||
}
|
||||
|
||||
if other.ndims.unwrap() == 1 {
|
||||
// Remove the appended 1
|
||||
postindices.push(RustNDIndex::Ellipsis);
|
||||
postindices.push(RustNDIndex::SingleElement(zero));
|
||||
}
|
||||
|
||||
if !postindices.is_empty() {
|
||||
result = result.index(generator, ctx, &postindices);
|
||||
}
|
||||
|
||||
match out {
|
||||
NDArrayOut::NewNDArray { .. } => result,
|
||||
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
|
||||
let result_shape = result.shape();
|
||||
out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape);
|
||||
|
||||
out_ndarray.copy_data_from(generator, ctx, result);
|
||||
out_ndarray
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -32,6 +32,7 @@ mod broadcast;
|
||||
mod contiguous;
|
||||
mod indexing;
|
||||
mod map;
|
||||
mod matmul;
|
||||
mod nditer;
|
||||
pub mod shape;
|
||||
mod view;
|
||||
|
@ -8,5 +8,5 @@ expression: res_vec
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(254)]\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
|
||||
]
|
||||
|
@ -7,7 +7,7 @@ expression: res_vec
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar238]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar238\"]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||
|
@ -5,8 +5,8 @@ expression: res_vec
|
||||
[
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(251)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n",
|
||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
||||
expression: res_vec
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar237, typevar238]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar237\", \"typevar238\"]\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||
|
@ -6,12 +6,12 @@ expression: res_vec
|
||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(257)]\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(265)]\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n",
|
||||
]
|
||||
|
@ -7,12 +7,12 @@ use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop};
|
||||
|
||||
use super::{
|
||||
type_inferencer::*,
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||
typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||
};
|
||||
use crate::{
|
||||
symbol_resolver::SymbolValue,
|
||||
toplevel::{
|
||||
helper::PrimDef,
|
||||
helper::{extract_ndims, PrimDef},
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
},
|
||||
};
|
||||
@ -175,19 +175,8 @@ pub fn impl_binop(
|
||||
ops: &[Operator],
|
||||
) {
|
||||
with_fields(unifier, ty, |unifier, fields| {
|
||||
let (other_ty, other_var_id) = if other_ty.len() == 1 {
|
||||
(other_ty[0], None)
|
||||
} else {
|
||||
let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
|
||||
(tvar.ty, Some(tvar.id))
|
||||
};
|
||||
|
||||
let function_vars = if let Some(var_id) = other_var_id {
|
||||
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
|
||||
} else {
|
||||
VarMap::new()
|
||||
};
|
||||
|
||||
let other_tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
|
||||
let function_vars = into_var_map([other_tvar]);
|
||||
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
|
||||
|
||||
for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) {
|
||||
@ -198,7 +187,7 @@ pub fn impl_binop(
|
||||
ret: ret_ty,
|
||||
vars: function_vars.clone(),
|
||||
args: vec![FuncArg {
|
||||
ty: other_ty,
|
||||
ty: other_tvar.ty,
|
||||
default_value: None,
|
||||
name: "other".into(),
|
||||
is_vararg: false,
|
||||
@ -541,36 +530,43 @@ pub fn typeof_binop(
|
||||
}
|
||||
}
|
||||
|
||||
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
||||
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => {
|
||||
assert_eq!(values.len(), 1);
|
||||
u64::try_from(values[0].clone()).unwrap()
|
||||
let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
||||
let lhs_ndims = extract_ndims(unifier, lhs_ndims);
|
||||
|
||||
let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
||||
let rhs_ndims = extract_ndims(unifier, rhs_ndims);
|
||||
|
||||
if !(unifier.unioned(lhs_dtype, primitives.float)
|
||||
&& unifier.unioned(rhs_dtype, primitives.float))
|
||||
{
|
||||
return Err(format!(
|
||||
"ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}",
|
||||
unifier.stringify(lhs),
|
||||
unifier.stringify(rhs)
|
||||
));
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
||||
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => {
|
||||
assert_eq!(values.len(), 1);
|
||||
u64::try_from(values[0].clone()).unwrap()
|
||||
|
||||
// Deduce the ndims of the resulting ndarray.
|
||||
// If this is 0 (an unsized ndarray), matmul returns a scalar just like NumPy.
|
||||
let result_ndims = match (lhs_ndims, rhs_ndims) {
|
||||
(0, _) | (_, 0) => {
|
||||
return Err(
|
||||
"ndarray.__matmul__ does not allow unsized ndarray input".to_string()
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
(1, 1) => 0,
|
||||
(1, _) => rhs_ndims - 1,
|
||||
(_, 1) => lhs_ndims - 1,
|
||||
(m, n) => max(m, n),
|
||||
};
|
||||
|
||||
match (lhs_ndims, rhs_ndims) {
|
||||
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
|
||||
(lhs, rhs) if lhs == 0 || rhs == 0 => {
|
||||
return Err(format!(
|
||||
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
|
||||
u8::from(rhs == 0)
|
||||
))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
return Err(format!(
|
||||
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
|
||||
))
|
||||
}
|
||||
if result_ndims == 0 {
|
||||
// If the result is unsized, NumPy returns a scalar.
|
||||
primitives.float
|
||||
} else {
|
||||
let result_ndims_ty =
|
||||
unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None);
|
||||
make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty))
|
||||
}
|
||||
}
|
||||
|
||||
@ -773,7 +769,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
|
||||
impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None);
|
||||
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
||||
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
||||
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
|
Loading…
Reference in New Issue
Block a user