[core] codegen/ndarray: Use IRRT for size() and indexing operations
Also refactor some usages of call_ndarray_calc_size with ndarray.size().
This commit is contained in:
parent
155002629b
commit
90071be0a7
@ -29,25 +29,6 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n
|
||||
}
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims,
|
||||
SizeT num_dims,
|
||||
const NDIndexInt* indices,
|
||||
SizeT num_indices) {
|
||||
SizeT idx = 0;
|
||||
SizeT stride = 1;
|
||||
for (SizeT i = 0; i < num_dims; ++i) {
|
||||
SizeT ri = num_dims - i - 1;
|
||||
if (ri < num_indices) {
|
||||
idx += stride * indices[ri];
|
||||
}
|
||||
|
||||
__builtin_assume(dims[i] > 0);
|
||||
stride *= dims[ri];
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
|
||||
SizeT lhs_ndims,
|
||||
@ -107,18 +88,6 @@ void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint
|
||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||
}
|
||||
|
||||
uint32_t
|
||||
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) {
|
||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||
}
|
||||
|
||||
uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims,
|
||||
uint64_t num_dims,
|
||||
const NDIndexInt* indices,
|
||||
uint64_t num_indices) {
|
||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
|
||||
uint32_t lhs_ndims,
|
||||
const uint32_t* rhs_dims,
|
||||
|
@ -877,8 +877,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty);
|
||||
|
||||
let n = llvm_ndarray_ty.map_value(n, None);
|
||||
let n_sz =
|
||||
irrt::ndarray::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
|
||||
let n_sz = n.size(generator, ctx);
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let n_sz_eqz = ctx
|
||||
.builder
|
||||
@ -913,7 +912,16 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
llvm_int64.const_int(1, false),
|
||||
(n_sz, false),
|
||||
|generator, ctx, _, idx| {
|
||||
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
|
||||
let elem = unsafe {
|
||||
n.data().get_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&ctx.builder
|
||||
.build_int_truncate_or_bit_cast(idx, llvm_usize, "")
|
||||
.unwrap(),
|
||||
None,
|
||||
)
|
||||
};
|
||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||
let cur_idx = ctx.builder.build_load(res_idx, "").unwrap();
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
use inkwell::{
|
||||
types::{BasicTypeEnum, IntType},
|
||||
types::BasicTypeEnum,
|
||||
values::{BasicValueEnum, CallSiteValue, IntValue},
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
@ -138,78 +138,8 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns a `usize` of the flattened index for
|
||||
/// the multidimensional index.
|
||||
///
|
||||
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: &Index,
|
||||
) -> IntValue<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Index: ArrayLikeIndexer<'ctx>,
|
||||
{
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
debug_assert_eq!(
|
||||
IntType::try_from(indices.element_type(ctx, generator))
|
||||
.map(IntType::get_bit_width)
|
||||
.unwrap_or_default(),
|
||||
llvm_i32.get_bit_width(),
|
||||
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
|
||||
);
|
||||
debug_assert_eq!(
|
||||
indices.size(ctx, generator).get_type().get_bit_width(),
|
||||
llvm_usize.get_bit_width(),
|
||||
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
|
||||
);
|
||||
|
||||
let ndarray_flatten_index_fn_name =
|
||||
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_flatten_index");
|
||||
let ndarray_flatten_index_fn =
|
||||
ctx.module.get_function(&ndarray_flatten_index_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(&ndarray_flatten_index_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let ndarray_dims = ndarray.shape();
|
||||
|
||||
let index = ctx
|
||||
.builder
|
||||
.build_call(
|
||||
ndarray_flatten_index_fn,
|
||||
&[
|
||||
ndarray_dims.base_ptr(ctx, generator).into(),
|
||||
ndarray_num_dims.into(),
|
||||
indices.base_ptr(ctx, generator).into(),
|
||||
indices.size(ctx, generator).into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap();
|
||||
|
||||
index
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a [`TypedArrayLikeAdapter`]
|
||||
/// containing the size of each dimension of the resultant `ndarray`.
|
||||
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
|
||||
/// dimension and size of each dimension of the resultant `ndarray`.
|
||||
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
|
@ -21,8 +21,8 @@ use super::{
|
||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||
types::{ndarray::NDArrayType, ListType, ProxyType},
|
||||
values::{
|
||||
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue,
|
||||
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
|
||||
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue,
|
||||
ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
|
||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
@ -318,12 +318,7 @@ where
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
&ndarray.shape().as_slice_value(ctx, generator),
|
||||
(None, None),
|
||||
);
|
||||
let ndarray_num_elems = ndarray.size(generator, ctx);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
@ -434,6 +429,66 @@ where
|
||||
rhs_val.get_type()
|
||||
);
|
||||
|
||||
// Returns the element of an ndarray indexed by the given indices, performing int-promotion on
|
||||
// `indices` where necessary.
|
||||
//
|
||||
// Required for compatibility with `NDArrayType::get_unchecked`.
|
||||
let get_data_by_indices_compat =
|
||||
|generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
// Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT
|
||||
let stackptr = llvm_intrinsics::call_stacksave(ctx, None);
|
||||
let indices = if llvm_usize == ctx.ctx.i32_type() {
|
||||
indices
|
||||
} else {
|
||||
let indices_usize = TypedArrayLikeAdapter::<G, IntValue<'ctx>>::from(
|
||||
ArraySliceValue::from_ptr_val(
|
||||
ctx.builder
|
||||
.build_array_alloca(llvm_usize, indices.size(ctx, generator), "")
|
||||
.unwrap(),
|
||||
indices.size(ctx, generator),
|
||||
None,
|
||||
),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(indices.size(ctx, generator), false),
|
||||
|generator, ctx, _, i| {
|
||||
let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) };
|
||||
let idx = ctx
|
||||
.builder
|
||||
.build_int_z_extend_or_bit_cast(idx, llvm_usize, "")
|
||||
.unwrap();
|
||||
unsafe {
|
||||
indices_usize.set_typed_unchecked(ctx, generator, &i, idx);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
indices_usize
|
||||
};
|
||||
|
||||
let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) };
|
||||
|
||||
llvm_intrinsics::call_stackrestore(ctx, stackptr);
|
||||
|
||||
elem
|
||||
};
|
||||
|
||||
// Assert that all ndarray operands are broadcastable to the target size
|
||||
if !lhs_scalar {
|
||||
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
|
||||
@ -455,7 +510,7 @@ where
|
||||
.map_value(lhs_val.into_pointer_value(), None);
|
||||
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
|
||||
|
||||
unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) }
|
||||
get_data_by_indices_compat(generator, ctx, lhs, lhs_idx)
|
||||
};
|
||||
|
||||
let rhs_elem = if rhs_scalar {
|
||||
@ -465,7 +520,7 @@ where
|
||||
.map_value(rhs_val.into_pointer_value(), None);
|
||||
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
|
||||
|
||||
unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) }
|
||||
get_data_by_indices_compat(generator, ctx, rhs, rhs_idx)
|
||||
};
|
||||
|
||||
value_fn(generator, ctx, (lhs_elem, rhs_elem))
|
||||
@ -1408,7 +1463,6 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||
lhs: NDArrayValue<'ctx>,
|
||||
rhs: NDArrayValue<'ctx>,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if cfg!(debug_assertions) {
|
||||
@ -1597,19 +1651,19 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||
|
||||
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
||||
|
||||
ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap()
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(idx, llvm_usize, "").unwrap()
|
||||
};
|
||||
|
||||
let idx0 = unsafe {
|
||||
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
||||
|
||||
ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap()
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(idx0, llvm_usize, "").unwrap()
|
||||
};
|
||||
let idx1 = unsafe {
|
||||
let idx1 =
|
||||
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
|
||||
|
||||
ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap()
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(idx1, llvm_usize, "").unwrap()
|
||||
};
|
||||
|
||||
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
||||
@ -1620,14 +1674,12 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_i32.const_zero(),
|
||||
llvm_usize.const_zero(),
|
||||
(common_dim, false),
|
||||
|generator, ctx, _, i| {
|
||||
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
|
||||
|
||||
let ab_idx = generator.gen_array_var_alloc(
|
||||
ctx,
|
||||
llvm_i32.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_usize.const_int(2, false),
|
||||
None,
|
||||
)?;
|
||||
@ -2002,7 +2054,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1 = llvm_ndarray_ty.map_value(n1, None);
|
||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||
let n_sz = n1.size(generator, ctx);
|
||||
|
||||
// Dimensions are reversed in the transposed array
|
||||
let out = create_ndarray_dyn_shape(
|
||||
@ -2122,7 +2174,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
|
||||
let n1 = llvm_ndarray_ty.map_value(n1, None);
|
||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||
let n_sz = n1.size(generator, ctx);
|
||||
|
||||
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
@ -2350,7 +2402,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
);
|
||||
|
||||
// The new shape must be compatible with the old shape
|
||||
let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None));
|
||||
let out_sz = out.size(generator, ctx);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
||||
@ -2407,8 +2459,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None);
|
||||
let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None);
|
||||
|
||||
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
||||
let n1_sz = n1.size(generator, ctx);
|
||||
let n2_sz = n2.size(generator, ctx);
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
|
@ -671,12 +671,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
) -> IntValue<'ctx> {
|
||||
irrt::ndarray::call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
&self.as_slice_value(ctx, generator),
|
||||
(None, None),
|
||||
)
|
||||
irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self.0)
|
||||
}
|
||||
}
|
||||
|
||||
@ -688,24 +683,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let sizeof_elem = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(
|
||||
self.element_type(ctx, generator).size_of().unwrap(),
|
||||
idx.get_type(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap();
|
||||
let ptr = unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(
|
||||
self.base_ptr(ctx, generator),
|
||||
&[idx],
|
||||
name.unwrap_or_default(),
|
||||
)
|
||||
.unwrap()
|
||||
};
|
||||
let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(generator, ctx, *self.0, *idx);
|
||||
|
||||
// Current implementation is transparent - The returned pointer type is
|
||||
// already cast into the expected type, allowing for immediately
|
||||
@ -716,7 +694,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
||||
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
||||
.unwrap()
|
||||
.ptr_type(AddressSpace::default()),
|
||||
"",
|
||||
name.unwrap_or_default(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
@ -769,52 +747,25 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
||||
indices: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
assert_eq!(indices.element_type(ctx, generator), generator.get_size_type(ctx.ctx).into());
|
||||
|
||||
let indices_elem_ty = unsafe {
|
||||
indices
|
||||
.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.get_type()
|
||||
.get_element_type()
|
||||
};
|
||||
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
|
||||
panic!("Expected list[int32] but got {indices_elem_ty}")
|
||||
};
|
||||
assert_eq!(
|
||||
indices_elem_ty.get_bit_width(),
|
||||
32,
|
||||
"Expected list[int32] but got list[int{}]",
|
||||
indices_elem_ty.get_bit_width()
|
||||
let ptr = irrt::ndarray::call_nac3_ndarray_get_pelement_by_indices(
|
||||
generator,
|
||||
ctx,
|
||||
*self.0,
|
||||
indices.base_ptr(ctx, generator),
|
||||
);
|
||||
|
||||
let index = irrt::ndarray::call_ndarray_flatten_index(generator, ctx, *self.0, indices);
|
||||
let sizeof_elem = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(
|
||||
self.element_type(ctx, generator).size_of().unwrap(),
|
||||
index.get_type(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap();
|
||||
|
||||
let ptr = unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(
|
||||
self.base_ptr(ctx, generator),
|
||||
&[index],
|
||||
name.unwrap_or_default(),
|
||||
)
|
||||
.unwrap()
|
||||
};
|
||||
// TODO: Current implementation is transparent
|
||||
// Current implementation is transparent - The returned pointer type is
|
||||
// already cast into the expected type, allowing for immediately
|
||||
// load/store.
|
||||
ctx.builder
|
||||
.build_pointer_cast(
|
||||
ptr,
|
||||
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
||||
.unwrap()
|
||||
.ptr_type(AddressSpace::default()),
|
||||
"",
|
||||
name.unwrap_or_default(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user