[core] codegen/ndarray: Cleanup

- Remove redundant size param
- Add *_field functions
This commit is contained in:
David Mak 2024-11-28 11:07:14 +08:00
parent 6732111430
commit 5a08736b51
2 changed files with 22 additions and 33 deletions

View File

@ -110,12 +110,8 @@ impl<'ctx> NDArrayType<'ctx> {
// TODO: Move this into e.g. StructProxyType // TODO: Move this into e.g. StructProxyType
#[must_use] #[must_use]
pub fn get_fields( pub fn get_fields(&self, ctx: &'ctx Context) -> NDArrayStructFields<'ctx> {
&self, Self::fields(ctx, self.llvm_usize)
ctx: &'ctx Context,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize)
} }
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`. /// Creates an LLVM type corresponding to the expected structure of an `NDArray`.

View File

@ -12,7 +12,7 @@ use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin, llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
types::NDArrayType, types::{structure::StructField, NDArrayType},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
@ -48,12 +48,13 @@ impl<'ctx> NDArrayValue<'ctx> {
NDArrayValue { value: ptr, dtype, llvm_usize, name } NDArrayValue { value: ptr, dtype, llvm_usize, name }
} }
fn ndims_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).ndims
}
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`. /// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type() self.ndims_field(ctx).ptr_by_gep(ctx, self.value, self.name)
.get_fields(ctx.ctx, self.llvm_usize)
.ndims
.ptr_by_gep(ctx, self.value, self.name)
} }
/// Stores the number of dimensions `ndims` into this instance. /// Stores the number of dimensions `ndims` into this instance.
@ -75,18 +76,19 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
} }
fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).shape
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling /// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field. /// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type() self.shape_field(ctx).ptr_by_gep(ctx, self.value, self.name)
.get_fields(ctx.ctx, self.llvm_usize)
.shape
.ptr_by_gep(ctx, self.value, self.name)
} }
/// Stores the array of dimension sizes `dims` into this instance. /// Stores the array of dimension sizes `dims` into this instance.
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap(); self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name);
} }
/// Convenience method for creating a new array storing dimension sizes with the given `size`. /// Convenience method for creating a new array storing dimension sizes with the given `size`.
@ -105,13 +107,14 @@ impl<'ctx> NDArrayValue<'ctx> {
NDArrayShapeProxy(self) NDArrayShapeProxy(self)
} }
fn data_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).data
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type() self.data_field(ctx).ptr_by_gep(ctx, self.value, self.name)
.get_fields(ctx.ctx, self.llvm_usize)
.data
.ptr_by_gep(ctx, self.value, self.name)
} }
/// Stores the array of data elements `data` into this instance. /// Stores the array of data elements `data` into this instance.
@ -120,7 +123,7 @@ impl<'ctx> NDArrayValue<'ctx> {
.builder .builder
.build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap(); .unwrap();
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap(); self.data_field(ctx).set(ctx, self.as_base_value(), data.into_pointer_value(), self.name);
} }
/// Convenience method for creating a new array storing data elements with the given element /// Convenience method for creating a new array storing data elements with the given element
@ -186,12 +189,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
_: &G, _: &G,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); self.0.shape_field(ctx).get(ctx, self.0.as_base_value(), self.0.name)
ctx.builder
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
} }
fn size<G: CodeGenerator + ?Sized>( fn size<G: CodeGenerator + ?Sized>(
@ -283,12 +281,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
_: &G, _: &G,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); self.0.data_field(ctx).get(ctx, self.0.as_base_value(), self.0.name)
ctx.builder
.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
} }
fn size<G: CodeGenerator + ?Sized>( fn size<G: CodeGenerator + ?Sized>(