[core] codegen: Reimplement ndarray cmpop
Based on 56cccce1
: core/ndstrides: implement cmpop
This commit is contained in:
parent
f43523ec72
commit
7da45a4fbd
@ -28,46 +28,6 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n
|
||||
stride *= dims[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
|
||||
SizeT lhs_ndims,
|
||||
const SizeT* rhs_dims,
|
||||
SizeT rhs_ndims,
|
||||
SizeT* out_dims) {
|
||||
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
||||
|
||||
for (SizeT i = 0; i < max_ndims; ++i) {
|
||||
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
|
||||
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
|
||||
SizeT* out_dim = &out_dims[max_ndims - i - 1];
|
||||
|
||||
if (lhs_dim_sz == nullptr) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (rhs_dim_sz == nullptr) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == 1) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (*rhs_dim_sz == 1) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else {
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
|
||||
SizeT src_ndims,
|
||||
const NDIndexInt* in_idx,
|
||||
NDIndexInt* out_idx) {
|
||||
for (SizeT i = 0; i < src_ndims; ++i) {
|
||||
SizeT src_i = src_ndims - i - 1;
|
||||
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
@ -87,34 +47,4 @@ void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32
|
||||
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) {
|
||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
|
||||
uint32_t lhs_ndims,
|
||||
const uint32_t* rhs_dims,
|
||||
uint32_t rhs_ndims,
|
||||
uint32_t* out_dims) {
|
||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
|
||||
uint64_t lhs_ndims,
|
||||
const uint64_t* rhs_dims,
|
||||
uint64_t rhs_ndims,
|
||||
uint64_t* out_dims) {
|
||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
|
||||
uint32_t src_ndims,
|
||||
const NDIndexInt* in_idx,
|
||||
NDIndexInt* out_idx) {
|
||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
|
||||
uint64_t src_ndims,
|
||||
const NDIndexInt* in_idx,
|
||||
NDIndexInt* out_idx) {
|
||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||
}
|
||||
}
|
@ -1875,37 +1875,39 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||
{
|
||||
let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) };
|
||||
let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) };
|
||||
let (Some(left_ty), left) = left else { codegen_unreachable!(ctx) };
|
||||
let (Some(right_ty), right) = comparators[0] else { codegen_unreachable!(ctx) };
|
||||
let op = ops[0];
|
||||
|
||||
let is_ndarray1 =
|
||||
left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
let is_ndarray2 =
|
||||
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
let left_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, left_ty);
|
||||
let right_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, right_ty);
|
||||
|
||||
return if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
|
||||
let left = ScalarOrNDArray::from_value(generator, ctx, (left_ty, left))
|
||||
.to_ndarray(generator, ctx);
|
||||
let right = ScalarOrNDArray::from_value(generator, ctx, (right_ty, right))
|
||||
.to_ndarray(generator, ctx);
|
||||
|
||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
let left_val = NDArrayType::from_unifier_type(generator, ctx, left_ty)
|
||||
.map_value(lhs.into_pointer_value(), None);
|
||||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
let result_ndarray = NDArrayType::new_broadcast(
|
||||
generator,
|
||||
ctx.ctx,
|
||||
ctx.ctx.i8_type().into(),
|
||||
&[left.get_type(), right.get_type()],
|
||||
)
|
||||
.broadcast_starmap(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.bool,
|
||||
None,
|
||||
(left_ty, left_val.as_base_value().into(), false),
|
||||
(right_ty, rhs, false),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
&[left, right],
|
||||
NDArrayOut::NewNDArray { dtype: ctx.ctx.i8_type().into() },
|
||||
|generator, ctx, scalars| {
|
||||
let left_scalar = scalars[0];
|
||||
let right_scalar = scalars[1];
|
||||
|
||||
let val = gen_cmpop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(Some(ndarray_dtype1), lhs),
|
||||
(Some(left_ty_dtype), left_scalar),
|
||||
&[op],
|
||||
&[(Some(ndarray_dtype2), rhs)],
|
||||
&[(Some(right_ty_dtype), right_scalar)],
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(
|
||||
@ -1918,40 +1920,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Some(res.as_base_value().into()))
|
||||
} else {
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
||||
&mut ctx.unifier,
|
||||
if is_ndarray1 { left_ty } else { right_ty },
|
||||
);
|
||||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.bool,
|
||||
None,
|
||||
(left_ty, lhs, !is_ndarray1),
|
||||
(right_ty, rhs, !is_ndarray2),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
let val = gen_cmpop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(Some(ndarray_dtype), lhs),
|
||||
&[op],
|
||||
&[(Some(ndarray_dtype), rhs)],
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(
|
||||
ctx,
|
||||
generator,
|
||||
ctx.primitives.bool,
|
||||
)?;
|
||||
|
||||
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Some(res.as_base_value().into()))
|
||||
};
|
||||
return Ok(Some(result_ndarray.as_base_value().into()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,18 +1,15 @@
|
||||
use inkwell::{
|
||||
types::BasicTypeEnum,
|
||||
values::{BasicValueEnum, CallSiteValue, IntValue},
|
||||
AddressSpace, IntPredicate,
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Either;
|
||||
|
||||
use super::get_usize_dependent_function_name;
|
||||
use crate::codegen::{
|
||||
llvm_intrinsics,
|
||||
macros::codegen_unreachable,
|
||||
stmt::gen_for_callback_incrementing,
|
||||
values::{
|
||||
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue,
|
||||
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
|
||||
TypedArrayLikeAdapter,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
@ -145,166 +142,3 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|_, _, v| v.into(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
|
||||
/// dimension and size of each dimension of the resultant `ndarray`.
|
||||
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
lhs: NDArrayValue<'ctx>,
|
||||
rhs: NDArrayValue<'ctx>,
|
||||
) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_broadcast_fn_name =
|
||||
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_broadcast");
|
||||
let ndarray_calc_broadcast_fn =
|
||||
ctx.module.get_function(&ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[
|
||||
llvm_pusize.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(&ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let lhs_ndims = lhs.load_ndims(ctx);
|
||||
let rhs_ndims = rhs.load_ndims(ctx);
|
||||
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(min_ndims, false),
|
||||
|generator, ctx, _, idx| {
|
||||
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
|
||||
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
|
||||
(
|
||||
lhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
|
||||
rhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
|
||||
)
|
||||
};
|
||||
|
||||
let llvm_usize_const_one = llvm_usize.const_int(1, false);
|
||||
let lhs_eqz = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
|
||||
.unwrap();
|
||||
let rhs_eqz = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
|
||||
.unwrap();
|
||||
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
|
||||
|
||||
let lhs_eq_rhs = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
|
||||
.unwrap();
|
||||
|
||||
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
is_compatible,
|
||||
"0:ValueError",
|
||||
"operands could not be broadcast together",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
|
||||
let lhs_dims = lhs.shape().base_ptr(ctx, generator);
|
||||
let lhs_ndims = lhs.load_ndims(ctx);
|
||||
let rhs_dims = rhs.shape().base_ptr(ctx, generator);
|
||||
let rhs_ndims = rhs.load_ndims(ctx);
|
||||
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
|
||||
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_broadcast_fn,
|
||||
&[
|
||||
lhs_dims.into(),
|
||||
lhs_ndims.into(),
|
||||
rhs_dims.into(),
|
||||
rhs_ndims.into(),
|
||||
out_dims.base_ptr(ctx, generator).into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
TypedArrayLikeAdapter::from(out_dims, |_, _, v| v.into_int_value(), |_, _, v| v.into())
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
|
||||
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
|
||||
/// array `broadcast_idx`.
|
||||
pub fn call_ndarray_calc_broadcast_index<
|
||||
'ctx,
|
||||
G: CodeGenerator + ?Sized,
|
||||
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
|
||||
>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
array: NDArrayValue<'ctx>,
|
||||
broadcast_idx: &BroadcastIdx,
|
||||
) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_broadcast_idx",
|
||||
64 => "__nac3_ndarray_calc_broadcast_idx64",
|
||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_broadcast_fn =
|
||||
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let broadcast_size = broadcast_idx.size(ctx, generator);
|
||||
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
|
||||
|
||||
let array_dims = array.shape().base_ptr(ctx, generator);
|
||||
let array_ndims = array.load_ndims(ctx);
|
||||
let broadcast_idx_ptr = unsafe {
|
||||
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_broadcast_fn,
|
||||
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
|
||||
|_, _, v| v.into_int_value(),
|
||||
|_, _, v| v.into(),
|
||||
)
|
||||
}
|
||||
|
@ -10,10 +10,7 @@ use super::{
|
||||
expr::gen_binop_expr_with_values,
|
||||
irrt::{
|
||||
calculate_len_for_slice_range,
|
||||
ndarray::{
|
||||
call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index,
|
||||
call_ndarray_calc_nd_indices, call_ndarray_calc_size,
|
||||
},
|
||||
ndarray::{call_ndarray_calc_nd_indices, call_ndarray_calc_size},
|
||||
},
|
||||
llvm_intrinsics::{self, call_memcpy_generic},
|
||||
macros::codegen_unreachable,
|
||||
@ -21,7 +18,7 @@ use super::{
|
||||
types::ndarray::{factory::ndarray_zero_value, NDArrayType},
|
||||
values::{
|
||||
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
|
||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor,
|
||||
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor,
|
||||
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
|
||||
UntypedArrayLikeMutator,
|
||||
},
|
||||
@ -195,152 +192,6 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
|
||||
/// the target `ndarray`.
|
||||
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
target: NDArrayValue<'ctx>,
|
||||
source: NDArrayValue<'ctx>,
|
||||
) {
|
||||
let array_ndims = source.load_ndims(ctx);
|
||||
let broadcast_size = target.load_ndims(ctx);
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
|
||||
"0:ValueError",
|
||||
"operands cannot be broadcast together",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
|
||||
/// with broadcast-compatible shapes.
|
||||
fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
res: NDArrayValue<'ctx>,
|
||||
(lhs_ty, lhs_val, lhs_scalar): (Type, BasicValueEnum<'ctx>, bool),
|
||||
(rhs_ty, rhs_val, rhs_scalar): (Type, BasicValueEnum<'ctx>, bool),
|
||||
value_fn: ValueFn,
|
||||
) -> Result<NDArrayValue<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
ValueFn: Fn(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
(BasicValueEnum<'ctx>, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
assert!(
|
||||
!(lhs_scalar && rhs_scalar),
|
||||
"One of the operands must be a ndarray instance: `{}`, `{}`",
|
||||
lhs_val.get_type(),
|
||||
rhs_val.get_type()
|
||||
);
|
||||
|
||||
// Returns the element of an ndarray indexed by the given indices, performing int-promotion on
|
||||
// `indices` where necessary.
|
||||
//
|
||||
// Required for compatibility with `NDArrayType::get_unchecked`.
|
||||
let get_data_by_indices_compat =
|
||||
|generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
// Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT
|
||||
let stackptr = llvm_intrinsics::call_stacksave(ctx, None);
|
||||
let indices = if llvm_usize == ctx.ctx.i32_type() {
|
||||
indices
|
||||
} else {
|
||||
let indices_usize = TypedArrayLikeAdapter::<G, IntValue<'ctx>>::from(
|
||||
ArraySliceValue::from_ptr_val(
|
||||
ctx.builder
|
||||
.build_array_alloca(llvm_usize, indices.size(ctx, generator), "")
|
||||
.unwrap(),
|
||||
indices.size(ctx, generator),
|
||||
None,
|
||||
),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(indices.size(ctx, generator), false),
|
||||
|generator, ctx, _, i| {
|
||||
let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) };
|
||||
let idx = ctx
|
||||
.builder
|
||||
.build_int_z_extend_or_bit_cast(idx, llvm_usize, "")
|
||||
.unwrap();
|
||||
unsafe {
|
||||
indices_usize.set_typed_unchecked(ctx, generator, &i, idx);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
indices_usize
|
||||
};
|
||||
|
||||
let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) };
|
||||
|
||||
llvm_intrinsics::call_stackrestore(ctx, stackptr);
|
||||
|
||||
elem
|
||||
};
|
||||
|
||||
// Assert that all ndarray operands are broadcastable to the target size
|
||||
if !lhs_scalar {
|
||||
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
|
||||
.map_value(lhs_val.into_pointer_value(), None);
|
||||
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
|
||||
}
|
||||
|
||||
if !rhs_scalar {
|
||||
let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
|
||||
.map_value(rhs_val.into_pointer_value(), None);
|
||||
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
|
||||
}
|
||||
|
||||
ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| {
|
||||
let lhs_elem = if lhs_scalar {
|
||||
lhs_val
|
||||
} else {
|
||||
let lhs = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
|
||||
.map_value(lhs_val.into_pointer_value(), None);
|
||||
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
|
||||
|
||||
get_data_by_indices_compat(generator, ctx, lhs, lhs_idx)
|
||||
};
|
||||
|
||||
let rhs_elem = if rhs_scalar {
|
||||
rhs_val
|
||||
} else {
|
||||
let rhs = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
|
||||
.map_value(rhs_val.into_pointer_value(), None);
|
||||
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
|
||||
|
||||
get_data_by_indices_compat(generator, ctx, rhs, rhs_idx)
|
||||
};
|
||||
|
||||
value_fn(generator, ctx, (lhs_elem, rhs_elem))
|
||||
})?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Copies a slice of an [`NDArrayValue`] to another.
|
||||
///
|
||||
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape`
|
||||
@ -592,101 +443,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
|
||||
///
|
||||
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
|
||||
/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple.
|
||||
/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the
|
||||
/// `value_fn` arguments tuple for all output elements.
|
||||
///
|
||||
/// The second element of the tuple indicates whether to treat the operand value as a `ndarray`
|
||||
/// (which would be accessed by its broadcast index) or as a scalar value (which would be
|
||||
/// broadcast to all elements).
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
||||
/// written to a new `ndarray`.
|
||||
/// * `value_fn` - Function mapping the two input elements into the result.
|
||||
///
|
||||
/// # Panic
|
||||
///
|
||||
/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`.
|
||||
pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
res: Option<NDArrayValue<'ctx>>,
|
||||
lhs: (Type, BasicValueEnum<'ctx>, bool),
|
||||
rhs: (Type, BasicValueEnum<'ctx>, bool),
|
||||
value_fn: ValueFn,
|
||||
) -> Result<NDArrayValue<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
ValueFn: Fn(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
(BasicValueEnum<'ctx>, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let (lhs_ty, lhs_val, lhs_scalar) = lhs;
|
||||
let (rhs_ty, rhs_val, rhs_scalar) = rhs;
|
||||
|
||||
assert!(
|
||||
!(lhs_scalar && rhs_scalar),
|
||||
"One of the operands must be a ndarray instance: `{}`, `{}`",
|
||||
lhs_val.get_type(),
|
||||
rhs_val.get_type()
|
||||
);
|
||||
|
||||
let ndarray = res.unwrap_or_else(|| {
|
||||
if lhs_scalar && rhs_scalar {
|
||||
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
|
||||
.map_value(lhs_val.into_pointer_value(), None);
|
||||
let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
|
||||
.map_value(rhs_val.into_pointer_value(), None);
|
||||
|
||||
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
|
||||
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&ndarray_dims,
|
||||
|generator, ctx, v| Ok(v.size(ctx, generator)),
|
||||
|generator, ctx, v, idx| unsafe {
|
||||
Ok(v.get_typed_unchecked(ctx, generator, &idx, None))
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
} else {
|
||||
let ndarray = NDArrayType::from_unifier_type(
|
||||
generator,
|
||||
ctx,
|
||||
if lhs_scalar { rhs_ty } else { lhs_ty },
|
||||
)
|
||||
.map_value(if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), None);
|
||||
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&ndarray,
|
||||
|_, ctx, v| Ok(v.load_ndims(ctx)),
|
||||
|generator, ctx, v, idx| unsafe {
|
||||
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
});
|
||||
|
||||
ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| {
|
||||
value_fn(generator, ctx, elems)
|
||||
})?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
|
Loading…
Reference in New Issue
Block a user