[core] codegen/ndarray: Cleanup

- Remove redundant size param
- Add *_field functions
This commit is contained in:
David Mak 2024-11-28 11:07:14 +08:00
parent 6732111430
commit 5a08736b51
2 changed files with 22 additions and 33 deletions

View File

@ -110,12 +110,8 @@ impl<'ctx> NDArrayType<'ctx> {
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(
&self,
ctx: &'ctx Context,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize)
pub fn get_fields(&self, ctx: &'ctx Context) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, self.llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.

View File

@ -12,7 +12,7 @@ use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
types::NDArrayType,
types::{structure::StructField, NDArrayType},
CodeGenContext, CodeGenerator,
};
@ -48,12 +48,13 @@ impl<'ctx> NDArrayValue<'ctx> {
NDArrayValue { value: ptr, dtype, llvm_usize, name }
}
fn ndims_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).ndims
}
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.ndims
.ptr_by_gep(ctx, self.value, self.name)
self.ndims_field(ctx).ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the number of dimensions `ndims` into this instance.
@ -75,18 +76,19 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).shape
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.shape
.ptr_by_gep(ctx, self.value, self.name)
self.shape_field(ctx).ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name);
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
@ -105,13 +107,14 @@ impl<'ctx> NDArrayValue<'ctx> {
NDArrayShapeProxy(self)
}
fn data_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).data
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.data
.ptr_by_gep(ctx, self.value, self.name)
self.data_field(ctx).ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of data elements `data` into this instance.
@ -120,7 +123,7 @@ impl<'ctx> NDArrayValue<'ctx> {
.builder
.build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap();
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap();
self.data_field(ctx).set(ctx, self.as_base_value(), data.into_pointer_value(), self.name);
}
/// Convenience method for creating a new array storing data elements with the given element
@ -186,12 +189,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
self.0.shape_field(ctx).get(ctx, self.0.as_base_value(), self.0.name)
}
fn size<G: CodeGenerator + ?Sized>(
@ -283,12 +281,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
self.0.data_field(ctx).get(ctx, self.0.as_base_value(), self.0.name)
}
fn size<G: CodeGenerator + ?Sized>(