diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp index 7fc9a63b9..534f18d68 100644 --- a/nac3core/irrt/irrt/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray.hpp @@ -28,46 +28,6 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n stride *= dims[i]; } } - -template -void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, - SizeT lhs_ndims, - const SizeT* rhs_dims, - SizeT rhs_ndims, - SizeT* out_dims) { - SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; - - for (SizeT i = 0; i < max_ndims; ++i) { - const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr; - const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr; - SizeT* out_dim = &out_dims[max_ndims - i - 1]; - - if (lhs_dim_sz == nullptr) { - *out_dim = *rhs_dim_sz; - } else if (rhs_dim_sz == nullptr) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == 1) { - *out_dim = *rhs_dim_sz; - } else if (*rhs_dim_sz == 1) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == *rhs_dim_sz) { - *out_dim = *lhs_dim_sz; - } else { - __builtin_unreachable(); - } - } -} - -template -void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims, - SizeT src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - for (SizeT i = 0; i < src_ndims; ++i) { - SizeT src_i = src_ndims - i - 1; - out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; - } -} } // namespace extern "C" { @@ -87,34 +47,4 @@ void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32 void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) { __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); } - -void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims, - uint32_t lhs_ndims, - const uint32_t* rhs_dims, - uint32_t rhs_ndims, - uint32_t* out_dims) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); -} - -void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims, - uint64_t lhs_ndims, - const uint64_t* rhs_dims, - uint64_t rhs_ndims, - uint64_t* out_dims) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); -} - -void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims, - uint32_t src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); -} - -void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims, - uint64_t src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); -} } \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 232c68c2b..8a002bb37 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1852,83 +1852,52 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) }; - let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) }; + let (Some(left_ty), left) = left else { codegen_unreachable!(ctx) }; + let (Some(right_ty), right) = comparators[0] else { codegen_unreachable!(ctx) }; let op = ops[0]; - let is_ndarray1 = - left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let left_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, left_ty); + let right_ty_dtype = arraylike_flatten_element_type(&mut ctx.unifier, right_ty); - return if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); + let left = ScalarOrNDArray::from_value(generator, ctx, (left_ty, left)) + .to_ndarray(generator, ctx); + let right = ScalarOrNDArray::from_value(generator, ctx, (right_ty, right)) + .to_ndarray(generator, ctx); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + let result_ndarray = NDArrayType::new_broadcast( + generator, + ctx.ctx, + ctx.ctx.i8_type().into(), + &[left.get_type(), right.get_type()], + ) + .broadcast_starmap( + generator, + ctx, + &[left, right], + NDArrayOut::NewNDArray { dtype: ctx.ctx.i8_type().into() }, + |generator, ctx, scalars| { + let left_scalar = scalars[0]; + let right_scalar = scalars[1]; - let left_val = NDArrayType::from_unifier_type(generator, ctx, left_ty) - .map_value(lhs.into_pointer_value(), None); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (left_ty, left_val.as_base_value().into(), false), - (right_ty, rhs, false), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype1), lhs), - &[op], - &[(Some(ndarray_dtype2), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; + let val = gen_cmpop_expr_with_values( + generator, + ctx, + (Some(left_ty_dtype), left_scalar), + &[op], + &[(Some(right_ty_dtype), right_scalar)], + )? + .unwrap() + .to_basic_value_enum( + ctx, + generator, + ctx.primitives.bool, + )?; - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; - Ok(Some(res.as_base_value().into())) - } else { - let (ndarray_dtype, _) = unpack_ndarray_var_tys( - &mut ctx.unifier, - if is_ndarray1 { left_ty } else { right_ty }, - ); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (left_ty, lhs, !is_ndarray1), - (right_ty, rhs, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype), lhs), - &[op], - &[(Some(ndarray_dtype), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; - - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; - - Ok(Some(res.as_base_value().into())) - }; + return Ok(Some(result_ndarray.as_base_value().into())); } } diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index ba22568e6..151795c5f 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,18 +1,15 @@ use inkwell::{ types::BasicTypeEnum, values::{BasicValueEnum, CallSiteValue, IntValue}, - AddressSpace, IntPredicate, + AddressSpace, }; use itertools::Either; use super::get_usize_dependent_function_name; use crate::codegen::{ - llvm_intrinsics, - macros::codegen_unreachable, - stmt::gen_for_callback_incrementing, values::{ ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, + TypedArrayLikeAdapter, }, CodeGenContext, CodeGenerator, }; @@ -145,166 +142,3 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( |_, _, v| v.into(), ) } - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of -/// dimension and size of each dimension of the resultant `ndarray`. -pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_broadcast"); - let ndarray_calc_broadcast_fn = - ctx.module.get_function(&ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[ - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - ], - false, - ); - - ctx.module.add_function(&ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (min_ndims, false), - |generator, ctx, _, idx| { - let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); - let (lhs_dim_sz, rhs_dim_sz) = unsafe { - ( - lhs.shape().get_typed_unchecked(ctx, generator, &idx, None), - rhs.shape().get_typed_unchecked(ctx, generator, &idx, None), - ) - }; - - let llvm_usize_const_one = llvm_usize.const_int(1, false); - let lhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let rhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); - - let lhs_eq_rhs = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") - .unwrap(); - - let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); - - ctx.make_assert( - generator, - is_compatible, - "0:ValueError", - "operands could not be broadcast together", - [None, None, None], - ctx.current_loc, - ); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); - let lhs_dims = lhs.shape().base_ptr(ctx, generator); - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_dims = rhs.shape().base_ptr(ctx, generator); - let rhs_ndims = rhs.load_ndims(ctx); - let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); - let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[ - lhs_dims.into(), - lhs_ndims.into(), - rhs_dims.into(), - rhs_ndims.into(), - out_dims.base_ptr(ctx, generator).into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from(out_dims, |_, _, v| v.into_int_value(), |_, _, v| v.into()) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] -/// containing the indices used for accessing `array` corresponding to the index of the broadcasted -/// array `broadcast_idx`. -pub fn call_ndarray_calc_broadcast_index< - 'ctx, - G: CodeGenerator + ?Sized, - BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, ->( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - array: NDArrayValue<'ctx>, - broadcast_idx: &BroadcastIdx, -) -> TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast_idx", - 64 => "__nac3_ndarray_calc_broadcast_idx64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let broadcast_size = broadcast_idx.size(ctx, generator); - let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); - - let array_dims = array.shape().base_ptr(ctx, generator); - let array_ndims = array.load_ndims(ctx); - let broadcast_idx_ptr = unsafe { - broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), - |_, _, v| v.into_int_value(), - |_, _, v| v.into(), - ) -} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index d02103ab2..9fe5a9722 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -10,10 +10,7 @@ use super::{ expr::gen_binop_expr_with_values, irrt::{ calculate_len_for_slice_range, - ndarray::{ - call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index, - call_ndarray_calc_nd_indices, call_ndarray_calc_size, - }, + ndarray::{call_ndarray_calc_nd_indices, call_ndarray_calc_size}, }, llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, @@ -21,7 +18,7 @@ use super::{ types::ndarray::{factory::ndarray_zero_value, NDArrayType}, values::{ ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, + ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, @@ -195,152 +192,6 @@ where }) } -/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of -/// the target `ndarray`. -fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - target: NDArrayValue<'ctx>, - source: NDArrayValue<'ctx>, -) { - let array_ndims = source.load_ndims(ctx); - let broadcast_size = target.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(), - "0:ValueError", - "operands cannot be broadcast together", - [None, None, None], - ctx.current_loc, - ); -} - -/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value -/// with broadcast-compatible shapes. -fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - res: NDArrayValue<'ctx>, - (lhs_ty, lhs_val, lhs_scalar): (Type, BasicValueEnum<'ctx>, bool), - (rhs_ty, rhs_val, rhs_scalar): (Type, BasicValueEnum<'ctx>, bool), - value_fn: ValueFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), - ) -> Result, String>, -{ - assert!( - !(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type() - ); - - // Returns the element of an ndarray indexed by the given indices, performing int-promotion on - // `indices` where necessary. - // - // Required for compatibility with `NDArrayType::get_unchecked`. - let get_data_by_indices_compat = - |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| { - let llvm_usize = generator.get_size_type(ctx.ctx); - - // Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT - let stackptr = llvm_intrinsics::call_stacksave(ctx, None); - let indices = if llvm_usize == ctx.ctx.i32_type() { - indices - } else { - let indices_usize = TypedArrayLikeAdapter::>::from( - ArraySliceValue::from_ptr_val( - ctx.builder - .build_array_alloca(llvm_usize, indices.size(ctx, generator), "") - .unwrap(), - indices.size(ctx, generator), - None, - ), - |_, _, val| val.into_int_value(), - |_, _, val| val.into(), - ); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (indices.size(ctx, generator), false), - |generator, ctx, _, i| { - let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) }; - let idx = ctx - .builder - .build_int_z_extend_or_bit_cast(idx, llvm_usize, "") - .unwrap(); - unsafe { - indices_usize.set_typed_unchecked(ctx, generator, &i, idx); - } - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - indices_usize - }; - - let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) }; - - llvm_intrinsics::call_stackrestore(ctx, stackptr); - - elem - }; - - // Assert that all ndarray operands are broadcastable to the target size - if !lhs_scalar { - let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) - .map_value(lhs_val.into_pointer_value(), None); - ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val); - } - - if !rhs_scalar { - let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty) - .map_value(rhs_val.into_pointer_value(), None); - ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); - } - - ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| { - let lhs_elem = if lhs_scalar { - lhs_val - } else { - let lhs = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) - .map_value(lhs_val.into_pointer_value(), None); - let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); - - get_data_by_indices_compat(generator, ctx, lhs, lhs_idx) - }; - - let rhs_elem = if rhs_scalar { - rhs_val - } else { - let rhs = NDArrayType::from_unifier_type(generator, ctx, rhs_ty) - .map_value(rhs_val.into_pointer_value(), None); - let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); - - get_data_by_indices_compat(generator, ctx, rhs, rhs_idx) - }; - - value_fn(generator, ctx, (lhs_elem, rhs_elem)) - })?; - - Ok(res) -} - /// Copies a slice of an [`NDArrayValue`] to another. /// /// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `shape` @@ -592,101 +443,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) } -/// LLVM-typed implementation for computing elementwise binary operations on two input operands. -/// -/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output -/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple. -/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the -/// `value_fn` arguments tuple for all output elements. -/// -/// The second element of the tuple indicates whether to treat the operand value as a `ndarray` -/// (which would be accessed by its broadcast index) or as a scalar value (which would be -/// broadcast to all elements). -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -/// * `value_fn` - Function mapping the two input elements into the result. -/// -/// # Panic -/// -/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`. -pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - res: Option>, - lhs: (Type, BasicValueEnum<'ctx>, bool), - rhs: (Type, BasicValueEnum<'ctx>, bool), - value_fn: ValueFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), - ) -> Result, String>, -{ - let (lhs_ty, lhs_val, lhs_scalar) = lhs; - let (rhs_ty, rhs_val, rhs_scalar) = rhs; - - assert!( - !(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type() - ); - - let ndarray = res.unwrap_or_else(|| { - if lhs_scalar && rhs_scalar { - let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) - .map_value(lhs_val.into_pointer_value(), None); - let rhs_val = NDArrayType::from_unifier_type(generator, ctx, rhs_ty) - .map_value(rhs_val.into_pointer_value(), None); - - let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); - - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &ndarray_dims, - |generator, ctx, v| Ok(v.size(ctx, generator)), - |generator, ctx, v, idx| unsafe { - Ok(v.get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - } else { - let ndarray = NDArrayType::from_unifier_type( - generator, - ctx, - if lhs_scalar { rhs_ty } else { lhs_ty }, - ) - .map_value(if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), None); - - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &ndarray, - |_, ctx, v| Ok(v.load_ndims(ctx)), - |generator, ctx, v, idx| unsafe { - Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - } - }); - - ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| { - value_fn(generator, ctx, elems) - })?; - - Ok(ndarray) -} - /// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. /// /// * `elem_ty` - The element type of the `NDArray`.