[core] codegen: Add dtype to NDArrayType
We won't have this once NDArray is refactored to strided impl.
This commit is contained in:
parent
1ba2e287a6
commit
1a535db558
@ -15,6 +15,7 @@ use crate::codegen::{
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct NDArrayType<'ctx> {
|
||||
ty: PointerType<'ctx>,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
@ -112,7 +113,14 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
|
||||
|
||||
NDArrayType { ty: ptr_ty, llvm_usize }
|
||||
NDArrayType {
|
||||
ty: ptr_ty,
|
||||
dtype: ptr_ty
|
||||
.get_element_type()
|
||||
.try_into()
|
||||
.expect("Expected BasicTypeEnum for dtype of NDArray"),
|
||||
llvm_usize,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the type of the `size` field of this `ndarray` type.
|
||||
@ -128,14 +136,8 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
|
||||
/// Returns the element type of this `ndarray` type.
|
||||
#[must_use]
|
||||
pub fn element_type(&self) -> AnyTypeEnum<'ctx> {
|
||||
self.as_base_type()
|
||||
.get_element_type()
|
||||
.into_struct_type()
|
||||
.get_field_type_at_index(2)
|
||||
.map(BasicTypeEnum::into_pointer_type)
|
||||
.map(PointerType::get_element_type)
|
||||
.unwrap()
|
||||
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
|
||||
self.dtype
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user