diff --git a/nac3core/src/codegen/types/ndarray.rs b/nac3core/src/codegen/types/ndarray.rs index d6887322..e1795f07 100644 --- a/nac3core/src/codegen/types/ndarray.rs +++ b/nac3core/src/codegen/types/ndarray.rs @@ -1,11 +1,15 @@ use inkwell::{ context::Context, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + values::{IntValue, PointerValue}, AddressSpace, }; +use itertools::Itertools; -use super::ProxyType; +use super::{ + structure::{FieldIndexCounter, StructField, StructFields}, + ProxyType, +}; use crate::codegen::{ values::{ArraySliceValue, NDArrayValue, ProxyValue}, {CodeGenContext, CodeGenerator}, @@ -19,6 +23,51 @@ pub struct NDArrayType<'ctx> { llvm_usize: IntType<'ctx>, } +#[derive(PartialEq, Eq, Clone, Copy)] +pub struct NDArrayStructFields<'ctx> { + pub data: StructField<'ctx, PointerValue<'ctx>>, + pub itemsize: StructField<'ctx, IntValue<'ctx>>, + pub ndims: StructField<'ctx, IntValue<'ctx>>, + pub shape: StructField<'ctx, PointerValue<'ctx>>, + pub strides: StructField<'ctx, PointerValue<'ctx>>, +} + +impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> { + fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { + let mut counter = FieldIndexCounter::default(); + + NDArrayStructFields { + data: StructField::create( + &mut counter, + "data", + ctx.i8_type().ptr_type(AddressSpace::default()), + ), + itemsize: StructField::create(&mut counter, "itemsize", llvm_usize), + ndims: StructField::create(&mut counter, "ndims", llvm_usize), + shape: StructField::create( + &mut counter, + "shape", + llvm_usize.ptr_type(AddressSpace::default()), + ), + strides: StructField::create( + &mut counter, + "strides", + llvm_usize.ptr_type(AddressSpace::default()), + ), + } + } + + fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> { + vec![ + self.data.into(), + self.itemsize.into(), + self.ndims.into(), + self.shape.into(), + self.strides.into(), + ] + } +} + impl<'ctx> NDArrayType<'ctx> { /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. pub fn is_representable( @@ -86,19 +135,39 @@ impl<'ctx> NDArrayType<'ctx> { Ok(()) } + // TODO: Move this into e.g. StructProxyType + #[must_use] + fn fields( + ctx: &'ctx Context, + llvm_usize: IntType<'ctx>, + ) -> NDArrayStructFields<'ctx> { + NDArrayStructFields::new(ctx, llvm_usize) + } + + // TODO: Move this into e.g. StructProxyType + #[must_use] + pub fn get_fields( + &self, + ctx: &'ctx Context, + llvm_usize: IntType<'ctx>, + ) -> NDArrayStructFields<'ctx> { + Self::fields(ctx, llvm_usize) + } + /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { // struct NDArray { num_dims: size_t, dims: size_t*, data: i8* } // - // * num_dims: Number of dimensions in the array - // * dims: Pointer to an array containing the size of each dimension - // * data: Pointer to an array containing the array data - let field_tys = [ - llvm_usize.into(), - llvm_usize.ptr_type(AddressSpace::default()).into(), - ctx.i8_type().ptr_type(AddressSpace::default()).into(), - ]; + // * data : Pointer to an array containing the array data + // * itemsize: The size of each NDArray elements in bytes + // * ndims : Number of dimensions in the array + // * shape : Pointer to an array containing the shape of the NDArray + // * strides : Pointer to an array indicating the number of bytes between each element at a dimension + let field_tys = Self::fields(ctx, llvm_usize) + .into_iter() + .map(|field| field.1) + .collect_vec(); ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } diff --git a/nac3core/src/codegen/values/ndarray.rs b/nac3core/src/codegen/values/ndarray.rs index 732ed0d3..b4fcb89d 100644 --- a/nac3core/src/codegen/values/ndarray.rs +++ b/nac3core/src/codegen/values/ndarray.rs @@ -3,7 +3,7 @@ use inkwell::{ values::{BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, }; - +use itertools::Itertools; use super::{ ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, @@ -12,7 +12,7 @@ use crate::codegen::{ irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, llvm_intrinsics::call_int_umin, stmt::gen_for_callback_incrementing, - types::NDArrayType, + types::{NDArrayType, structure::StructFields}, CodeGenContext, CodeGenerator, }; @@ -48,90 +48,25 @@ impl<'ctx> NDArrayValue<'ctx> { NDArrayValue { value: ptr, dtype, llvm_usize, name } } - /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. - fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the number of dimensions `ndims` into this instance. - pub fn store_ndims( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ndims: IntValue<'ctx>, - ) { - debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); - - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_store(pndims, ndims).unwrap(); - } - - /// Returns the number of dimensions of this `NDArray` as a value. - pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() - } - - /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` - /// on the field. - fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the array of dimension sizes `dims` into this instance. - fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap(); - } - - /// Convenience method for creating a new array storing dimension sizes with the given `size`. - pub fn create_dim_sizes( - &self, - ctx: &CodeGenContext<'ctx, '_>, - llvm_usize: IntType<'ctx>, - size: IntValue<'ctx>, - ) { - self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap()); - } - - /// Returns a proxy object to the field storing the size of each dimension of this `NDArray`. - #[must_use] - pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> { - NDArrayDimsProxy(self) - } - /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let llvm_i32 = ctx.ctx.i32_type(); let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); + let field_offset = self + .get_type() + .get_fields(ctx.ctx, self.llvm_usize) + .into_iter() + .find_position(|field| field.0 == "data") + .unwrap() + .0 as u64; + unsafe { ctx.builder .build_in_bounds_gep( self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], + &[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)], var_name.as_str(), ) .unwrap() @@ -171,6 +106,123 @@ impl<'ctx> NDArrayValue<'ctx> { pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> { NDArrayDataProxy(self) } + + /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. + fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.get_type() + .get_fields(ctx.ctx, self.llvm_usize) + .ndims + .ptr_by_gep(ctx, self.as_base_value(), self.name) + } + + /// Stores the number of dimensions `ndims` into this instance. + pub fn store_ndims( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ndims: IntValue<'ctx>, + ) { + debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); + + let pndims = self.ptr_to_ndims(ctx); + ctx.builder.build_store(pndims, ndims).unwrap(); + } + + /// Returns the number of dimensions of this `NDArray` as a value. + pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + let pndims = self.ptr_to_ndims(ctx); + ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() + } + + /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` + /// on the field. + fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.name.map(|v| format!("{v}.itemsize.addr")).unwrap_or_default(); + + let field_offset = self + .get_type() + .get_fields(ctx.ctx, self.llvm_usize) + .into_iter() + .find_position(|field| field.0 == "itemsize") + .unwrap() + .0 as u64; + + unsafe { + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, false)], + var_name.as_str(), + ) + .unwrap() + } + } + + /// Stores the size of each element `itemsize` into this instance. + pub fn store_itemsize( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ndims: IntValue<'ctx>, + ) { + debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); + + let pndims = self.ptr_to_ndims(ctx); + ctx.builder.build_store(pndims, ndims).unwrap(); + } + + /// Returns the size of each element of this `NDArray` as a value. + pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + let pndims = self.ptr_to_ndims(ctx); + ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() + } + + /// Returns the double-indirection pointer to the `shape` array, as if by calling + /// `getelementptr` on the field. + fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.name.map(|v| format!("{v}.shape.addr")).unwrap_or_default(); + + let field_offset = self + .get_type() + .get_fields(ctx.ctx, self.llvm_usize) + .into_iter() + .find_position(|field| field.0 == "shape") + .unwrap() + .0 as u64; + + unsafe { + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)], + var_name.as_str(), + ) + .unwrap() + } + } + + /// Stores the array of dimension sizes `dims` into this instance. + fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { + ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap(); + } + + /// Convenience method for creating a new array storing dimension sizes with the given `size`. + pub fn create_dim_sizes( + &self, + ctx: &CodeGenContext<'ctx, '_>, + llvm_usize: IntType<'ctx>, + size: IntValue<'ctx>, + ) { + self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap()); + } + + /// Returns a proxy object to the field storing the size of each dimension of this `NDArray`. + #[must_use] + pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> { + NDArrayDimsProxy(self) + } } impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {