forked from M-Labs/nac3
[core] codegen/ndarray: Cleanup
- Remove redundant size param - Add *_field functions
This commit is contained in:
parent
363e1a1f84
commit
355c051886
@ -110,12 +110,8 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||||||
|
|
||||||
// TODO: Move this into e.g. StructProxyType
|
// TODO: Move this into e.g. StructProxyType
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn get_fields(
|
pub fn get_fields(&self, ctx: &'ctx Context) -> NDArrayStructFields<'ctx> {
|
||||||
&self,
|
Self::fields(ctx, self.llvm_usize)
|
||||||
ctx: &'ctx Context,
|
|
||||||
llvm_usize: IntType<'ctx>,
|
|
||||||
) -> NDArrayStructFields<'ctx> {
|
|
||||||
Self::fields(ctx, llvm_usize)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||||
|
@ -12,7 +12,7 @@ use crate::codegen::{
|
|||||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
||||||
llvm_intrinsics::call_int_umin,
|
llvm_intrinsics::call_int_umin,
|
||||||
stmt::gen_for_callback_incrementing,
|
stmt::gen_for_callback_incrementing,
|
||||||
types::NDArrayType,
|
types::{structure::StructField, NDArrayType},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -48,12 +48,13 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
NDArrayValue { value: ptr, dtype, llvm_usize, name }
|
NDArrayValue { value: ptr, dtype, llvm_usize, name }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn ndims_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> {
|
||||||
|
self.get_type().get_fields(ctx.ctx).ndims
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||||
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
self.get_type()
|
self.ndims_field(ctx).ptr_by_gep(ctx, self.value, self.name)
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
|
||||||
.ndims
|
|
||||||
.ptr_by_gep(ctx, self.value, self.name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stores the number of dimensions `ndims` into this instance.
|
/// Stores the number of dimensions `ndims` into this instance.
|
||||||
@ -75,18 +76,19 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||||
|
self.get_type().get_fields(ctx.ctx).shape
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
||||||
/// `getelementptr` on the field.
|
/// `getelementptr` on the field.
|
||||||
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
self.get_type()
|
self.shape_field(ctx).ptr_by_gep(ctx, self.value, self.name)
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
|
||||||
.shape
|
|
||||||
.ptr_by_gep(ctx, self.value, self.name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stores the array of dimension sizes `dims` into this instance.
|
/// Stores the array of dimension sizes `dims` into this instance.
|
||||||
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
||||||
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
|
self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
|
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
|
||||||
@ -105,13 +107,14 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
NDArrayShapeProxy(self)
|
NDArrayShapeProxy(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn data_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||||
|
self.get_type().get_fields(ctx.ctx).data
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||||
/// on the field.
|
/// on the field.
|
||||||
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
self.get_type()
|
self.data_field(ctx).ptr_by_gep(ctx, self.value, self.name)
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
|
||||||
.data
|
|
||||||
.ptr_by_gep(ctx, self.value, self.name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stores the array of data elements `data` into this instance.
|
/// Stores the array of data elements `data` into this instance.
|
||||||
@ -120,7 +123,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
.builder
|
.builder
|
||||||
.build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
|
.build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap();
|
self.data_field(ctx).set(ctx, self.as_base_value(), data.into_pointer_value(), self.name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convenience method for creating a new array storing data elements with the given element
|
/// Convenience method for creating a new array storing data elements with the given element
|
||||||
@ -188,12 +191,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
|
|||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
_: &G,
|
_: &G,
|
||||||
) -> PointerValue<'ctx> {
|
) -> PointerValue<'ctx> {
|
||||||
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
self.0.shape_field(ctx).get(ctx, self.0.as_base_value(), self.0.name)
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size<G: CodeGenerator + ?Sized>(
|
fn size<G: CodeGenerator + ?Sized>(
|
||||||
@ -285,12 +283,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
|||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
_: &G,
|
_: &G,
|
||||||
) -> PointerValue<'ctx> {
|
) -> PointerValue<'ctx> {
|
||||||
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
|
self.0.data_field(ctx).get(ctx, self.0.as_base_value(), self.0.name)
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
|
|
||||||
.map(BasicValueEnum::into_pointer_value)
|
|
||||||
.unwrap()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size<G: CodeGenerator + ?Sized>(
|
fn size<G: CodeGenerator + ?Sized>(
|
||||||
|
Loading…
Reference in New Issue
Block a user