[core] codegen: Reimplement ndarray cmpop

Based on 56cccce1: core/ndstrides: implement cmpop
This commit is contained in:
David Mak 2024-12-19 10:46:24 +08:00
parent a2f1b25fd8
commit ebbadc2d74
4 changed files with 43 additions and 554 deletions

View File

@ -28,46 +28,6 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n
stride *= dims[i]; stride *= dims[i];
} }
} }
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
SizeT lhs_ndims,
const SizeT* rhs_dims,
SizeT rhs_ndims,
SizeT* out_dims) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
} // namespace } // namespace
extern "C" { extern "C" {
@ -87,34 +47,4 @@ void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) { void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
} }
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
uint32_t lhs_ndims,
const uint32_t* rhs_dims,
uint32_t rhs_ndims,
uint32_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
uint64_t lhs_ndims,
const uint64_t* rhs_dims,
uint64_t rhs_ndims,
uint64_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
} }

View File

@ -1852,83 +1852,52 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) }; let (Some(left_ty), left) = left else { codegen_unreachable!(ctx) };
let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) }; let (Some(right_ty), right) = comparators[0] else { codegen_unreachable!(ctx) };
let op = ops[0]; let op = ops[0];
let is_ndarray1 = let left_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, left_ty);
left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let right_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, right_ty);
let is_ndarray2 =
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
return if is_ndarray1 && is_ndarray2 { let left = ScalarOrNDArray::from_value(generator, ctx, (left_ty, left))
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); .to_ndarray(generator, ctx);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); let right = ScalarOrNDArray::from_value(generator, ctx, (right_ty, right))
.to_ndarray(generator, ctx);
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); let result_ndarray = NDArrayType::new_broadcast(
generator,
ctx.ctx,
ctx.ctx.i8_type().into(),
&[left.get_type(), right.get_type()],
)
.broadcast_starmap(
generator,
ctx,
&[left, right],
NDArrayOut::NewNDArray { dtype: ctx.ctx.i8_type().into() },
|generator, ctx, scalars| {
let left_scalar = scalars[0];
let right_scalar = scalars[1];
let left_val = NDArrayType::from_unifier_type(generator, ctx, left_ty) let val = gen_cmpop_expr_with_values(
.map_value(lhs.into_pointer_value(), None); generator,
let res = numpy::ndarray_elementwise_binop_impl( ctx,
generator, (Some(left_ty_dtype), left_scalar),
ctx, &[op],
ctx.primitives.bool, &[(Some(right_ty_dtype), right_scalar)],
None, )?
(left_ty, left_val.as_base_value().into(), false), .unwrap()
(right_ty, rhs, false), .to_basic_value_enum(
|generator, ctx, (lhs, rhs)| { ctx,
let val = gen_cmpop_expr_with_values( generator,
generator, ctx.primitives.bool,
ctx, )?;
(Some(ndarray_dtype1), lhs),
&[op],
&[(Some(ndarray_dtype2), rhs)],
)?
.unwrap()
.to_basic_value_enum(
ctx,
generator,
ctx.primitives.bool,
)?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
}, },
)?; )?;
Ok(Some(res.as_base_value().into())) return Ok(Some(result_ndarray.as_base_value().into()));
} else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
&mut ctx.unifier,
if is_ndarray1 { left_ty } else { right_ty },
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ctx.primitives.bool,
None,
(left_ty, lhs, !is_ndarray1),
(right_ty, rhs, !is_ndarray2),
|generator, ctx, (lhs, rhs)| {
let val = gen_cmpop_expr_with_values(
generator,
ctx,
(Some(ndarray_dtype), lhs),
&[op],
&[(Some(ndarray_dtype), rhs)],
)?
.unwrap()
.to_basic_value_enum(
ctx,
generator,
ctx.primitives.bool,
)?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
},
)?;
Ok(Some(res.as_base_value().into()))
};
} }
} }

View File

@ -1,18 +1,15 @@
use inkwell::{ use inkwell::{
types::BasicTypeEnum, types::BasicTypeEnum,
values::{BasicValueEnum, CallSiteValue, IntValue}, values::{BasicValueEnum, CallSiteValue, IntValue},
AddressSpace, IntPredicate, AddressSpace,
}; };
use itertools::Either; use itertools::Either;
use super::get_usize_dependent_function_name; use super::get_usize_dependent_function_name;
use crate::codegen::{ use crate::codegen::{
llvm_intrinsics,
macros::codegen_unreachable,
stmt::gen_for_callback_incrementing,
values::{ values::{
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, TypedArrayLikeAdapter,
}, },
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
@ -145,166 +142,3 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|_, _, v| v.into(), |_, _, v| v.into(),
) )
} }
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_broadcast");
let ndarray_calc_broadcast_fn =
ctx.module.get_function(&ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(&ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
rhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let rhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
let lhs_eq_rhs = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap();
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.shape().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.shape().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(out_dims, |_, _, v| v.into_int_value(), |_, _, v| v.into())
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.shape().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
|_, _, v| v.into_int_value(),
|_, _, v| v.into(),
)
}

View File

