From e3fe3f03fbc4f1369c779ff1976f8e0801001423 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Mar 2024 11:08:43 +0800 Subject: [PATCH] core: Implement calculations for broadcasting ndarrays --- nac3core/src/codegen/irrt/irrt.c | 62 ++++++++++++++++++++++++++ nac3core/src/codegen/irrt/mod.rs | 76 ++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+) diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index bbe27ce..363c3c2 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -8,6 +8,8 @@ typedef unsigned _BitInt(64) uint64_t; # define MAX(a, b) (a > b ? a : b) # define MIN(a, b) (a > b ? b : a) +# define NULL ((void *) 0) + // adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c // need to make sure `exp >= 0` before calling this function #define DEF_INT_EXP(T) T __nac3_int_exp_##T( \ @@ -293,3 +295,63 @@ uint64_t __nac3_ndarray_flatten_index64( } return idx; } + +void __nac3_ndarray_calc_broadcast( + const uint32_t *lhs_dims, + uint32_t lhs_ndims, + const uint32_t *rhs_dims, + uint32_t rhs_ndims, + uint32_t *out_dims +) { + uint32_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; + + for (uint32_t i = 0; i < max_ndims; ++i) { + uint32_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL; + uint32_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL; + uint32_t *out_dim = &out_dims[max_ndims - i - 1]; + + if (lhs_dim_sz == NULL) { + *out_dim = *rhs_dim_sz; + } else if (rhs_dim_sz == NULL) { + *out_dim = *lhs_dim_sz; + } else if (*lhs_dim_sz == 1) { + *out_dim = *rhs_dim_sz; + } else if (*rhs_dim_sz == 1) { + *out_dim = *lhs_dim_sz; + } else if (*lhs_dim_sz == *rhs_dim_sz) { + *out_dim = *lhs_dim_sz; + } else { + __builtin_unreachable(); + } + } +} + +void __nac3_ndarray_calc_broadcast64( + const uint64_t *lhs_dims, + uint64_t lhs_ndims, + const uint64_t *rhs_dims, + uint64_t rhs_ndims, + uint64_t *out_dims +) { + uint64_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; + + for (uint64_t i = 0; i < max_ndims; ++i) { + uint64_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL; + uint64_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL; + uint64_t *out_dim = &out_dims[max_ndims - i - 1]; + + if (lhs_dim_sz == NULL) { + *out_dim = *rhs_dim_sz; + } else if (rhs_dim_sz == NULL) { + *out_dim = *lhs_dim_sz; + } else if (*lhs_dim_sz == 1) { + *out_dim = *rhs_dim_sz; + } else if (*rhs_dim_sz == 1) { + *out_dim = *lhs_dim_sz; + } else if (*lhs_dim_sz == *rhs_dim_sz) { + *out_dim = *lhs_dim_sz; + } else { + __builtin_unreachable(); + } + } +} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 0008d7b..d2c9248 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -4,6 +4,7 @@ use super::{ classes::{ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, UntypedArrayLikeMutator}, CodeGenContext, CodeGenerator, + llvm_intrinsics, }; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -813,4 +814,79 @@ pub fn call_ndarray_flatten_index_const<'ctx, G: CodeGenerator + ?Sized>( ndarray, &indices_alloca, ) +} + +/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of +/// dimension and size of each dimension of the resultant `ndarray`. +pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + lhs: NDArrayValue<'ctx>, + rhs: NDArrayValue<'ctx>, +) -> (IntValue<'ctx>, PointerValue<'ctx>) { + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_broadcast", + 64 => "__nac3_ndarray_calc_broadcast64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[ + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + }); + + let lhs_ndims = lhs.load_ndims(ctx); + let rhs_ndims = rhs.load_ndims(ctx); + let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); + + // TODO: Generate assertion checks for whether each dimension is compatible + // gen_for_callback_incrementing( + // generator, + // ctx, + // llvm_usize.const_zero(), + // (max_ndims, false), + // |generator, ctx, idx| { + // let lhs_dim_sz = + // + // let lhs_elem = lhs.get_dims().get(ctx, generator, idx, None); + // let rhs_elem = rhs.get_dims().get(ctx, generator, idx, None); + // + // + // }, + // llvm_usize.const_int(1, false), + // ).unwrap(); + + let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); + let lhs_ndims = lhs.load_ndims(ctx); + let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); + let rhs_ndims = rhs.load_ndims(ctx); + let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); + + ctx.builder + .build_call( + ndarray_calc_broadcast_fn, + &[ + lhs_dims.into(), + lhs_ndims.into(), + rhs_dims.into(), + rhs_ndims.into(), + out_dims.into(), + ], + "", + ) + .unwrap(); + + (max_ndims, out_dims) } \ No newline at end of file