Compare commits

..

11 Commits

Author SHA1 Message Date
David Mak c862dbd861 [core] WIP - Implemented construct_* for NDArrays 2024-11-15 15:29:22 +08:00
David Mak 684aafe54c [core] Add itemsize and strides to NDArray struct 2024-11-15 15:28:37 +08:00
David Mak a770d9f415 [core] coregen/types: Implement StructFields for NDArray
Also rename some fields to better align with their naming in numpy.
2024-11-15 15:27:45 +08:00
David Mak 6f702ac250 [core] codegen/types: Implement NDArray in terms of i8*
Better aligns with the future implementation of ndstrides.
2024-11-15 15:19:23 +08:00
David Mak 64ec66d3dd [core] irrt: Break IRRT into several impl files
Each IRRT file is now mapped to one Rust file.
2024-11-15 15:19:23 +08:00
David Mak 3c336b0ea5 [core] irrt: Update some IRRT implementation
- Change CSlice to use `void*` for better pointer compatibility
- Remove __STDC_VERSION__ guard
- Only include impl *.hpp files in irrt.cpp
- Refactor typedef to using declaration
- Add missing ``// namespace`
2024-11-15 15:19:23 +08:00
David Mak 9fab65109a [core] codegen: Add dtype to NDArrayType
We won't have this once NDArray is refactored to strided impl.
2024-11-15 15:19:23 +08:00
David Mak 3a5e7a98b1 [core] codegen: Add Self::llvm_type to all type abstractions 2024-11-15 15:19:23 +08:00
lyken d3fb4204e7 core/irrt: fix exception.hpp C++ castings 2024-11-15 15:19:23 +08:00
lyken 8631fc8b58 core/toplevel/helper: add {extract,create}_ndims 2024-11-15 15:19:23 +08:00
David Mak 4b666f8706 [core] codegen/types: Implement StructField{,s}
Loosely based on FieldTraversal.
2024-11-15 15:19:06 +08:00
4 changed files with 180 additions and 287 deletions

View File

