diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index fda9312..59c481f 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -8,6 +8,8 @@ typedef unsigned _BitInt(64) uint64_t; # define MAX(a, b) (a > b ? a : b) # define MIN(a, b) (a > b ? b : a) +# define NULL ((void *) 0) + // adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c // need to make sure `exp >= 0` before calling this function #define DEF_INT_EXP(T) T __nac3_int_exp_##T( \ @@ -293,3 +295,87 @@ uint64_t __nac3_ndarray_flatten_index64( } return idx; } + +void __nac3_ndarray_calc_broadcast( + const uint32_t *lhs_dims, + uint32_t lhs_ndims, + const uint32_t *rhs_dims, + uint32_t rhs_ndims, + uint32_t *out_dims +) { + uint32_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; + + for (uint32_t i = 0; i < max_ndims; ++i) { + uint32_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL; + uint32_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL; + uint32_t *out_dim = &out_dims[max_ndims - i - 1]; + + if (lhs_dim_sz == NULL) { + *out_dim = *rhs_dim_sz; + } else if (rhs_dim_sz == NULL) { + *out_dim = *lhs_dim_sz; + } else if (*lhs_dim_sz == 1) { + *out_dim = *rhs_dim_sz; + } else if (*rhs_dim_sz == 1) { + *out_dim = *lhs_dim_sz; + } else if (*lhs_dim_sz == *rhs_dim_sz) { + *out_dim = *lhs_dim_sz; + } else { + __builtin_unreachable(); + } + } +} + +void __nac3_ndarray_calc_broadcast64( + const uint64_t *lhs_dims, + uint64_t lhs_ndims, + const uint64_t *rhs_dims, + uint64_t rhs_ndims, + uint64_t *out_dims +) { + uint64_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; + + for (uint64_t i = 0; i < max_ndims; ++i) { + uint64_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL; + uint64_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL; + uint64_t *out_dim = &out_dims[max_ndims - i - 1]; + + if (lhs_dim_sz == NULL) { + *out_dim = *rhs_dim_sz; + } else if (rhs_dim_sz == NULL) { + *out_dim = *lhs_dim_sz; + } else if (*lhs_dim_sz == 1) { + *out_dim = *rhs_dim_sz; + } else if (*rhs_dim_sz == 1) { + *out_dim = *lhs_dim_sz; + } else if (*lhs_dim_sz == *rhs_dim_sz) { + *out_dim = *lhs_dim_sz; + } else { + __builtin_unreachable(); + } + } +} + +void __nac3_ndarray_calc_broadcast_idx( + const uint32_t *src_dims, + uint32_t src_ndims, + const uint32_t *in_idx, + uint32_t *out_idx +) { + for (uint32_t i = 0; i < src_ndims; ++i) { + uint32_t src_i = src_ndims - i - 1; + out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; + } +} + +void __nac3_ndarray_calc_broadcast_idx64( + const uint64_t *src_dims, + uint64_t src_ndims, + const uint32_t *in_idx, + uint32_t *out_idx +) { + for (uint64_t i = 0; i < src_ndims; ++i) { + uint64_t src_i = src_ndims - i - 1; + out_idx[src_i] = src_dims[src_i] == 1 ? 0 : (uint32_t) in_idx[src_i]; + } +} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 6da048f..97592f4 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -8,9 +8,11 @@ use super::{ ListValue, NDArrayValue, TypedArrayLikeAdapter, + UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, + llvm_intrinsics, }; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -783,4 +785,153 @@ pub fn call_ndarray_flatten_index<'ctx, G, Index>( ndarray, indices, ) +} + +/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of +/// dimension and size of each dimension of the resultant `ndarray`. +pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + lhs: NDArrayValue<'ctx>, + rhs: NDArrayValue<'ctx>, +) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_broadcast", + 64 => "__nac3_ndarray_calc_broadcast64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[ + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + }); + + let lhs_ndims = lhs.load_ndims(ctx); + let rhs_ndims = rhs.load_ndims(ctx); + let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); + + // TODO: Generate assertion checks for whether each dimension is compatible + // gen_for_callback_incrementing( + // generator, + // ctx, + // llvm_usize.const_zero(), + // (max_ndims, false), + // |generator, ctx, idx| { + // let lhs_dim_sz = + // + // let lhs_elem = lhs.get_dims().get(ctx, generator, idx, None); + // let rhs_elem = rhs.get_dims().get(ctx, generator, idx, None); + // + // + // }, + // llvm_usize.const_int(1, false), + // ).unwrap(); + + let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); + let lhs_ndims = lhs.load_ndims(ctx); + let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); + let rhs_ndims = rhs.load_ndims(ctx); + let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); + let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); + + ctx.builder + .build_call( + ndarray_calc_broadcast_fn, + &[ + lhs_dims.into(), + lhs_ndims.into(), + rhs_dims.into(), + rhs_ndims.into(), + out_dims.base_ptr(ctx, generator).into(), + ], + "", + ) + .unwrap(); + + TypedArrayLikeAdapter::from( + out_dims, + Box::new(|_, v| v.into_int_value()), + Box::new(|_, v| v.into()), + ) +} + +/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] +/// containing the indices used for accessing `array` corresponding to the index of the broadcasted +/// array `broadcast_idx`. +pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + array: NDArrayValue<'ctx>, + broadcast_idx: &BroadcastIdx, +) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_broadcast_idx", + 64 => "__nac3_ndarray_calc_broadcast_idx64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[ + llvm_pusize.into(), + llvm_usize.into(), + llvm_pi32.into(), + llvm_pi32.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + }); + + // TODO: Assertions + + let broadcast_size = broadcast_idx.size(ctx, generator); + let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); + + let array_dims = array.dim_sizes().base_ptr(ctx, generator); + let array_ndims = array.load_ndims(ctx); + let broadcast_idx_ptr = unsafe { + broadcast_idx.ptr_offset_unchecked( + ctx, + generator, + llvm_usize.const_zero(), + None + ) + }; + + ctx.builder + .build_call( + ndarray_calc_broadcast_fn, + &[ + array_dims.into(), + array_ndims.into(), + broadcast_idx_ptr.into(), + out_idx.into(), + ], + "", + ) + .unwrap(); + + TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), + Box::new(|_, v| v.into_int_value()), + Box::new(|_, v| v.into()), + ) } \ No newline at end of file