ndstrides: [11] Implement general matmul & np_dot()
#521
@ -13,3 +13,4 @@
|
|||||||
#include "irrt/ndarray/reshape.hpp"
|
#include "irrt/ndarray/reshape.hpp"
|
||||||
#include "irrt/ndarray/broadcast.hpp"
|
#include "irrt/ndarray/broadcast.hpp"
|
||||||
#include "irrt/ndarray/transpose.hpp"
|
#include "irrt/ndarray/transpose.hpp"
|
||||||
|
#include "irrt/ndarray/matmul.hpp"
|
100
nac3core/irrt/irrt/ndarray/matmul.hpp
Normal file
100
nac3core/irrt/irrt/ndarray/matmul.hpp
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "irrt/debug.hpp"
|
||||||
|
#include "irrt/exception.hpp"
|
||||||
|
#include "irrt/int_types.hpp"
|
||||||
|
#include "irrt/ndarray/basic.hpp"
|
||||||
|
#include "irrt/ndarray/broadcast.hpp"
|
||||||
|
#include "irrt/ndarray/iter.hpp"
|
||||||
|
|
||||||
|
// NOTE: Everything would be much easier and elegant if einsum is implemented.
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray {
|
||||||
|
namespace matmul {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* Suppose `a_shape == [1, 97, 4, 2]`
|
||||||
|
* and `b_shape == [99, 98, 1, 2, 5]`,
|
||||||
|
*
|
||||||
|
* ...then `new_a_shape == [99, 98, 97, 4, 2]`,
|
||||||
|
* `new_b_shape == [99, 98, 97, 2, 5]`,
|
||||||
|
* and `dst_shape == [99, 98, 97, 4, 5]`.
|
||||||
|
* ^^^^^^^^^^ ^^^^
|
||||||
|
* (broadcasted) (4x2 @ 2x5 => 4x5)
|
||||||
|
*
|
||||||
|
* @param a_ndims Length of `a_shape`.
|
||||||
|
* @param a_shape Shape of `a`.
|
||||||
|
* @param b_ndims Length of `b_shape`.
|
||||||
|
* @param b_shape Shape of `b`.
|
||||||
|
* @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`,
|
||||||
|
* `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void calculate_shapes(SizeT a_ndims,
|
||||||
|
SizeT* a_shape,
|
||||||
|
SizeT b_ndims,
|
||||||
|
SizeT* b_shape,
|
||||||
|
SizeT final_ndims,
|
||||||
|
SizeT* new_a_shape,
|
||||||
|
SizeT* new_b_shape,
|
||||||
|
SizeT* dst_shape) {
|
||||||
|
debug_assert(SizeT, a_ndims >= 2);
|
||||||
|
debug_assert(SizeT, b_ndims >= 2);
|
||||||
|
debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims);
|
||||||
|
|
||||||
|
// Check that a and b are compatible for matmul
|
||||||
|
if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) {
|
||||||
|
// This is a custom error message. Different from NumPy.
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})",
|
||||||
|
a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
|
const SizeT num_entries = 2;
|
||||||
|
ShapeEntry<SizeT> entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape},
|
||||||
|
{.ndims = b_ndims - 2, .shape = b_shape}};
|
||||||
|
|
||||||
|
// TODO: Optimize this
|
||||||
|
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_a_shape);
|
||||||
|
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_b_shape);
|
||||||
|
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, dst_shape);
|
||||||
|
|
||||||
|
new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
||||||
|
new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1];
|
||||||
|
new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2];
|
||||||
|
new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
||||||
|
dst_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
||||||
|
dst_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
||||||
|
}
|
||||||
|
} // namespace matmul
|
||||||
|
} // namespace ndarray
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
using namespace ndarray::matmul;
|
||||||
|
|
||||||
|
void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims,
|
||||||
|
int32_t* a_shape,
|
||||||
|
int32_t b_ndims,
|
||||||
|
int32_t* b_shape,
|
||||||
|
int32_t final_ndims,
|
||||||
|
int32_t* new_a_shape,
|
||||||
|
int32_t* new_b_shape,
|
||||||
|
int32_t* dst_shape) {
|
||||||
|
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims,
|
||||||
|
int64_t* a_shape,
|
||||||
|
int64_t b_ndims,
|
||||||
|
int64_t* b_shape,
|
||||||
|
int64_t final_ndims,
|
||||||
|
int64_t* new_a_shape,
|
||||||
|
int64_t* new_b_shape,
|
||||||
|
int64_t* dst_shape) {
|
||||||
|
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
||||||
|
}
|
||||||
|
}
|
@ -1580,7 +1580,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
|
|
||||||
if op.base == Operator::MatMult {
|
if op.base == Operator::MatMult {
|
||||||
// Handle matrix multiplication.
|
// Handle matrix multiplication.
|
||||||
todo!()
|
let left = left.to_ndarray(generator, ctx);
|
||||||
|
let right = right.to_ndarray(generator, ctx);
|
||||||
|
let result = NDArrayObject::matmul(generator, ctx, left, right, out)
|
||||||
|
.split_unsized(generator, ctx);
|
||||||
|
Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())))
|
||||||
} else {
|
} else {
|
||||||
// For other operations, they are all elementwise operations.
|
// For other operations, they are all elementwise operations.
|
||||||
|
|
||||||
|
@ -1220,3 +1220,30 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
.arg(axes)
|
.arg(axes)
|
||||||
.returning_void();
|
.returning_void();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a_ndims: Instance<'ctx, Int<SizeT>>,
|
||||||
|
a_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
b_ndims: Instance<'ctx, Int<SizeT>>,
|
||||||
|
b_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
final_ndims: Instance<'ctx, Int<SizeT>>,
|
||||||
|
new_a_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
new_b_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
dst_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
) {
|
||||||
|
let name =
|
||||||
|
get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes");
|
||||||
|
FnCall::builder(generator, ctx, &name)
|
||||||
|
.arg(a_ndims)
|
||||||
|
.arg(a_shape)
|
||||||
|
.arg(b_ndims)
|
||||||
|
.arg(b_shape)
|
||||||
|
.arg(final_ndims)
|
||||||
|
.arg(new_a_shape)
|
||||||
|
.arg(new_b_shape)
|
||||||
|
.arg(dst_shape)
|
||||||
|
.returning_void();
|
||||||
|
}
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType},
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType},
|
||||||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
|
|
||||||
use nac3parser::ast::{Operator, StrRef};
|
use nac3parser::ast::StrRef;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
classes::{
|
classes::{
|
||||||
@ -12,7 +12,6 @@ use super::{
|
|||||||
ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
||||||
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
},
|
},
|
||||||
expr::gen_binop_expr_with_values,
|
|
||||||
irrt::{
|
irrt::{
|
||||||
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
|
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
|
||||||
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size,
|
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size,
|
||||||
@ -22,9 +21,12 @@ use super::{
|
|||||||
model::*,
|
model::*,
|
||||||
object::{
|
object::{
|
||||||
any::AnyObject,
|
any::AnyObject,
|
||||||
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
|
ndarray::{nditer::NDIterHandle, shape_util::parse_numpy_int_sequence, NDArrayObject},
|
||||||
|
},
|
||||||
|
stmt::{
|
||||||
|
gen_for_callback, gen_for_callback_incrementing, gen_for_range_callback,
|
||||||
|
gen_if_else_expr_callback,
|
||||||
},
|
},
|
||||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -34,10 +36,7 @@ use crate::{
|
|||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::typedef::{FunSignature, Type},
|
||||||
magic_methods::Binop,
|
|
||||||
typedef::{FunSignature, Type},
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Creates an uninitialized `NDArray` instance.
|
/// Creates an uninitialized `NDArray` instance.
|
||||||
@ -1437,302 +1436,6 @@ where
|
|||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
|
||||||
///
|
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
|
||||||
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
|
||||||
/// written to a new `ndarray`.
|
|
||||||
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
elem_ty: Type,
|
|
||||||
res: Option<NDArrayValue<'ctx>>,
|
|
||||||
lhs: NDArrayValue<'ctx>,
|
|
||||||
rhs: NDArrayValue<'ctx>,
|
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
if cfg!(debug_assertions) {
|
|
||||||
let lhs_ndims = lhs.load_ndims(ctx);
|
|
||||||
let rhs_ndims = rhs.load_ndims(ctx);
|
|
||||||
|
|
||||||
// lhs.ndims == 2
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "")
|
|
||||||
.unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// rhs.ndims == 2
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "")
|
|
||||||
.unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(res) = res {
|
|
||||||
let res_ndims = res.load_ndims(ctx);
|
|
||||||
let res_dim0 = unsafe {
|
|
||||||
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
let res_dim1 = unsafe {
|
|
||||||
res.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
let lhs_dim0 = unsafe {
|
|
||||||
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
let rhs_dim1 = unsafe {
|
|
||||||
rhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
// res.ndims == 2
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::EQ,
|
|
||||||
res_ndims,
|
|
||||||
llvm_usize.const_int(2, false),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// res.dims[0] == lhs.dims[0]
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// res.dims[1] == rhs.dims[0]
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
|
||||||
let lhs_dim1 = unsafe {
|
|
||||||
lhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
let rhs_dim0 = unsafe {
|
|
||||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
|
|
||||||
// lhs.dims[1] == rhs.dims[0]
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) {
|
|
||||||
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
|
|
||||||
} else {
|
|
||||||
lhs
|
|
||||||
};
|
|
||||||
|
|
||||||
let ndarray = res.unwrap_or_else(|| {
|
|
||||||
create_ndarray_dyn_shape(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
elem_ty,
|
|
||||||
&(lhs, rhs),
|
|
||||||
|_, _, _| Ok(llvm_usize.const_int(2, false)),
|
|
||||||
|generator, ctx, (lhs, rhs), idx| {
|
|
||||||
gen_if_else_expr_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|_, ctx| {
|
|
||||||
Ok(ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "")
|
|
||||||
.unwrap())
|
|
||||||
},
|
|
||||||
|generator, ctx| {
|
|
||||||
Ok(Some(unsafe {
|
|
||||||
lhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_zero(),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}))
|
|
||||||
},
|
|
||||||
|generator, ctx| {
|
|
||||||
Ok(Some(unsafe {
|
|
||||||
rhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}))
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.map(|v| v.map(BasicValueEnum::into_int_value).unwrap())
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
});
|
|
||||||
|
|
||||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| {
|
|
||||||
llvm_intrinsics::call_expect(
|
|
||||||
ctx,
|
|
||||||
idx.size(ctx, generator).get_type().const_int(2, false),
|
|
||||||
idx.size(ctx, generator),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
let common_dim = {
|
|
||||||
let lhs_idx1 = unsafe {
|
|
||||||
lhs.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
let rhs_idx0 = unsafe {
|
|
||||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
|
|
||||||
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
|
||||||
|
|
||||||
ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
let idx0 = unsafe {
|
|
||||||
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
|
||||||
|
|
||||||
ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap()
|
|
||||||
};
|
|
||||||
let idx1 = unsafe {
|
|
||||||
let idx1 =
|
|
||||||
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
|
|
||||||
|
|
||||||
ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
|
||||||
let result_identity = ndarray_zero_value(generator, ctx, elem_ty);
|
|
||||||
ctx.builder.build_store(result_addr, result_identity).unwrap();
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
llvm_i32.const_zero(),
|
|
||||||
(common_dim, false),
|
|
||||||
|generator, ctx, _, i| {
|
|
||||||
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
|
|
||||||
|
|
||||||
let ab_idx = generator.gen_array_var_alloc(
|
|
||||||
ctx,
|
|
||||||
llvm_i32.into(),
|
|
||||||
llvm_usize.const_int(2, false),
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let a = unsafe {
|
|
||||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into());
|
|
||||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into());
|
|
||||||
|
|
||||||
lhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
|
||||||
};
|
|
||||||
let b = unsafe {
|
|
||||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into());
|
|
||||||
ab_idx.set_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_int(1, false),
|
|
||||||
idx1.into(),
|
|
||||||
);
|
|
||||||
|
|
||||||
rhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
|
||||||
};
|
|
||||||
|
|
||||||
let a_mul_b = gen_binop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(&Some(elem_ty), a),
|
|
||||||
Binop::normal(Operator::Mult),
|
|
||||||
(&Some(elem_ty), b),
|
|
||||||
ctx.current_loc,
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
|
||||||
|
|
||||||
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
|
||||||
let result = gen_binop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(&Some(elem_ty), result),
|
|
||||||
Binop::normal(Operator::Add),
|
|
||||||
(&Some(elem_ty), a_mul_b),
|
|
||||||
ctx.current_loc,
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, elem_ty)?;
|
|
||||||
ctx.builder.build_store(result_addr, result).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
|
||||||
Ok(result)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(ndarray)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.empty`.
|
/// Generates LLVM IR for `ndarray.empty`.
|
||||||
pub fn gen_ndarray_empty<'ctx>(
|
pub fn gen_ndarray_empty<'ctx>(
|
||||||
context: &mut CodeGenContext<'ctx, '_>,
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
@ -2004,77 +1707,88 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "ndarray_dot";
|
const FN_NAME: &str = "ndarray_dot";
|
||||||
let (x1_ty, x1) = x1;
|
let (x1_ty, x1) = x1;
|
||||||
let (_, x2) = x2;
|
let (x2_ty, x2) = x2;
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
match (x1, x2) {
|
match (x1, x2) {
|
||||||
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
|
(BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) => {
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
let a = AnyObject { ty: x1_ty, value: x1 };
|
||||||
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
let b = AnyObject { ty: x2_ty, value: x2 };
|
||||||
|
|
||||||
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let a = NDArrayObject::from_object(generator, ctx, a);
|
||||||
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let b = NDArrayObject::from_object(generator, ctx, b);
|
||||||
|
|
||||||
|
// TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html.
|
||||||
|
assert_eq!(a.ndims, 1);
|
||||||
|
assert_eq!(b.ndims, 1);
|
||||||
|
let common_dtype = a.dtype;
|
||||||
|
|
||||||
|
// Check shapes.
|
||||||
|
let a_size = a.size(generator, ctx);
|
||||||
|
let b_size = b.size(generator, ctx);
|
||||||
|
let same_shape = a_size.compare(ctx, IntPredicate::EQ, b_size);
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
generator,
|
generator,
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
|
same_shape.value,
|
||||||
"0:ValueError",
|
"0:ValueError",
|
||||||
"shapes ({0}), ({1}) not aligned",
|
"shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)",
|
||||||
[Some(n1_sz), Some(n2_sz), None],
|
[Some(a_size.value), Some(b_size.value), None],
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
);
|
);
|
||||||
|
|
||||||
let identity =
|
let dtype_llvm = ctx.get_llvm_type(generator, common_dtype);
|
||||||
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
|
||||||
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
|
|
||||||
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap();
|
||||||
|
ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap();
|
||||||
|
|
||||||
|
// Do dot product.
|
||||||
|
gen_for_callback(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
None,
|
Some("np_dot"),
|
||||||
llvm_usize.const_zero(),
|
|generator, ctx| {
|
||||||
(n1_sz, false),
|
let a_iter = NDIterHandle::new(generator, ctx, a);
|
||||||
|generator, ctx, _, idx| {
|
let b_iter = NDIterHandle::new(generator, ctx, b);
|
||||||
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
Ok((a_iter, b_iter))
|
||||||
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
|
},
|
||||||
|
|generator, ctx, (a_iter, _b_iter)| {
|
||||||
|
// Only a_iter drives the condition, b_iter should have the same status.
|
||||||
|
Ok(a_iter.has_element(generator, ctx).value)
|
||||||
|
},
|
||||||
|
|generator, ctx, _hooks, (a_iter, b_iter)| {
|
||||||
|
let a_scalar = a_iter.get_scalar(generator, ctx).value;
|
||||||
|
let b_scalar = b_iter.get_scalar(generator, ctx).value;
|
||||||
|
|
||||||
let product = match elem1 {
|
let old_result = ctx.builder.build_load(result, "").unwrap();
|
||||||
BasicValueEnum::IntValue(e1) => ctx
|
let new_result: BasicValueEnum<'ctx> = match old_result {
|
||||||
.builder
|
BasicValueEnum::IntValue(old_result) => {
|
||||||
.build_int_mul(e1, elem2.into_int_value(), "")
|
let a_scalar = a_scalar.into_int_value();
|
||||||
.unwrap()
|
let b_scalar = b_scalar.into_int_value();
|
||||||
.as_basic_value_enum(),
|
let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap();
|
||||||
BasicValueEnum::FloatValue(e1) => ctx
|
ctx.builder.build_int_add(old_result, x, "").unwrap().into()
|
||||||
.builder
|
}
|
||||||
.build_float_mul(e1, elem2.into_float_value(), "")
|
BasicValueEnum::FloatValue(old_result) => {
|
||||||
.unwrap()
|
let a_scalar = a_scalar.into_float_value();
|
||||||
.as_basic_value_enum(),
|
let b_scalar = b_scalar.into_float_value();
|
||||||
_ => codegen_unreachable!(ctx),
|
let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap();
|
||||||
|
ctx.builder.build_float_add(old_result, x, "").unwrap().into()
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
|
||||||
let acc_val = match acc_val {
|
|
||||||
BasicValueEnum::IntValue(e1) => ctx
|
|
||||||
.builder
|
|
||||||
.build_int_add(e1, product.into_int_value(), "")
|
|
||||||
.unwrap()
|
|
||||||
.as_basic_value_enum(),
|
|
||||||
BasicValueEnum::FloatValue(e1) => ctx
|
|
||||||
.builder
|
|
||||||
.build_float_add(e1, product.into_float_value(), "")
|
|
||||||
.unwrap()
|
|
||||||
.as_basic_value_enum(),
|
|
||||||
_ => codegen_unreachable!(ctx),
|
|
||||||
};
|
|
||||||
ctx.builder.build_store(acc, acc_val).unwrap();
|
|
||||||
|
|
||||||
|
ctx.builder.build_store(result, new_result).unwrap();
|
||||||
Ok(())
|
Ok(())
|
||||||
},
|
},
|
||||||
llvm_usize.const_int(1, false),
|
|generator, ctx, (a_iter, b_iter)| {
|
||||||
)?;
|
a_iter.next(generator, ctx);
|
||||||
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
b_iter.next(generator, ctx);
|
||||||
Ok(acc_val)
|
Ok(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(ctx.builder.build_load(result, "").unwrap())
|
||||||
}
|
}
|
||||||
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
|
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
|
||||||
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
||||||
|
216
nac3core/src/codegen/object/ndarray/matmul.rs
Normal file
216
nac3core/src/codegen/object/ndarray/matmul.rs
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
use std::cmp::max;
|
||||||
|
|
||||||
|
use nac3parser::ast::Operator;
|
||||||
|
|
||||||
|
use super::{util::gen_for_model, NDArrayObject, NDArrayOut};
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
expr::gen_binop_expr_with_values, irrt::call_nac3_ndarray_matmul_calculate_shapes,
|
||||||
|
model::*, object::ndarray::indexing::RustNDIndex, CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::{magic_methods::Binop, typedef::Type},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`.
|
||||||
|
///
|
||||||
|
/// `dst_dtype` defines the dtype of the returned ndarray.
|
||||||
|
fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dst_dtype: Type,
|
||||||
|
in_a: NDArrayObject<'ctx>,
|
||||||
|
in_b: NDArrayObject<'ctx>,
|
||||||
|
) -> NDArrayObject<'ctx> {
|
||||||
|
assert!(in_a.ndims >= 2);
|
||||||
|
assert!(in_b.ndims >= 2);
|
||||||
|
|
||||||
|
// Deduce ndims of the result of matmul.
|
||||||
|
let ndims_int = max(in_a.ndims, in_b.ndims);
|
||||||
|
let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int, false);
|
||||||
|
|
||||||
|
// Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the
|
||||||
|
// destination ndarray to store the result of matmul.
|
||||||
|
let (lhs, rhs, dst) = {
|
||||||
|
let in_lhs_ndims = in_a.ndims_llvm(generator, ctx.ctx);
|
||||||
|
let in_lhs_shape = in_a.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
let in_rhs_ndims = in_b.ndims_llvm(generator, ctx.ctx);
|
||||||
|
let in_rhs_shape = in_b.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
let lhs_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
||||||
|
let rhs_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
||||||
|
let dst_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value);
|
||||||
|
|
||||||
|
// Matmul dimension compatibility is checked here.
|
||||||
|
call_nac3_ndarray_matmul_calculate_shapes(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
in_lhs_ndims,
|
||||||
|
in_lhs_shape,
|
||||||
|
in_rhs_ndims,
|
||||||
|
in_rhs_shape,
|
||||||
|
ndims,
|
||||||
|
lhs_shape,
|
||||||
|
rhs_shape,
|
||||||
|
dst_shape,
|
||||||
|
);
|
||||||
|
|
||||||
|
let lhs = in_a.broadcast_to(generator, ctx, ndims_int, lhs_shape);
|
||||||
|
let rhs = in_b.broadcast_to(generator, ctx, ndims_int, rhs_shape);
|
||||||
|
|
||||||
|
let dst = NDArrayObject::alloca(generator, ctx, dst_dtype, ndims_int);
|
||||||
|
dst.copy_shape_from_array(generator, ctx, dst_shape);
|
||||||
|
dst.create_data(generator, ctx);
|
||||||
|
|
||||||
|
(lhs, rhs, dst)
|
||||||
|
};
|
||||||
|
|
||||||
|
let len = lhs.instance.get(generator, ctx, |f| f.shape).get_index_const(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
i64::try_from(ndims_int - 1).unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let at_row = i64::try_from(ndims_int - 2).unwrap();
|
||||||
|
let at_col = i64::try_from(ndims_int - 1).unwrap();
|
||||||
|
|
||||||
|
let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype);
|
||||||
|
let dst_zero = dst_dtype_llvm.const_zero();
|
||||||
|
|
||||||
|
dst.foreach(generator, ctx, |generator, ctx, _, hdl| {
|
||||||
|
let pdst_ij = hdl.get_pointer(generator, ctx);
|
||||||
|
|
||||||
|
ctx.builder.build_store(pdst_ij, dst_zero).unwrap();
|
||||||
|
|
||||||
|
let indices = hdl.get_indices();
|
||||||
|
let i = indices.get_index_const(generator, ctx, at_row);
|
||||||
|
let j = indices.get_index_const(generator, ctx, at_col);
|
||||||
|
|
||||||
|
let num_0 = Int(SizeT).const_int(generator, ctx.ctx, 0, false);
|
||||||
|
let num_1 = Int(SizeT).const_int(generator, ctx.ctx, 1, false);
|
||||||
|
|
||||||
|
gen_for_model(generator, ctx, num_0, len, num_1, |generator, ctx, _, k| {
|
||||||
|
// `indices` is modified to index into `a` and `b`, and restored.
|
||||||
|
indices.set_index_const(ctx, at_row, i);
|
||||||
|
indices.set_index_const(ctx, at_col, k);
|
||||||
|
let a_ik = lhs.get_scalar_by_indices(generator, ctx, indices);
|
||||||
|
|
||||||
|
indices.set_index_const(ctx, at_row, k);
|
||||||
|
indices.set_index_const(ctx, at_col, j);
|
||||||
|
let b_kj = rhs.get_scalar_by_indices(generator, ctx, indices);
|
||||||
|
|
||||||
|
// Restore `indices`.
|
||||||
|
indices.set_index_const(ctx, at_row, i);
|
||||||
|
indices.set_index_const(ctx, at_col, j);
|
||||||
|
|
||||||
|
// x = a_[...]ik * b_[...]kj
|
||||||
|
let x = gen_binop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(&Some(lhs.dtype), a_ik.value),
|
||||||
|
Binop::normal(Operator::Mult),
|
||||||
|
(&Some(rhs.dtype), b_kj.value),
|
||||||
|
ctx.current_loc,
|
||||||
|
)?
|
||||||
|
.unwrap()
|
||||||
|
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
||||||
|
|
||||||
|
// dst_[...]ij += x
|
||||||
|
let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap();
|
||||||
|
let dst_ij = gen_binop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(&Some(dst_dtype), dst_ij),
|
||||||
|
Binop::normal(Operator::Add),
|
||||||
|
(&Some(dst_dtype), x),
|
||||||
|
ctx.current_loc,
|
||||||
|
)?
|
||||||
|
.unwrap()
|
||||||
|
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
||||||
|
ctx.builder.build_store(pdst_ij, dst_ij).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
dst
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
/// Perform `np.matmul` according to the rules in
|
||||||
|
/// <https://numpy.org/doc/stable/reference/generated/numpy.matmul.html>.
|
||||||
|
///
|
||||||
|
/// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`]
|
||||||
|
/// to handle when the output could be a scalar.
|
||||||
|
///
|
||||||
|
/// `dst_dtype` defines the dtype of the returned ndarray.
|
||||||
|
pub fn matmul<G: CodeGenerator>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: Self,
|
||||||
|
b: Self,
|
||||||
|
out: NDArrayOut<'ctx>,
|
||||||
|
) -> Self {
|
||||||
|
// Sanity check, but type inference should prevent this.
|
||||||
|
assert!(a.ndims > 0 && b.ndims > 0, "np.matmul disallows scalar input");
|
||||||
|
|
||||||
|
/*
|
||||||
|
If both arguments are 2-D they are multiplied like conventional matrices.
|
||||||
|
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indices and broadcast accordingly.
|
||||||
|
If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
|
||||||
|
If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
|
||||||
|
*/
|
||||||
|
|
||||||
|
let new_a = if a.ndims == 1 {
|
||||||
|
// Prepend 1 to its dimensions
|
||||||
|
a.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis])
|
||||||
|
} else {
|
||||||
|
a
|
||||||
|
};
|
||||||
|
|
||||||
|
let new_b = if b.ndims == 1 {
|
||||||
|
// Append 1 to its dimensions
|
||||||
|
b.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis])
|
||||||
|
} else {
|
||||||
|
b
|
||||||
|
};
|
||||||
|
|
||||||
|
// NOTE: `result` will always be a newly allocated ndarray.
|
||||||
|
// Current implementation cannot do in-place matrix muliplication.
|
||||||
|
let mut result = matmul_at_least_2d(generator, ctx, out.get_dtype(), new_a, new_b);
|
||||||
|
|
||||||
|
// Postprocessing on the result to remove prepended/appended axes.
|
||||||
|
let mut postindices = vec![];
|
||||||
|
let zero = Int(Int32).const_0(generator, ctx.ctx);
|
||||||
|
|
||||||
|
if a.ndims == 1 {
|
||||||
|
// Remove the prepended 1
|
||||||
|
postindices.push(RustNDIndex::SingleElement(zero));
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.ndims == 1 {
|
||||||
|
// Remove the appended 1
|
||||||
|
postindices.push(RustNDIndex::Ellipsis);
|
||||||
|
postindices.push(RustNDIndex::SingleElement(zero));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !postindices.is_empty() {
|
||||||
|
result = result.index(generator, ctx, &postindices);
|
||||||
|
}
|
||||||
|
|
||||||
|
match out {
|
||||||
|
NDArrayOut::NewNDArray { .. } => result,
|
||||||
|
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
|
||||||
|
let result_shape = result.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
out_ndarray.assert_can_be_written_by_out(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
result.ndims,
|
||||||
|
result_shape,
|
||||||
|
);
|
||||||
|
|
||||||
|
out_ndarray.copy_data_from(generator, ctx, result);
|
||||||
|
out_ndarray
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -27,6 +27,7 @@ pub mod broadcast;
|
|||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod indexing;
|
pub mod indexing;
|
||||||
pub mod map;
|
pub mod map;
|
||||||
|
pub mod matmul;
|
||||||
pub mod nditer;
|
pub mod nditer;
|
||||||
pub mod shape_util;
|
pub mod shape_util;
|
||||||
pub mod view;
|
pub mod view;
|
||||||
|
@ -2080,10 +2080,12 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let x1_ty = fun.0.args[0].ty;
|
let x1_ty = fun.0.args[0].ty;
|
||||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
|
||||||
let x2_ty = fun.0.args[1].ty;
|
let x2_ty = fun.0.args[1].ty;
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
|
||||||
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?;
|
||||||
|
Ok(Some(result))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ expression: res_vec
|
|||||||
[
|
[
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(250)]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(257)]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
|
@ -7,7 +7,7 @@ expression: res_vec
|
|||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar239]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar239\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar246]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar246\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
@ -5,8 +5,8 @@ expression: res_vec
|
|||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(252)]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(259)]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(264)]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
|||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar238, typevar239]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar238\", \"typevar239\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar245, typevar246]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar245\", \"typevar246\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
@ -6,12 +6,12 @@ expression: res_vec
|
|||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(265)]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(266)]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(273)]\n}\n",
|
||||||
]
|
]
|
||||||
|
@ -7,12 +7,12 @@ use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop};
|
|||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
type_inferencer::*,
|
type_inferencer::*,
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::PrimDef,
|
helper::{extract_ndims, PrimDef},
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@ -175,19 +175,8 @@ pub fn impl_binop(
|
|||||||
ops: &[Operator],
|
ops: &[Operator],
|
||||||
) {
|
) {
|
||||||
with_fields(unifier, ty, |unifier, fields| {
|
with_fields(unifier, ty, |unifier, fields| {
|
||||||
let (other_ty, other_var_id) = if other_ty.len() == 1 {
|
let other_tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
|
||||||
(other_ty[0], None)
|
let function_vars = into_var_map([other_tvar]);
|
||||||
} else {
|
|
||||||
let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
|
|
||||||
(tvar.ty, Some(tvar.id))
|
|
||||||
};
|
|
||||||
|
|
||||||
let function_vars = if let Some(var_id) = other_var_id {
|
|
||||||
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
|
|
||||||
} else {
|
|
||||||
VarMap::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
|
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
|
||||||
|
|
||||||
for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) {
|
for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) {
|
||||||
@ -198,7 +187,7 @@ pub fn impl_binop(
|
|||||||
ret: ret_ty,
|
ret: ret_ty,
|
||||||
vars: function_vars.clone(),
|
vars: function_vars.clone(),
|
||||||
args: vec![FuncArg {
|
args: vec![FuncArg {
|
||||||
ty: other_ty,
|
ty: other_tvar.ty,
|
||||||
default_value: None,
|
default_value: None,
|
||||||
name: "other".into(),
|
name: "other".into(),
|
||||||
is_vararg: false,
|
is_vararg: false,
|
||||||
@ -541,36 +530,43 @@ pub fn typeof_binop(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
||||||
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
let lhs_ndims = extract_ndims(unifier, lhs_ndims);
|
||||||
TypeEnum::TLiteral { values, .. } => {
|
|
||||||
assert_eq!(values.len(), 1);
|
let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
||||||
u64::try_from(values[0].clone()).unwrap()
|
let rhs_ndims = extract_ndims(unifier, rhs_ndims);
|
||||||
|
|
||||||
|
if !(unifier.unioned(lhs_dtype, primitives.float)
|
||||||
|
&& unifier.unioned(rhs_dtype, primitives.float))
|
||||||
|
{
|
||||||
|
return Err(format!(
|
||||||
|
"ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}",
|
||||||
|
unifier.stringify(lhs),
|
||||||
|
unifier.stringify(rhs)
|
||||||
|
));
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
// Deduce the ndims of the resulting ndarray.
|
||||||
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
// If this is 0 (an unsized ndarray), matmul returns a scalar just like NumPy.
|
||||||
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
|
let result_ndims = match (lhs_ndims, rhs_ndims) {
|
||||||
TypeEnum::TLiteral { values, .. } => {
|
(0, _) | (_, 0) => {
|
||||||
assert_eq!(values.len(), 1);
|
return Err(
|
||||||
u64::try_from(values[0].clone()).unwrap()
|
"ndarray.__matmul__ does not allow unsized ndarray input".to_string()
|
||||||
|
)
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
(1, 1) => 0,
|
||||||
|
(1, _) => rhs_ndims - 1,
|
||||||
|
(_, 1) => lhs_ndims - 1,
|
||||||
|
(m, n) => max(m, n),
|
||||||
};
|
};
|
||||||
|
|
||||||
match (lhs_ndims, rhs_ndims) {
|
if result_ndims == 0 {
|
||||||
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
|
// If the result is unsized, NumPy returns a scalar.
|
||||||
(lhs, rhs) if lhs == 0 || rhs == 0 => {
|
primitives.float
|
||||||
return Err(format!(
|
} else {
|
||||||
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
|
let result_ndims_ty =
|
||||||
u8::from(rhs == 0)
|
unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None);
|
||||||
))
|
make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty))
|
||||||
}
|
|
||||||
(lhs, rhs) => {
|
|
||||||
return Err(format!(
|
|
||||||
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -773,7 +769,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||||||
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
||||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
|
impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None);
|
||||||
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
||||||
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
||||||
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
|
Loading…
Reference in New Issue
Block a user