@ -10,10 +10,7 @@ use super::{
expr::gen_binop_expr_with_values, expr::gen_binop_expr_with_values,
irrt::{ irrt::{
calculate_len_for_slice_range, calculate_len_for_slice_range,
ndarray::{ ndarray::{call_ndarray_calc_nd_indices, call_ndarray_calc_size},
call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index,
call_ndarray_calc_nd_indices, call_ndarray_calc_size,
},
}, },
llvm_intrinsics::{self, call_memcpy_generic}, llvm_intrinsics::{self, call_memcpy_generic},
macros::codegen_unreachable, macros::codegen_unreachable,
@ -21,7 +18,7 @@ use super::{
types::ndarray::{factory::ndarray_zero_value, NDArrayType}, types::ndarray::{factory::ndarray_zero_value, NDArrayType},
values::{ values::{
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor,
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
UntypedArrayLikeMutator, UntypedArrayLikeMutator,
}, },
@ -195,152 +192,6 @@ where
}) })
} }
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
/// the target `ndarray`.
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target: NDArrayValue<'ctx>,
source: NDArrayValue<'ctx>,
) {
let array_ndims = source.load_ndims(ctx);
let broadcast_size = target.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
"0:ValueError",
"operands cannot be broadcast together",
[None, None, None],
ctx.current_loc,
);
}
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
/// with broadcast-compatible shapes.
fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
res: NDArrayValue<'ctx>,
(lhs_ty, lhs_val, lhs_scalar): (Type, BasicValueEnum<'ctx>, bool),
(rhs_ty, rhs_val, rhs_scalar): (Type, BasicValueEnum<'ctx>, bool),
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
(BasicValueEnum<'ctx>, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String>,
{
assert!(
!(lhs_scalar && rhs_scalar),
"One of the operands must be a ndarray instance: `{}`, `{}`",
lhs_val.get_type(),
rhs_val.get_type()
);
// Returns the element of an ndarray indexed by the given indices, performing int-promotion on
// `indices` where necessary.
//
// Required for compatibility with `NDArrayType::get_unchecked`.
let get_data_by_indices_compat =
|generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| {
let llvm_usize = generator.get_size_type(ctx.ctx);
// Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT
let stackptr = llvm_intrinsics::call_stacksave(ctx, None);
let indices = if llvm_usize == ctx.ctx.i32_type() {
indices
} else {
let indices_usize = TypedArrayLikeAdapter::<G, IntValue<'ctx>>::from(
ArraySliceValue::from_ptr_val(
ctx.builder
.build_array_alloca(llvm_usize, indices.size(ctx, generator), "")
.unwrap(),
indices.size(ctx, generator),
None,
),
|_, _, val| val.into_int_value(),
|_, _, val| val.into(),
);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(indices.size(ctx, generator), false),
|generator, ctx, _, i| {
let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) };
let idx = ctx
.builder
.build_int_z_extend_or_bit_cast(idx, llvm_usize, "")
.unwrap();
unsafe {
indices_usize.set_typed_unchecked(ctx, generator, &i, idx);
}
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
indices_usize
};
let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) };
llvm_intrinsics::call_stackrestore(ctx, stackptr);
elem
};
// Assert that all ndarray operands are broadcastable to the target size
if !lhs_scalar {
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
.map_value(lhs_val.into_pointer_value(), None);
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
}
if !rhs_scalar {
let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
.map_value(rhs_val.into_pointer_value(), None);
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
}
ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| {
let lhs_elem = if lhs_scalar {
lhs_val
} else {
let lhs = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
.map_value(lhs_val.into_pointer_value(), None);
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
get_data_by_indices_compat(generator, ctx, lhs, lhs_idx)
};
let rhs_elem = if rhs_scalar {
rhs_val
} else {
let rhs = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
.map_value(rhs_val.into_pointer_value(), None);
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
get_data_by_indices_compat(generator, ctx, rhs, rhs_idx)
};
value_fn(generator, ctx, (lhs_elem, rhs_elem))
})?;
Ok(res)
}
/// Copies a slice of an [`NDArrayValue`] to another. /// Copies a slice of an [`NDArrayValue`] to another.
/// ///
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape` /// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape`
@ -592,101 +443,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
} }
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
///
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple.
/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the
/// `value_fn` arguments tuple for all output elements.
///
/// The second element of the tuple indicates whether to treat the operand value as a `ndarray`
/// (which would be accessed by its broadcast index) or as a scalar value (which would be
/// broadcast to all elements).
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
/// written to a new `ndarray`.
/// * `value_fn` - Function mapping the two input elements into the result.
///
/// # Panic
///
/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`.
pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
lhs: (Type, BasicValueEnum<'ctx>, bool),
rhs: (Type, BasicValueEnum<'ctx>, bool),
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
(BasicValueEnum<'ctx>, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String>,
{
let (lhs_ty, lhs_val, lhs_scalar) = lhs;
let (rhs_ty, rhs_val, rhs_scalar) = rhs;
assert!(
!(lhs_scalar && rhs_scalar),
"One of the operands must be a ndarray instance: `{}`, `{}`",
lhs_val.get_type(),
rhs_val.get_type()
);
let ndarray = res.unwrap_or_else(|| {
if lhs_scalar && rhs_scalar {
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
.map_value(lhs_val.into_pointer_value(), None);
let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty)
.map_value(rhs_val.into_pointer_value(), None);
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&ndarray_dims,
|generator, ctx, v| Ok(v.size(ctx, generator)),
|generator, ctx, v, idx| unsafe {
Ok(v.get_typed_unchecked(ctx, generator, &idx, None))
},
)
.unwrap()
} else {
let ndarray = NDArrayType::from_unifier_type(
generator,
ctx,
if lhs_scalar { rhs_ty } else { lhs_ty },
)
.map_value(if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), None);
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&ndarray,
|_, ctx, v| Ok(v.load_ndims(ctx)),
|generator, ctx, v, idx| unsafe {
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
},
)
.unwrap()
}
});
ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| {
value_fn(generator, ctx, elems)
})?;
Ok(ndarray)
}
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. /// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.