@ -7,15 +7,16 @@ namespace {
* @brief The NDArray object * @brief The NDArray object
* *
* Official numpy implementation: * Official numpy implementation:
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst * https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface
*
* Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The
* difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is
* that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject`
* does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its
* `data`. There are also minor differences in the struct layout.
*/ */
template<typename SizeT> template<typename SizeT>
struct NDArray { struct NDArray {
/**
* @brief The underlying data this `ndarray` is pointing to.
*/
void* data;
/** /**
* @brief The number of bytes of a single element in `data`. * @brief The number of bytes of a single element in `data`.
*/ */
@ -41,5 +42,10 @@ struct NDArray {
* Note that `strides` can have negative values or contain 0. * Note that `strides` can have negative values or contain 0.
*/ */
SizeT* strides; SizeT* strides;
/**
* @brief The underlying data this `ndarray` is pointing to.
*/
void* data;
}; };
} // namespace } // namespace

View File

@ -1,5 +1,5 @@
use inkwell::{ use inkwell::{
context::Context, context::{AsContextRef, Context, ContextRef},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue}, values::{IntValue, PointerValue},
AddressSpace, AddressSpace,
@ -31,23 +31,19 @@ pub struct NDArrayType<'ctx> {
#[derive(PartialEq, Eq, Clone, Copy)] #[derive(PartialEq, Eq, Clone, Copy)]
pub struct NDArrayStructFields<'ctx> { pub struct NDArrayStructFields<'ctx> {
pub data: StructField<'ctx, PointerValue<'ctx>>,
pub itemsize: StructField<'ctx, IntValue<'ctx>>, pub itemsize: StructField<'ctx, IntValue<'ctx>>,
pub ndims: StructField<'ctx, IntValue<'ctx>>, pub ndims: StructField<'ctx, IntValue<'ctx>>,
pub shape: StructField<'ctx, PointerValue<'ctx>>, pub shape: StructField<'ctx, PointerValue<'ctx>>,
pub strides: StructField<'ctx, PointerValue<'ctx>>, pub strides: StructField<'ctx, PointerValue<'ctx>>,
pub data: StructField<'ctx, PointerValue<'ctx>>,
} }
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> { impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
let ctx = unsafe { ContextRef::new(ctx.as_ctx_ref()) };
let mut counter = FieldIndexCounter::default(); let mut counter = FieldIndexCounter::default();
NDArrayStructFields { NDArrayStructFields {
data: StructField::create(
&mut counter,
"data",
ctx.i8_type().ptr_type(AddressSpace::default()),
),
itemsize: StructField::create(&mut counter, "itemsize", llvm_usize), itemsize: StructField::create(&mut counter, "itemsize", llvm_usize),
ndims: StructField::create(&mut counter, "ndims", llvm_usize), ndims: StructField::create(&mut counter, "ndims", llvm_usize),
shape: StructField::create( shape: StructField::create(
@ -60,16 +56,21 @@ impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
"strides", "strides",
llvm_usize.ptr_type(AddressSpace::default()), llvm_usize.ptr_type(AddressSpace::default()),
), ),
data: StructField::create(
&mut counter,
"data",
ctx.i8_type().ptr_type(AddressSpace::default()),
),
} }
} }
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> { fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
vec![ vec![
self.data.into(),
self.itemsize.into(), self.itemsize.into(),
self.ndims.into(), self.ndims.into(),
self.shape.into(), self.shape.into(),
self.strides.into(), self.strides.into(),
self.data.into(),
] ]
} }
} }
@ -80,106 +81,45 @@ impl<'ctx> NDArrayType<'ctx> {
llvm_ty: PointerType<'ctx>, llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
) -> Result<(), String> { ) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let llvm_expected_ty = Self::fields(ctx, llvm_usize).into_vec();
let llvm_ndarray_ty = llvm_ty.get_element_type(); let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
}; };
if llvm_ndarray_ty.count_fields() != 5 { if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() {
return Err(format!( return Err(format!(
"Expected 5 fields in `NDArray`, got {}", "Expected {} fields in `NDArray`, got {}",
llvm_expected_ty.len(),
llvm_ndarray_ty.count_fields() llvm_ndarray_ty.count_fields()
)); ));
} }
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap(); llvm_expected_ty
let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else { .iter()
return Err(format!("Expected pointer type for `ndarray.data`, got {ndarray_data_ty}")); .enumerate()
}; .map(|(i, expected_ty)| {
let ndarray_data = ndarray_pdata.get_element_type(); (expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap())
let Ok(ndarray_data) = IntType::try_from(ndarray_data) else { })
return Err(format!( .try_for_each(|(expected_ty, actual_ty)| {
"Expected pointer-to-int type for `ndarray.data`, got pointer-to-{ndarray_data}" if expected_ty == actual_ty {
)); Ok(())
}; } else {
if ndarray_data.get_bit_width() != 8 { Err(format!("Expected {expected_ty} for `ndarray.data`, got {actual_ty}"))
return Err(format!( }
"Expected pointer-to-8-bit int type for `ndarray.data`, got pointer-to-{}-bit int", })?;
ndarray_data.get_bit_width()
));
}
let ndarray_itemsize_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
let Ok(ndarray_itemsize_ty) = IntType::try_from(ndarray_itemsize_ty) else {
return Err(format!(
"Expected int type for `ndarray.itemsize`, got {ndarray_itemsize_ty}"
));
};
if ndarray_itemsize_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected {}-bit int type for `ndarray.itemsize`, got {}-bit int",
llvm_usize.get_bit_width(),
ndarray_itemsize_ty.get_bit_width()
));
}
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
return Err(format!("Expected int type for `ndarray.ndims`, got {ndarray_ndims_ty}"));
};
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected {}-bit int type for `ndarray.ndims`, got {}-bit int",
llvm_usize.get_bit_width(),
ndarray_ndims_ty.get_bit_width()
));
}
let ndarray_shape_ty = llvm_ndarray_ty.get_field_type_at_index(3).unwrap();
let Ok(ndarray_pshape) = PointerType::try_from(ndarray_shape_ty) else {
return Err(format!(
"Expected pointer type for `ndarray.shape`, got {ndarray_shape_ty}"
));
};
let ndarray_shape = ndarray_pshape.get_element_type();
let Ok(ndarray_shape) = IntType::try_from(ndarray_shape) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.shape`, got pointer-to-{ndarray_shape}"
));
};
if ndarray_shape.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected pointer-to-{}-bit int type for `ndarray.shape`, got pointer-to-{}-bit int",
llvm_usize.get_bit_width(),
ndarray_shape.get_bit_width()
));
}
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(4).unwrap();
let Ok(ndarray_pstrides) = PointerType::try_from(ndarray_dims_ty) else {
return Err(format!(
"Expected pointer type for `ndarray.strides`, got {ndarray_dims_ty}"
));
};
let ndarray_strides = ndarray_pstrides.get_element_type();
let Ok(ndarray_strides) = IntType::try_from(ndarray_strides) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.strides`, got pointer-to-{ndarray_strides}"
));
};
if ndarray_strides.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected pointer-to-{}-bit int type for `ndarray.strides`, got pointer-to-{}-bit int",
llvm_usize.get_bit_width(),
ndarray_strides.get_bit_width()
));
}
Ok(()) Ok(())
} }
// TODO: Move this into e.g. StructProxyType // TODO: Move this into e.g. StructProxyType
#[must_use] #[must_use]
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> { fn fields(
ctx: impl AsContextRef<'ctx>,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize) NDArrayStructFields::new(ctx, llvm_usize)
} }
@ -187,7 +127,7 @@ impl<'ctx> NDArrayType<'ctx> {
#[must_use] #[must_use]
pub fn get_fields( pub fn get_fields(
&self, &self,
ctx: &'ctx Context, ctx: impl AsContextRef<'ctx>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> { ) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize) Self::fields(ctx, llvm_usize)

View File

@ -1,7 +1,7 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use inkwell::{ use inkwell::{
context::Context, context::AsContextRef,
types::{BasicTypeEnum, IntType}, types::{BasicTypeEnum, IntType},
values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
}; };
@ -22,7 +22,7 @@ use crate::codegen::CodeGenContext;
/// ``` /// ```
pub trait StructFields<'ctx>: Eq + Copy { pub trait StructFields<'ctx>: Eq + Copy {
/// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types. /// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types.
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self; fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self;
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in /// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
/// the type definition. /// the type definition.
@ -94,6 +94,19 @@ where
StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData } StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData }
} }
/// Creates an instance of [`StructField`] with a given index.
///
/// * `index` - The index of this field within its enclosing structure.
/// * `name` - Name of the field.
/// * `ty` - The type of this field.
pub(super) fn create_at(
index: u32,
name: &'static str,
ty: impl Into<BasicTypeEnum<'ctx>>,
) -> Self {
StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
}
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32 /// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
/// {idx...}, i32 {self.index}`. /// {idx...}, i32 {self.index}`.
pub fn ptr_by_array_gep( pub fn ptr_by_array_gep(

View File

@ -3,7 +3,6 @@ use inkwell::{
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
use itertools::Itertools;
use super::{ use super::{
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator, ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
@ -13,7 +12,7 @@ use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin, llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
types::{structure::StructFields, NDArrayType}, types::NDArrayType,
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
@ -51,29 +50,127 @@ impl<'ctx> NDArrayValue<'ctx> {
NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name } NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name }
} }
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.ndims
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the pointer to the field storing the size of each element of this `NDArray`.
fn ptr_to_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.itemsize
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the size of each element `itemsize` into this instance.
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the size of each element of this `NDArray` as a value.
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.shape
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_shape(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
NDArrayShapeProxy(self)
}
/// Returns the double-indirection pointer to the `stride` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.strides
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing the stride with the given `size`.
pub fn create_strides(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
#[must_use]
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
NDArrayStridesProxy(self)
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); self.get_type()
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize) .get_fields(ctx.ctx, self.llvm_usize)
.into_iter() .data
.find_position(|field| field.0 == "data") .ptr_by_gep(ctx, self.value, self.name)
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
var_name.as_str(),
)
.unwrap()
}
} }
/// Stores the array of data elements `data` into this instance. /// Stores the array of data elements `data` into this instance.
@ -109,169 +206,6 @@ impl<'ctx> NDArrayValue<'ctx> {
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> { pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
NDArrayDataProxy(self) NDArrayDataProxy(self)
} }
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type().get_fields(ctx.ctx, self.llvm_usize).ndims.ptr_by_gep(
ctx,
self.as_base_value(),
self.name,
)
}
/// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the pointer to the field storing the size of each element of this `NDArray`.
fn ptr_to_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.itemsize.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "itemsize")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, false)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the size of each element `itemsize` into this instance.
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the size of each element of this `NDArray` as a value.
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.shape.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "shape")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_shape(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
NDArrayShapeProxy(self)
}
/// Returns the double-indirection pointer to the `stride` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.strides.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "strides")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing the stride with the given `size`.
pub fn create_strides(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
#[must_use]
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
NDArrayStridesProxy(self)
}
} }
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {