core: Implement calculations for broadcasting ndarrays
This commit is contained in:
parent
deb325de4f
commit
e3fe3f03fb
|
@ -8,6 +8,8 @@ typedef unsigned _BitInt(64) uint64_t;
|
|||
# define MAX(a, b) (a > b ? a : b)
|
||||
# define MIN(a, b) (a > b ? b : a)
|
||||
|
||||
# define NULL ((void *) 0)
|
||||
|
||||
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||
// need to make sure `exp >= 0` before calling this function
|
||||
#define DEF_INT_EXP(T) T __nac3_int_exp_##T( \
|
||||
|
@ -293,3 +295,63 @@ uint64_t __nac3_ndarray_flatten_index64(
|
|||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast(
|
||||
const uint32_t *lhs_dims,
|
||||
uint32_t lhs_ndims,
|
||||
const uint32_t *rhs_dims,
|
||||
uint32_t rhs_ndims,
|
||||
uint32_t *out_dims
|
||||
) {
|
||||
uint32_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
||||
|
||||
for (uint32_t i = 0; i < max_ndims; ++i) {
|
||||
uint32_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL;
|
||||
uint32_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL;
|
||||
uint32_t *out_dim = &out_dims[max_ndims - i - 1];
|
||||
|
||||
if (lhs_dim_sz == NULL) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (rhs_dim_sz == NULL) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == 1) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (*rhs_dim_sz == 1) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else {
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast64(
|
||||
const uint64_t *lhs_dims,
|
||||
uint64_t lhs_ndims,
|
||||
const uint64_t *rhs_dims,
|
||||
uint64_t rhs_ndims,
|
||||
uint64_t *out_dims
|
||||
) {
|
||||
uint64_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
||||
|
||||
for (uint64_t i = 0; i < max_ndims; ++i) {
|
||||
uint64_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL;
|
||||
uint64_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL;
|
||||
uint64_t *out_dim = &out_dims[max_ndims - i - 1];
|
||||
|
||||
if (lhs_dim_sz == NULL) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (rhs_dim_sz == NULL) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == 1) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (*rhs_dim_sz == 1) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else {
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ use super::{
|
|||
classes::{ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, UntypedArrayLikeMutator},
|
||||
CodeGenContext,
|
||||
CodeGenerator,
|
||||
llvm_intrinsics,
|
||||
};
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
|
@ -813,4 +814,79 @@ pub fn call_ndarray_flatten_index_const<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ndarray,
|
||||
&indices_alloca,
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
|
||||
/// dimension and size of each dimension of the resultant `ndarray`.
|
||||
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
lhs: NDArrayValue<'ctx>,
|
||||
rhs: NDArrayValue<'ctx>,
|
||||
) -> (IntValue<'ctx>, PointerValue<'ctx>) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_broadcast",
|
||||
64 => "__nac3_ndarray_calc_broadcast64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
||||
};
|
||||
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[
|
||||
llvm_pusize.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let lhs_ndims = lhs.load_ndims(ctx);
|
||||
let rhs_ndims = rhs.load_ndims(ctx);
|
||||
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
|
||||
|
||||
// TODO: Generate assertion checks for whether each dimension is compatible
|
||||
// gen_for_callback_incrementing(
|
||||
// generator,
|
||||
// ctx,
|
||||
// llvm_usize.const_zero(),
|
||||
// (max_ndims, false),
|
||||
// |generator, ctx, idx| {
|
||||
// let lhs_dim_sz =
|
||||
//
|
||||
// let lhs_elem = lhs.get_dims().get(ctx, generator, idx, None);
|
||||
// let rhs_elem = rhs.get_dims().get(ctx, generator, idx, None);
|
||||
//
|
||||
//
|
||||
// },
|
||||
// llvm_usize.const_int(1, false),
|
||||
// ).unwrap();
|
||||
|
||||
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
|
||||
let lhs_ndims = lhs.load_ndims(ctx);
|
||||
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
|
||||
let rhs_ndims = rhs.load_ndims(ctx);
|
||||
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_broadcast_fn,
|
||||
&[
|
||||
lhs_dims.into(),
|
||||
lhs_ndims.into(),
|
||||
rhs_dims.into(),
|
||||
rhs_ndims.into(),
|
||||
out_dims.into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
(max_ndims, out_dims)
|
||||
}
|
Loading…
Reference in New Issue