From 5a08736b51d3a44df886877c24a8e118c9bfe305 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 28 Nov 2024 11:07:14 +0800 Subject: [PATCH] [core] codegen/ndarray: Cleanup - Remove redundant size param - Add *_field functions --- nac3core/src/codegen/types/ndarray/mod.rs | 8 +--- nac3core/src/codegen/values/ndarray/mod.rs | 47 +++++++++------------- 2 files changed, 22 insertions(+), 33 deletions(-) diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 2a37ae51..e11488a8 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -110,12 +110,8 @@ impl<'ctx> NDArrayType<'ctx> { // TODO: Move this into e.g. StructProxyType #[must_use] - pub fn get_fields( - &self, - ctx: &'ctx Context, - llvm_usize: IntType<'ctx>, - ) -> NDArrayStructFields<'ctx> { - Self::fields(ctx, llvm_usize) + pub fn get_fields(&self, ctx: &'ctx Context) -> NDArrayStructFields<'ctx> { + Self::fields(ctx, self.llvm_usize) } /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 38c6a980..458d771e 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -12,7 +12,7 @@ use crate::codegen::{ irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, llvm_intrinsics::call_int_umin, stmt::gen_for_callback_incrementing, - types::NDArrayType, + types::{structure::StructField, NDArrayType}, CodeGenContext, CodeGenerator, }; @@ -48,12 +48,13 @@ impl<'ctx> NDArrayValue<'ctx> { NDArrayValue { value: ptr, dtype, llvm_usize, name } } + fn ndims_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields(ctx.ctx).ndims + } + /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.get_type() - .get_fields(ctx.ctx, self.llvm_usize) - .ndims - .ptr_by_gep(ctx, self.value, self.name) + self.ndims_field(ctx).ptr_by_gep(ctx, self.value, self.name) } /// Stores the number of dimensions `ndims` into this instance. @@ -75,18 +76,19 @@ impl<'ctx> NDArrayValue<'ctx> { ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() } + fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields(ctx.ctx).shape + } + /// Returns the double-indirection pointer to the `shape` array, as if by calling /// `getelementptr` on the field. fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.get_type() - .get_fields(ctx.ctx, self.llvm_usize) - .shape - .ptr_by_gep(ctx, self.value, self.name) + self.shape_field(ctx).ptr_by_gep(ctx, self.value, self.name) } /// Stores the array of dimension sizes `dims` into this instance. fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap(); + self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name); } /// Convenience method for creating a new array storing dimension sizes with the given `size`. @@ -105,13 +107,14 @@ impl<'ctx> NDArrayValue<'ctx> { NDArrayShapeProxy(self) } + fn data_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields(ctx.ctx).data + } + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.get_type() - .get_fields(ctx.ctx, self.llvm_usize) - .data - .ptr_by_gep(ctx, self.value, self.name) + self.data_field(ctx).ptr_by_gep(ctx, self.value, self.name) } /// Stores the array of data elements `data` into this instance. @@ -120,7 +123,7 @@ impl<'ctx> NDArrayValue<'ctx> { .builder .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap(); + self.data_field(ctx).set(ctx, self.as_base_value(), data.into_pointer_value(), self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -186,12 +189,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.ptr_to_shape(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() + self.0.shape_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) } fn size( @@ -283,12 +281,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.ptr_to_data(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() + self.0.data_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) } fn size(