[core] codegen: Add dtype to NDArrayType

We won't have this once NDArray is refactored to strided impl.
This commit is contained in:
David Mak 2024-11-08 15:49:01 +08:00
parent f276547800
commit 4df1ec3282
1 changed files with 11 additions and 9 deletions

View File

@ -15,6 +15,7 @@ use crate::codegen::{
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NDArrayType<'ctx> { pub struct NDArrayType<'ctx> {
ty: PointerType<'ctx>, ty: PointerType<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
} }
@ -112,7 +113,14 @@ impl<'ctx> NDArrayType<'ctx> {
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
NDArrayType { ty: ptr_ty, llvm_usize } NDArrayType {
ty: ptr_ty,
dtype: ptr_ty
.get_element_type()
.try_into()
.expect("Expected BasicTypeEnum for dtype of NDArray"),
llvm_usize,
}
} }
/// Returns the type of the `size` field of this `ndarray` type. /// Returns the type of the `size` field of this `ndarray` type.
@ -128,14 +136,8 @@ impl<'ctx> NDArrayType<'ctx> {
/// Returns the element type of this `ndarray` type. /// Returns the element type of this `ndarray` type.
#[must_use] #[must_use]
pub fn element_type(&self) -> AnyTypeEnum<'ctx> { pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.as_base_type() self.dtype
.get_element_type()
.into_struct_type()
.get_field_type_at_index(2)
.map(BasicTypeEnum::into_pointer_type)
.map(PointerType::get_element_type)
.unwrap()
} }
} }