[core] codegen/ndarray: Use IRRT for size() and indexing operations

Also refactor some usages of call_ndarray_calc_size with ndarray.size().
This commit is contained in:
David Mak 2024-12-19 12:21:08 +08:00
parent 155002629b
commit 90071be0a7
5 changed files with 102 additions and 192 deletions

View File

@ -29,25 +29,6 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n
} }
} }
template<typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims,
SizeT num_dims,
const NDIndexInt* indices,
SizeT num_indices) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template<typename SizeT> template<typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
SizeT lhs_ndims, SizeT lhs_ndims,
@ -107,18 +88,6 @@ void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
} }
uint32_t
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims,
uint64_t num_dims,
const NDIndexInt* indices,
uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims, void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
uint32_t lhs_ndims, uint32_t lhs_ndims,
const uint32_t* rhs_dims, const uint32_t* rhs_dims,

View File

@ -877,8 +877,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty); let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty);
let n = llvm_ndarray_ty.map_value(n, None); let n = llvm_ndarray_ty.map_value(n, None);
let n_sz = let n_sz = n.size(generator, ctx);
irrt::ndarray::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx let n_sz_eqz = ctx
.builder .builder
@ -913,7 +912,16 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
llvm_int64.const_int(1, false), llvm_int64.const_int(1, false),
(n_sz, false), (n_sz, false),
|generator, ctx, _, idx| { |generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let elem = unsafe {
n.data().get_unchecked(
ctx,
generator,
&ctx.builder
.build_int_truncate_or_bit_cast(idx, llvm_usize, "")
.unwrap(),
None,
)
};
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); let cur_idx = ctx.builder.build_load(res_idx, "").unwrap();

View File

