forked from M-Labs/nac3
[core] codegen/ndarray: Cleanup
- Remove redundant size param - Add *_field functions
This commit is contained in:
parent
363e1a1f84
commit
355c051886
@ -110,12 +110,8 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
|
||||
// TODO: Move this into e.g. StructProxyType
|
||||
#[must_use]
|
||||
pub fn get_fields(
|
||||
&self,
|
||||
ctx: &'ctx Context,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> NDArrayStructFields<'ctx> {
|
||||
Self::fields(ctx, llvm_usize)
|
||||
pub fn get_fields(&self, ctx: &'ctx Context) -> NDArrayStructFields<'ctx> {
|
||||
Self::fields(ctx, self.llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||
|
@ -12,7 +12,7 @@ use crate::codegen::{
|
||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
||||
llvm_intrinsics::call_int_umin,
|
||||
stmt::gen_for_callback_incrementing,
|
||||
types::NDArrayType,
|
||||
types::{structure::StructField, NDArrayType},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
@ -48,12 +48,13 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
NDArrayValue { value: ptr, dtype, llvm_usize, name }
|
||||
}
|
||||
|
||||
fn ndims_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> {
|
||||
self.get_type().get_fields(ctx.ctx).ndims
|
||||
}
|
||||
|
||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
self.get_type()
|
||||
.get_fields(ctx.ctx, self.llvm_usize)
|
||||
.ndims
|
||||
.ptr_by_gep(ctx, self.value, self.name)
|
||||
self.ndims_field(ctx).ptr_by_gep(ctx, self.value, self.name)
|
||||
}
|
||||
|
||||
/// Stores the number of dimensions `ndims` into this instance.
|
||||
@ -75,18 +76,19 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
||||
}
|
||||
|
||||
fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||
self.get_type().get_fields(ctx.ctx).shape
|
||||
}
|
||||
|
||||
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
||||
/// `getelementptr` on the field.
|
||||
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
self.get_type()
|
||||
.get_fields(ctx.ctx, self.llvm_usize)
|
||||
.shape
|
||||
.ptr_by_gep(ctx, self.value, self.name)
|
||||
self.shape_field(ctx).ptr_by_gep(ctx, self.value, self.name)
|
||||
}
|
||||
|
||||
/// Stores the array of dimension sizes `dims` into this instance.
|
||||
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
||||
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
|
||||
self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name);
|
||||
}
|
||||
|
||||
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
|
||||
@ -105,13 +107,14 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
NDArrayShapeProxy(self)
|
||||
}
|
||||
|
||||
fn data_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||
self.get_type().get_fields(ctx.ctx).data
|
||||
}
|
||||
|
||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||
/// on the field.
|
||||
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
self.get_type()
|
||||
.get_fields(ctx.ctx, self.llvm_usize)
|
||||
.data
|
||||
.ptr_by_gep(ctx, self.value, self.name)
|
||||
self.data_field(ctx).ptr_by_gep(ctx, self.value, self.name)
|
||||
}
|
||||
|
||||
/// Stores the array of data elements `data` into this instance.
|
||||
@ -120,7 +123,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
.builder
|
||||
.build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
|
||||
.unwrap();
|
||||
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap();
|
||||
self.data_field(ctx).set(ctx, self.as_base_value(), data.into_pointer_value(), self.name);
|
||||
}
|
||||
|
||||
/// Convenience method for creating a new array storing data elements with the given element
|
||||
@ -188,12 +191,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
||||
|
||||
ctx.builder
|
||||
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
self.0.shape_field(ctx).get(ctx, self.0.as_base_value(), self.0.name)
|
||||
}
|
||||
|
||||
fn size<G: CodeGenerator + ?Sized>(
|
||||
@ -285,12 +283,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
||||
|
||||
ctx.builder
|
||||
.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
self.0.data_field(ctx).get(ctx, self.0.as_base_value(), self.0.name)
|
||||
}
|
||||
|
||||
fn size<G: CodeGenerator + ?Sized>(
|
||||
|
Loading…
Reference in New Issue
Block a user