From 87bc34f7ec8ac22e2e3aff5ee84af654835df046 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Mar 2024 11:08:43 +0800 Subject: [PATCH] core: Implement calculations for broadcasting ndarrays --- nac3core/src/codegen/irrt/irrt.c | 86 ++++++++++++++ nac3core/src/codegen/irrt/mod.rs | 195 +++++++++++++++++++++++++++++++ 2 files changed, 281 insertions(+) diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index fda931215..59c481f56 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -8,6 +8,8 @@ typedef unsigned _BitInt(64) uint64_t; # define MAX(a, b) (a > b ? a : b) # define MIN(a, b) (a > b ? b : a) +# define NULL ((void *) 0) + // adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c // need to make sure `exp >= 0` before calling this function #define DEF_INT_EXP(T) T __nac3_int_exp_##T( \ @@ -293,3 +295,87 @@ uint64_t __nac3_ndarray_flatten_index64( } return idx; } + +void __nac3_ndarray_calc_broadcast( + const uint32_t *lhs_dims, + uint32_t lhs_ndims, + const uint32_t *rhs_dims, + uint32_t rhs_ndims, + uint32_t *out_dims +) { + uint32_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; + + for (uint32_t i = 0; i < max_ndims; ++i) { + uint32_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL; + uint32_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL; + uint32_t *out_dim = &out_dims[max_ndims - i - 1]; + + if (lhs_dim_sz == NULL) { + *out_dim = *rhs_dim_sz; + } else if (rhs_dim_sz == NULL) { + *out_dim = *lhs_dim_sz; + } else if (*lhs_dim_sz == 1) { + *out_dim = *rhs_dim_sz; + } else if (*rhs_dim_sz == 1) { + *out_dim = *lhs_dim_sz; + } else if (*lhs_dim_sz == *rhs_dim_sz) { + *out_dim = *lhs_dim_sz; + } else { + __builtin_unreachable(); + } + } +} + +void __nac3_ndarray_calc_broadcast64( + const uint64_t *lhs_dims, + uint64_t lhs_ndims, + const uint64_t *rhs_dims, + uint64_t rhs_ndims, + uint64_t *out_dims +) { + uint64_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; + + for (uint64_t i = 0; i < max_ndims; ++i) { + uint64_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL; + uint64_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL; + uint64_t *out_dim = &out_dims[max_ndims - i - 1]; + + if (lhs_dim_sz == NULL) { + *out_dim = *rhs_dim_sz; + } else if (rhs_dim_sz == NULL) { + *out_dim = *lhs_dim_sz; + } else if (*lhs_dim_sz == 1) { + *out_dim = *rhs_dim_sz; + } else if (*rhs_dim_sz == 1) { + *out_dim = *lhs_dim_sz; + } else if (*lhs_dim_sz == *rhs_dim_sz) { + *out_dim = *lhs_dim_sz; + } else { + __builtin_unreachable(); + } + } +} + +void __nac3_ndarray_calc_broadcast_idx( + const uint32_t *src_dims, + uint32_t src_ndims, + const uint32_t *in_idx, + uint32_t *out_idx +) { + for (uint32_t i = 0; i < src_ndims; ++i) { + uint32_t src_i = src_ndims - i - 1; + out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; + } +} + +void __nac3_ndarray_calc_broadcast_idx64( + const uint64_t *src_dims, + uint64_t src_ndims, + const uint32_t *in_idx, + uint32_t *out_idx +) { + for (uint64_t i = 0; i < src_ndims; ++i) { + uint64_t src_i = src_ndims - i - 1; + out_idx[src_i] = src_dims[src_i] == 1 ? 0 : (uint32_t) in_idx[src_i]; + } +} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 6da048fec..ce675ced6 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -8,9 +8,11 @@ use super::{ ListValue, NDArrayValue, TypedArrayLikeAdapter, + UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, + llvm_intrinsics, }; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -23,6 +25,8 @@ use inkwell::{ }; use itertools::Either; use nac3parser::ast::Expr; +use crate::codegen::classes::TypedArrayLikeAccessor; +use crate::codegen::stmt::gen_for_callback_incrementing; #[must_use] pub fn load_irrt(ctx: &Context) -> Module { @@ -783,4 +787,195 @@ pub fn call_ndarray_flatten_index<'ctx, G, Index>( ndarray, indices, ) +} + +/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of +/// dimension and size of each dimension of the resultant `ndarray`. +pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + lhs: NDArrayValue<'ctx>, + rhs: NDArrayValue<'ctx>, +) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_broadcast", + 64 => "__nac3_ndarray_calc_broadcast64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[ + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + }); + + let lhs_ndims = lhs.load_ndims(ctx); + let rhs_ndims = rhs.load_ndims(ctx); + let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); + + gen_for_callback_incrementing( + generator, + ctx, + llvm_usize.const_zero(), + (min_ndims, false), + |generator, ctx, idx| { + let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); + let (lhs_dim_sz, rhs_dim_sz) = unsafe { + ( + lhs.dim_sizes().get_typed_unchecked(ctx, generator, idx, None), + rhs.dim_sizes().get_typed_unchecked(ctx, generator, idx, None), + ) + }; + + let llvm_usize_const_one = llvm_usize.const_int(1, false); + let lhs_eqz = ctx.builder.build_int_compare( + IntPredicate::EQ, + lhs_dim_sz, + llvm_usize_const_one, + "", + ).unwrap(); + let rhs_eqz = ctx.builder.build_int_compare( + IntPredicate::EQ, + rhs_dim_sz, + llvm_usize_const_one, + "", + ).unwrap(); + let lhs_or_rhs_eqz = ctx.builder.build_or( + lhs_eqz, + rhs_eqz, + "" + ).unwrap(); + + let lhs_eq_rhs = ctx.builder.build_int_compare( + IntPredicate::EQ, + lhs_dim_sz, + rhs_dim_sz, + "" + ).unwrap(); + + let is_compatible = ctx.builder.build_or( + lhs_or_rhs_eqz, + lhs_eq_rhs, + "" + ).unwrap(); + + ctx.make_assert( + generator, + is_compatible, + "0:ValueError", + "operands could not be broadcast together", + [None, None, None], + ctx.current_loc, + ); + + Ok(()) + }, + llvm_usize.const_int(1, false), + ).unwrap(); + + let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); + let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); + let lhs_ndims = lhs.load_ndims(ctx); + let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); + let rhs_ndims = rhs.load_ndims(ctx); + let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); + let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); + + ctx.builder + .build_call( + ndarray_calc_broadcast_fn, + &[ + lhs_dims.into(), + lhs_ndims.into(), + rhs_dims.into(), + rhs_ndims.into(), + out_dims.base_ptr(ctx, generator).into(), + ], + "", + ) + .unwrap(); + + TypedArrayLikeAdapter::from( + out_dims, + Box::new(|_, v| v.into_int_value()), + Box::new(|_, v| v.into()), + ) +} + +/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] +/// containing the indices used for accessing `array` corresponding to the index of the broadcasted +/// array `broadcast_idx`. +pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + array: NDArrayValue<'ctx>, + broadcast_idx: &BroadcastIdx, +) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_broadcast_idx", + 64 => "__nac3_ndarray_calc_broadcast_idx64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[ + llvm_pusize.into(), + llvm_usize.into(), + llvm_pi32.into(), + llvm_pi32.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + }); + + let broadcast_size = broadcast_idx.size(ctx, generator); + let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); + + let array_dims = array.dim_sizes().base_ptr(ctx, generator); + let array_ndims = array.load_ndims(ctx); + let broadcast_idx_ptr = unsafe { + broadcast_idx.ptr_offset_unchecked( + ctx, + generator, + llvm_usize.const_zero(), + None + ) + }; + + ctx.builder + .build_call( + ndarray_calc_broadcast_fn, + &[ + array_dims.into(), + array_ndims.into(), + broadcast_idx_ptr.into(), + out_idx.into(), + ], + "", + ) + .unwrap(); + + TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), + Box::new(|_, v| v.into_int_value()), + Box::new(|_, v| v.into()), + ) } \ No newline at end of file