@ -1,5 +1,5 @@
use inkwell::{ use inkwell::{
types::{BasicTypeEnum, IntType}, types::BasicTypeEnum,
values::{BasicValueEnum, CallSiteValue, IntValue}, values::{BasicValueEnum, CallSiteValue, IntValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
@ -138,78 +138,8 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
) )
} }
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns a `usize` of the flattened index for /// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// the multidimensional index. /// dimension and size of each dimension of the resultant `ndarray`.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_flatten_index");
let ndarray_flatten_index_fn =
ctx.module.get_function(&ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
ctx.module.add_function(&ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.shape();
let index = ctx
.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
index
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a [`TypedArrayLikeAdapter`]
/// containing the size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,

View File

@ -21,8 +21,8 @@ use super::{
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
types::{ndarray::NDArrayType, ListType, ProxyType}, types::{ndarray::NDArrayType, ListType, ProxyType},
values::{ values::{
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
}, },
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
@ -318,12 +318,7 @@ where
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_num_elems = call_ndarray_calc_size( let ndarray_num_elems = ndarray.size(generator, ctx);
generator,
ctx,
&ndarray.shape().as_slice_value(ctx, generator),
(None, None),
);
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
@ -434,6 +429,66 @@ where
rhs_val.get_type() rhs_val.get_type()
); );
// Returns the element of an ndarray indexed by the given indices, performing int-promotion on
// `indices` where necessary.
//
// Required for compatibility with `NDArrayType::get_unchecked`.
let get_data_by_indices_compat =
|generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| {
let llvm_usize = generator.get_size_type(ctx.ctx);
// Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT
let stackptr = llvm_intrinsics::call_stacksave(ctx, None);
let indices = if llvm_usize == ctx.ctx.i32_type() {
indices
} else {
let indices_usize = TypedArrayLikeAdapter::<G, IntValue<'ctx>>::from(
ArraySliceValue::from_ptr_val(
ctx.builder
.build_array_alloca(llvm_usize, indices.size(ctx, generator), "")
.unwrap(),
indices.size(ctx, generator),
None,
),
|_, _, val| val.into_int_value(),
|_, _, val| val.into(),
);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(indices.size(ctx, generator), false),
|generator, ctx, _, i| {
let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) };
let idx = ctx
.builder
.build_int_z_extend_or_bit_cast(idx, llvm_usize, "")
.unwrap();
unsafe {
indices_usize.set_typed_unchecked(ctx, generator, &i, idx);
}
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
indices_usize
};
let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) };
llvm_intrinsics::call_stackrestore(ctx, stackptr);
elem
};
// Assert that all ndarray operands are broadcastable to the target size // Assert that all ndarray operands are broadcastable to the target size
if !lhs_scalar { if !lhs_scalar {
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty) let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
@ -455,7 +510,7 @@ where
.map_value(lhs_val.into_pointer_value(), None); .map_value(lhs_val.into_pointer_value(), None);
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } get_data_by_indices_compat(generator, ctx, lhs, lhs_idx)
}; };
let rhs_elem = if rhs_scalar { let rhs_elem = if rhs_scalar {
@ -465,7 +520,7 @@ where
.map_value(rhs_val.into_pointer_value(), None); .map_value(rhs_val.into_pointer_value(), None);
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } get_data_by_indices_compat(generator, ctx, rhs, rhs_idx)
}; };
value_fn(generator, ctx, (lhs_elem, rhs_elem)) value_fn(generator, ctx, (lhs_elem, rhs_elem))
@ -1408,7 +1463,6 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
lhs: NDArrayValue<'ctx>, lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>, rhs: NDArrayValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
@ -1597,19 +1651,19 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() ctx.builder.build_int_z_extend_or_bit_cast(idx, llvm_usize, "").unwrap()
}; };
let idx0 = unsafe { let idx0 = unsafe {
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() ctx.builder.build_int_z_extend_or_bit_cast(idx0, llvm_usize, "").unwrap()
}; };
let idx1 = unsafe { let idx1 = unsafe {
let idx1 = let idx1 =
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() ctx.builder.build_int_z_extend_or_bit_cast(idx1, llvm_usize, "").unwrap()
}; };
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
@ -1620,14 +1674,12 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
generator, generator,
ctx, ctx,
None, None,
llvm_i32.const_zero(), llvm_usize.const_zero(),
(common_dim, false), (common_dim, false),
|generator, ctx, _, i| { |generator, ctx, _, i| {
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
let ab_idx = generator.gen_array_var_alloc( let ab_idx = generator.gen_array_var_alloc(
ctx, ctx,
llvm_i32.into(), llvm_usize.into(),
llvm_usize.const_int(2, false), llvm_usize.const_int(2, false),
None, None,
)?; )?;
@ -2002,7 +2054,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = llvm_ndarray_ty.map_value(n1, None); let n1 = llvm_ndarray_ty.map_value(n1, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n_sz = n1.size(generator, ctx);
// Dimensions are reversed in the transposed array // Dimensions are reversed in the transposed array
let out = create_ndarray_dyn_shape( let out = create_ndarray_dyn_shape(
@ -2122,7 +2174,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
let n1 = llvm_ndarray_ty.map_value(n1, None); let n1 = llvm_ndarray_ty.map_value(n1, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n_sz = n1.size(generator, ctx);
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
@ -2350,7 +2402,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
); );
// The new shape must be compatible with the old shape // The new shape must be compatible with the old shape
let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None)); let out_sz = out.size(generator, ctx);
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
@ -2407,8 +2459,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None);
let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n1_sz = n1.size(generator, ctx);
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let n2_sz = n2.size(generator, ctx);
ctx.make_assert( ctx.make_assert(
generator, generator,

View File

@ -671,12 +671,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
generator: &G, generator: &G,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
irrt::ndarray::call_ndarray_calc_size( irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self.0)
generator,
ctx,
&self.as_slice_value(ctx, generator),
(None, None),
)
} }
} }
@ -688,24 +683,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
idx: &IntValue<'ctx>, idx: &IntValue<'ctx>,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let sizeof_elem = ctx let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(generator, ctx, *self.0, *idx);
.builder
.build_int_truncate_or_bit_cast(
self.element_type(ctx, generator).size_of().unwrap(),
idx.get_type(),
"",
)
.unwrap();
let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap();
let ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[idx],
name.unwrap_or_default(),
)
.unwrap()
};
// Current implementation is transparent - The returned pointer type is // Current implementation is transparent - The returned pointer type is
// already cast into the expected type, allowing for immediately // already cast into the expected type, allowing for immediately
@ -716,7 +694,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
BasicTypeEnum::try_from(self.element_type(ctx, generator)) BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap() .unwrap()
.ptr_type(AddressSpace::default()), .ptr_type(AddressSpace::default()),
"", name.unwrap_or_default(),
) )
.unwrap() .unwrap()
} }
@ -769,52 +747,25 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
indices: &Index, indices: &Index,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx); assert_eq!(indices.element_type(ctx, generator), generator.get_size_type(ctx.ctx).into());
let indices_elem_ty = unsafe { let ptr = irrt::ndarray::call_nac3_ndarray_get_pelement_by_indices(
indices generator,
.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) ctx,
.get_type() *self.0,
.get_element_type() indices.base_ptr(ctx, generator),
};
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected list[int32] but got {indices_elem_ty}")
};
assert_eq!(
indices_elem_ty.get_bit_width(),
32,
"Expected list[int32] but got list[int{}]",
indices_elem_ty.get_bit_width()
); );
let index = irrt::ndarray::call_ndarray_flatten_index(generator, ctx, *self.0, indices); // Current implementation is transparent - The returned pointer type is
let sizeof_elem = ctx // already cast into the expected type, allowing for immediately
.builder // load/store.
.build_int_truncate_or_bit_cast(
self.element_type(ctx, generator).size_of().unwrap(),
index.get_type(),
"",
)
.unwrap();
let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap();
let ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[index],
name.unwrap_or_default(),
)
.unwrap()
};
// TODO: Current implementation is transparent
ctx.builder ctx.builder
.build_pointer_cast( .build_pointer_cast(
ptr, ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator)) BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap() .unwrap()
.ptr_type(AddressSpace::default()), .ptr_type(AddressSpace::default()),
"", name.unwrap_or_default(),
) )
.unwrap() .unwrap()
} }