forked from M-Labs/nac3
[core] codegen: Reimplement ndarray cmpop
Based on 56cccce1
: core/ndstrides: implement cmpop
This commit is contained in:
parent
a2f1b25fd8
commit
ebbadc2d74
@ -28,46 +28,6 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n
|
|||||||
stride *= dims[i];
|
stride *= dims[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename SizeT>
|
|
||||||
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
|
|
||||||
SizeT lhs_ndims,
|
|
||||||
const SizeT* rhs_dims,
|
|
||||||
SizeT rhs_ndims,
|
|
||||||
SizeT* out_dims) {
|
|
||||||
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < max_ndims; ++i) {
|
|
||||||
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
|
|
||||||
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
|
|
||||||
SizeT* out_dim = &out_dims[max_ndims - i - 1];
|
|
||||||
|
|
||||||
if (lhs_dim_sz == nullptr) {
|
|
||||||
*out_dim = *rhs_dim_sz;
|
|
||||||
} else if (rhs_dim_sz == nullptr) {
|
|
||||||
*out_dim = *lhs_dim_sz;
|
|
||||||
} else if (*lhs_dim_sz == 1) {
|
|
||||||
*out_dim = *rhs_dim_sz;
|
|
||||||
} else if (*rhs_dim_sz == 1) {
|
|
||||||
*out_dim = *lhs_dim_sz;
|
|
||||||
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
|
||||||
*out_dim = *lhs_dim_sz;
|
|
||||||
} else {
|
|
||||||
__builtin_unreachable();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename SizeT>
|
|
||||||
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
|
|
||||||
SizeT src_ndims,
|
|
||||||
const NDIndexInt* in_idx,
|
|
||||||
NDIndexInt* out_idx) {
|
|
||||||
for (SizeT i = 0; i < src_ndims; ++i) {
|
|
||||||
SizeT src_i = src_ndims - i - 1;
|
|
||||||
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@ -87,34 +47,4 @@ void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32
|
|||||||
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) {
|
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) {
|
||||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
|
|
||||||
uint32_t lhs_ndims,
|
|
||||||
const uint32_t* rhs_dims,
|
|
||||||
uint32_t rhs_ndims,
|
|
||||||
uint32_t* out_dims) {
|
|
||||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
|
|
||||||
uint64_t lhs_ndims,
|
|
||||||
const uint64_t* rhs_dims,
|
|
||||||
uint64_t rhs_ndims,
|
|
||||||
uint64_t* out_dims) {
|
|
||||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
|
|
||||||
uint32_t src_ndims,
|
|
||||||
const NDIndexInt* in_idx,
|
|
||||||
NDIndexInt* out_idx) {
|
|
||||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
|
|
||||||
uint64_t src_ndims,
|
|
||||||
const NDIndexInt* in_idx,
|
|
||||||
NDIndexInt* out_idx) {
|
|
||||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
|
||||||
}
|
|
||||||
}
|
}
|
@ -1852,37 +1852,39 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) };
|
let (Some(left_ty), left) = left else { codegen_unreachable!(ctx) };
|
||||||
let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) };
|
let (Some(right_ty), right) = comparators[0] else { codegen_unreachable!(ctx) };
|
||||||
let op = ops[0];
|
let op = ops[0];
|
||||||
|
|
||||||
let is_ndarray1 =
|
let left_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, left_ty);
|
||||||
left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
let right_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, right_ty);
|
||||||
let is_ndarray2 =
|
|
||||||
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
|
||||||
|
|
||||||
return if is_ndarray1 && is_ndarray2 {
|
let left = ScalarOrNDArray::from_value(generator, ctx, (left_ty, left))
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
|
.to_ndarray(generator, ctx);
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
|
let right = ScalarOrNDArray::from_value(generator, ctx, (right_ty, right))
|
||||||
|
.to_ndarray(generator, ctx);
|
||||||
|
|
||||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
let result_ndarray = NDArrayType::new_broadcast(
|
||||||
|
generator,
|
||||||
let left_val = NDArrayType::from_unifier_type(generator, ctx, left_ty)
|
ctx.ctx,
|
||||||
.map_value(lhs.into_pointer_value(), None);
|
ctx.ctx.i8_type().into(),
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
&[left.get_type(), right.get_type()],
|
||||||
|
)
|
||||||
|
.broadcast_starmap(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.bool,
|
&[left, right],
|
||||||
None,
|
NDArrayOut::NewNDArray { dtype: ctx.ctx.i8_type().into() },
|
||||||
(left_ty, left_val.as_base_value().into(), false),
|
|generator, ctx, scalars| {
|
||||||
(right_ty, rhs, false),
|
let left_scalar = scalars[0];
|
||||||
|generator, ctx, (lhs, rhs)| {
|
let right_scalar = scalars[1];
|
||||||
|
|
||||||
let val = gen_cmpop_expr_with_values(
|
let val = gen_cmpop_expr_with_values(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
(Some(ndarray_dtype1), lhs),
|
(Some(left_ty_dtype), left_scalar),
|
||||||
&[op],
|
&[op],
|
||||||
&[(Some(ndarray_dtype2), rhs)],
|
&[(Some(right_ty_dtype), right_scalar)],
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(
|
.to_basic_value_enum(
|
||||||
@ -1895,40 +1897,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
return Ok(Some(result_ndarray.as_base_value().into()));
|
||||||
} else {
|
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
|
||||||
&mut ctx.unifier,
|
|
||||||
if is_ndarray1 { left_ty } else { right_ty },
|
|
||||||
);
|
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
ctx.primitives.bool,
|
|
||||||
None,
|
|
||||||
(left_ty, lhs, !is_ndarray1),
|
|
||||||
(right_ty, rhs, !is_ndarray2),
|
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|
||||||
let val = gen_cmpop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(Some(ndarray_dtype), lhs),
|
|
||||||
&[op],
|
|
||||||
&[(Some(ndarray_dtype), rhs)],
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
ctx.primitives.bool,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
|
||||||
},
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,18 +1,15 @@
|
|||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::BasicTypeEnum,
|
types::BasicTypeEnum,
|
||||||
values::{BasicValueEnum, CallSiteValue, IntValue},
|
values::{BasicValueEnum, CallSiteValue, IntValue},
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace,
|
||||||
};
|
};
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
|
|
||||||
use super::get_usize_dependent_function_name;
|
use super::get_usize_dependent_function_name;
|
||||||
use crate::codegen::{
|
use crate::codegen::{
|
||||||
llvm_intrinsics,
|
|
||||||
macros::codegen_unreachable,
|
|
||||||
stmt::gen_for_callback_incrementing,
|
|
||||||
values::{
|
values::{
|
||||||
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue,
|
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue,
|
||||||
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
|
TypedArrayLikeAdapter,
|
||||||
},
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
@ -145,166 +142,3 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|_, _, v| v.into(),
|
|_, _, v| v.into(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
|
|
||||||
/// dimension and size of each dimension of the resultant `ndarray`.
|
|
||||||
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
lhs: NDArrayValue<'ctx>,
|
|
||||||
rhs: NDArrayValue<'ctx>,
|
|
||||||
) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> {
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
|
||||||
|
|
||||||
let ndarray_calc_broadcast_fn_name =
|
|
||||||
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_broadcast");
|
|
||||||
let ndarray_calc_broadcast_fn =
|
|
||||||
ctx.module.get_function(&ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
|
||||||
let fn_type = llvm_usize.fn_type(
|
|
||||||
&[
|
|
||||||
llvm_pusize.into(),
|
|
||||||
llvm_usize.into(),
|
|
||||||
llvm_pusize.into(),
|
|
||||||
llvm_usize.into(),
|
|
||||||
llvm_pusize.into(),
|
|
||||||
],
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
ctx.module.add_function(&ndarray_calc_broadcast_fn_name, fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
let lhs_ndims = lhs.load_ndims(ctx);
|
|
||||||
let rhs_ndims = rhs.load_ndims(ctx);
|
|
||||||
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
(min_ndims, false),
|
|
||||||
|generator, ctx, _, idx| {
|
|
||||||
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
|
|
||||||
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
|
|
||||||
(
|
|
||||||
lhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
|
|
||||||
rhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
let llvm_usize_const_one = llvm_usize.const_int(1, false);
|
|
||||||
let lhs_eqz = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
|
|
||||||
.unwrap();
|
|
||||||
let rhs_eqz = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
|
|
||||||
.unwrap();
|
|
||||||
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
|
|
||||||
|
|
||||||
let lhs_eq_rhs = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
is_compatible,
|
|
||||||
"0:ValueError",
|
|
||||||
"operands could not be broadcast together",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
|
|
||||||
let lhs_dims = lhs.shape().base_ptr(ctx, generator);
|
|
||||||
let lhs_ndims = lhs.load_ndims(ctx);
|
|
||||||
let rhs_dims = rhs.shape().base_ptr(ctx, generator);
|
|
||||||
let rhs_ndims = rhs.load_ndims(ctx);
|
|
||||||
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
|
|
||||||
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(
|
|
||||||
ndarray_calc_broadcast_fn,
|
|
||||||
&[
|
|
||||||
lhs_dims.into(),
|
|
||||||
lhs_ndims.into(),
|
|
||||||
rhs_dims.into(),
|
|
||||||
rhs_ndims.into(),
|
|
||||||
out_dims.base_ptr(ctx, generator).into(),
|
|
||||||
],
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
TypedArrayLikeAdapter::from(out_dims, |_, _, v| v.into_int_value(), |_, _, v| v.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
|
|
||||||
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
|
|
||||||
/// array `broadcast_idx`.
|
|
||||||
pub fn call_ndarray_calc_broadcast_index<
|
|
||||||
'ctx,
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
|
|
||||||
>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
array: NDArrayValue<'ctx>,
|
|
||||||
broadcast_idx: &BroadcastIdx,
|
|
||||||
) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> {
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
|
||||||
|
|
||||||
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
|
||||||
32 => "__nac3_ndarray_calc_broadcast_idx",
|
|
||||||
64 => "__nac3_ndarray_calc_broadcast_idx64",
|
|
||||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
|
||||||
};
|
|
||||||
let ndarray_calc_broadcast_fn =
|
|
||||||
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
|
||||||
let fn_type = llvm_usize.fn_type(
|
|
||||||
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
let broadcast_size = broadcast_idx.size(ctx, generator);
|
|
||||||
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
|
|
||||||
|
|
||||||
let array_dims = array.shape().base_ptr(ctx, generator);
|
|
||||||
let array_ndims = array.load_ndims(ctx);
|
|
||||||
let broadcast_idx_ptr = unsafe {
|
|
||||||
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
};
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(
|
|
||||||
ndarray_calc_broadcast_fn,
|
|
||||||
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
TypedArrayLikeAdapter::from(
|
|
||||||
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
|
|
||||||
|_, _, v| v.into_int_value(),
|
|
||||||
|_, _, v| v.into(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
@ -10,10 +10,7 @@ use super::{
|
|||||||
expr::gen_binop_expr_with_values,
|
expr::gen_binop_expr_with_values,
|
||||||
irrt::{
|
irrt::{
|
||||||
calculate_len_for_slice_range,
|
calculate_len_for_slice_range,
|
||||||
ndarray::{
|
ndarray::{call_ndarray_calc_nd_indices, call_ndarray_calc_size},
|
||||||
call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index,
|
|
||||||
call_ndarray_calc_nd_indices, call_ndarray_calc_size,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
llvm_intrinsics::{self, call_memcpy_generic},
|
llvm_intrinsics::{self, call_memcpy_generic},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
@ -21,7 +18,7 @@ use super::{
|
|||||||
types::ndarray::{factory::ndarray_zero_value, NDArrayType},
|
types::ndarray::{factory::ndarray_zero_value, NDArrayType},
|
||||||
values::{
|
values::{
|
||||||
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
|
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor,
|
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor,
|
||||||
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
|
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
|
||||||
UntypedArrayLikeMutator,
|
UntypedArrayLikeMutator,
|
||||||
},
|
},
|
||||||
@ -195,152 +192,6 @@ where
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
|
|
||||||
/// the target `ndarray`.
|
|
||||||
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
target: NDArrayValue<'ctx>,
|
|
||||||
source: NDArrayValue<'ctx>,
|
|
||||||
) {
|
|
||||||
let array_ndims = source.load_ndims(ctx);
|
|
||||||
let broadcast_size = target.load_ndims(ctx);
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"operands cannot be broadcast together",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
|
|
||||||
/// with broadcast-compatible shapes.
|
|
||||||
fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
||||||
res: NDArrayValue<'ctx>,
|
|
||||||
(lhs_ty, lhs_val, lhs_scalar): (Type, BasicValueEnum<'ctx>, bool),
|
|
||||||
(rhs_ty, rhs_val, rhs_scalar): (Type, BasicValueEnum<'ctx>, bool),
|
|
||||||
value_fn: ValueFn,
|
|
||||||
) -> Result<NDArrayValue<'ctx>, String>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
ValueFn: Fn(
|
|
||||||
&mut G,
|
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
|
||||||
(BasicValueEnum<'ctx>, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
|
||||||
{
|
|
||||||
assert!(
|
|
||||||
!(lhs_scalar && rhs_scalar),
|
|
||||||
"One of the operands must be a ndarray instance: `{}`, `{}`",
|
|
||||||
lhs_val.get_type(),
|
|
||||||
rhs_val.get_type()
|
|
||||||
);
|
|
||||||
|
|
||||||
// Returns the element of an ndarray indexed by the given indices, performing int-promotion on
|
|
||||||
// `indices` where necessary.
|
|
||||||
//
|
|
||||||
// Required for compatibility with `NDArrayType::get_unchecked`.
|
|
||||||
let get_data_by_indices_compat =
|
|
||||||
|generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| {
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
// Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT
|
|
||||||
let stackptr = llvm_intrinsics::call_stacksave(ctx, None);
|
|
||||||
let indices = if llvm_usize == ctx.ctx.i32_type() {
|
|
||||||
indices
|
|
||||||
} else {
|
|
||||||
let indices_usize = TypedArrayLikeAdapter::<G, IntValue<'ctx>>::from(
|
|
||||||
ArraySliceValue::from_ptr_val(
|
|
||||||
ctx.builder
|
|
||||||
.build_array_alloca(llvm_usize, indices.size(ctx, generator), "")
|
|
||||||
.unwrap(),
|
|
||||||
indices.size(ctx, generator),
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
|_, _, val| val.into_int_value(),
|
|
||||||
|_, _, val| val.into(),
|
|
||||||
);
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
(indices.size(ctx, generator), false),
|
|
||||||
|generator, ctx, _, i| {
|
|
||||||
let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) };
|
|
||||||
let idx = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_z_extend_or_bit_cast(idx, llvm_usize, "")
|
|
||||||
.unwrap();
|
|
||||||
unsafe {
|
|
||||||
indices_usize.set_typed_unchecked(ctx, generator, &i, idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
indices_usize
|
|
||||||
};
|
|
||||||
|
|
||||||
let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) };
|
|
||||||
|
|
||||||
llvm_intrinsics::call_stackrestore(ctx, stackptr);
|
|
||||||
|
|
||||||
elem
|
|
||||||
};
|
|
||||||
|
|
||||||
// Assert that all ndarray operands are broadcastable to the target size
|
|
||||||
if !lhs_scalar {
|
|
||||||
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
|
|
||||||
.map_value(lhs_val.into_pointer_value(), None);
|
|
||||||
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
|
|
||||||
}
|
|
||||||
|
|
||||||
if !rhs_scalar {
|
|
||||||
let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
|
|
||||||
.map_value(rhs_val.into_pointer_value(), None);
|
|
||||||
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
|
|
||||||
}
|
|
||||||
|
|
||||||
ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| {
|
|
||||||
let lhs_elem = if lhs_scalar {
|
|
||||||
lhs_val
|
|
||||||
} else {
|
|
||||||
let lhs = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
|
|
||||||
.map_value(lhs_val.into_pointer_value(), None);
|
|
||||||
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
|
|
||||||
|
|
||||||
get_data_by_indices_compat(generator, ctx, lhs, lhs_idx)
|
|
||||||
};
|
|
||||||
|
|
||||||
let rhs_elem = if rhs_scalar {
|
|
||||||
rhs_val
|
|
||||||
} else {
|
|
||||||
let rhs = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
|
|
||||||
.map_value(rhs_val.into_pointer_value(), None);
|
|
||||||
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
|
|
||||||
|
|
||||||
get_data_by_indices_compat(generator, ctx, rhs, rhs_idx)
|
|
||||||
};
|
|
||||||
|
|
||||||
value_fn(generator, ctx, (lhs_elem, rhs_elem))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Copies a slice of an [`NDArrayValue`] to another.
|
/// Copies a slice of an [`NDArrayValue`] to another.
|
||||||
///
|
///
|
||||||
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape`
|
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape`
|
||||||
@ -592,101 +443,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
|
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
|
||||||
}
|
}
|
||||||
|
|
||||||
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
|
|
||||||
///
|
|
||||||
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
|
|
||||||
/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple.
|
|
||||||
/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the
|
|
||||||
/// `value_fn` arguments tuple for all output elements.
|
|
||||||
///
|
|
||||||
/// The second element of the tuple indicates whether to treat the operand value as a `ndarray`
|
|
||||||
/// (which would be accessed by its broadcast index) or as a scalar value (which would be
|
|
||||||
/// broadcast to all elements).
|
|
||||||
///
|
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
|
||||||
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
|
||||||
/// written to a new `ndarray`.
|
|
||||||
/// * `value_fn` - Function mapping the two input elements into the result.
|
|
||||||
///
|
|
||||||
/// # Panic
|
|
||||||
///
|
|
||||||
/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`.
|
|
||||||
pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
||||||
elem_ty: Type,
|
|
||||||
res: Option<NDArrayValue<'ctx>>,
|
|
||||||
lhs: (Type, BasicValueEnum<'ctx>, bool),
|
|
||||||
rhs: (Type, BasicValueEnum<'ctx>, bool),
|
|
||||||
value_fn: ValueFn,
|
|
||||||
) -> Result<NDArrayValue<'ctx>, String>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
ValueFn: Fn(
|
|
||||||
&mut G,
|
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
|
||||||
(BasicValueEnum<'ctx>, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
|
||||||
{
|
|
||||||
let (lhs_ty, lhs_val, lhs_scalar) = lhs;
|
|
||||||
let (rhs_ty, rhs_val, rhs_scalar) = rhs;
|
|
||||||
|
|
||||||
assert!(
|
|
||||||
!(lhs_scalar && rhs_scalar),
|
|
||||||
"One of the operands must be a ndarray instance: `{}`, `{}`",
|
|
||||||
lhs_val.get_type(),
|
|
||||||
rhs_val.get_type()
|
|
||||||
);
|
|
||||||
|
|
||||||
let ndarray = res.unwrap_or_else(|| {
|
|
||||||
if lhs_scalar && rhs_scalar {
|
|
||||||
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
|
|
||||||
.map_value(lhs_val.into_pointer_value(), None);
|
|
||||||
let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
|
|
||||||
.map_value(rhs_val.into_pointer_value(), None);
|
|
||||||
|
|
||||||
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
|
|
||||||
|
|
||||||
create_ndarray_dyn_shape(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
elem_ty,
|
|
||||||
&ndarray_dims,
|
|
||||||
|generator, ctx, v| Ok(v.size(ctx, generator)),
|
|
||||||
|generator, ctx, v, idx| unsafe {
|
|
||||||
Ok(v.get_typed_unchecked(ctx, generator, &idx, None))
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
} else {
|
|
||||||
let ndarray = NDArrayType::from_unifier_type(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
if lhs_scalar { rhs_ty } else { lhs_ty },
|
|
||||||
)
|
|
||||||
.map_value(if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), None);
|
|
||||||
|
|
||||||
create_ndarray_dyn_shape(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
elem_ty,
|
|
||||||
&ndarray,
|
|
||||||
|_, ctx, v| Ok(v.load_ndims(ctx)),
|
|
||||||
|generator, ctx, v, idx| unsafe {
|
|
||||||
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| {
|
|
||||||
value_fn(generator, ctx, elems)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(ndarray)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
|
Loading…
Reference in New Issue
Block a user