diff --git a/nac3core/src/codegen/types/ndarray.rs b/nac3core/src/codegen/types/ndarray.rs index 98bcdb62..3f25f828 100644 --- a/nac3core/src/codegen/types/ndarray.rs +++ b/nac3core/src/codegen/types/ndarray.rs @@ -15,6 +15,7 @@ use crate::codegen::{ #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct NDArrayType<'ctx> { ty: PointerType<'ctx>, + dtype: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, } @@ -112,7 +113,14 @@ impl<'ctx> NDArrayType<'ctx> { pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); - NDArrayType { ty: ptr_ty, llvm_usize } + NDArrayType { + ty: ptr_ty, + dtype: ptr_ty + .get_element_type() + .try_into() + .expect("Expected BasicTypeEnum for dtype of NDArray"), + llvm_usize, + } } /// Returns the type of the `size` field of this `ndarray` type. @@ -128,14 +136,8 @@ impl<'ctx> NDArrayType<'ctx> { /// Returns the element type of this `ndarray` type. #[must_use] - pub fn element_type(&self) -> AnyTypeEnum<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(2) - .map(BasicTypeEnum::into_pointer_type) - .map(PointerType::get_element_type) - .unwrap() + pub fn element_type(&self) -> BasicTypeEnum<'ctx> { + self.dtype } }