Compare commits

..

10 Commits

Author SHA1 Message Date
David Mak 0cca9dcfc4 [core] WIP - Implemented construct_* for NDArrays 2024-11-15 12:39:28 +08:00
David Mak b8b34e0b2f [core] coregen/types: Implement StructFields for NDArray 2024-11-15 12:39:25 +08:00
David Mak db0e1eb3d4 [core] codegen/types: Implement NDArray in terms of i8*
Better aligns with the future implementation of ndstrides.
2024-11-14 12:39:14 +08:00
David Mak 6bded88702 [core] irrt: Break IRRT into several impl files
Each IRRT file is now mapped to one Rust file.
2024-11-14 12:39:14 +08:00
David Mak f70b8132ed [core] irrt: Update some IRRT implementation
- Change CSlice to use `void*` for better pointer compatibility
- Remove __STDC_VERSION__ guard
- Only include impl *.hpp files in irrt.cpp
- Refactor typedef to using declaration
- Add missing ``// namespace`
2024-11-14 12:39:14 +08:00
David Mak 5c105fccff [core] codegen: Add dtype to NDArrayType
We won't have this once NDArray is refactored to strided impl.
2024-11-14 12:39:14 +08:00
David Mak 9e5bbdd60f [core] codegen: Add Self::llvm_type to all type abstractions 2024-11-14 12:39:14 +08:00
lyken 4679fdccb6 core/irrt: fix exception.hpp C++ castings 2024-11-14 12:39:14 +08:00
lyken 01952ce55f core/toplevel/helper: add {extract,create}_ndims 2024-11-14 12:39:14 +08:00
David Mak 91fb801c7d [core] codegen/types: Implement StructField{,s}
Loosely based on FieldTraversal.
2024-11-14 12:39:12 +08:00
4 changed files with 287 additions and 180 deletions

View File

@ -7,16 +7,15 @@ namespace {
* @brief The NDArray object * @brief The NDArray object
* *
* Official numpy implementation: * Official numpy implementation:
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface * https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
*
* Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The
* difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is
* that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject`
* does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its
* `data`. There are also minor differences in the struct layout.
*/ */
template<typename SizeT> template<typename SizeT>
struct NDArray { struct NDArray {
/**
* @brief The underlying data this `ndarray` is pointing to.
*/
void* data;
/** /**
* @brief The number of bytes of a single element in `data`. * @brief The number of bytes of a single element in `data`.
*/ */
@ -42,10 +41,5 @@ struct NDArray {
* Note that `strides` can have negative values or contain 0. * Note that `strides` can have negative values or contain 0.
*/ */
SizeT* strides; SizeT* strides;
/**
* @brief The underlying data this `ndarray` is pointing to.
*/
void* data;
}; };
} // namespace } // namespace

View File

@ -1,5 +1,5 @@
use inkwell::{ use inkwell::{
context::{AsContextRef, Context, ContextRef}, context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue}, values::{IntValue, PointerValue},
AddressSpace, AddressSpace,
@ -31,19 +31,23 @@ pub struct NDArrayType<'ctx> {
#[derive(PartialEq, Eq, Clone, Copy)] #[derive(PartialEq, Eq, Clone, Copy)]
pub struct NDArrayStructFields<'ctx> { pub struct NDArrayStructFields<'ctx> {
pub data: StructField<'ctx, PointerValue<'ctx>>,
pub itemsize: StructField<'ctx, IntValue<'ctx>>, pub itemsize: StructField<'ctx, IntValue<'ctx>>,
pub ndims: StructField<'ctx, IntValue<'ctx>>, pub ndims: StructField<'ctx, IntValue<'ctx>>,
pub shape: StructField<'ctx, PointerValue<'ctx>>, pub shape: StructField<'ctx, PointerValue<'ctx>>,
pub strides: StructField<'ctx, PointerValue<'ctx>>, pub strides: StructField<'ctx, PointerValue<'ctx>>,
pub data: StructField<'ctx, PointerValue<'ctx>>,
} }
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> { impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self { fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self {
let ctx = unsafe { ContextRef::new(ctx.as_ctx_ref()) };
let mut counter = FieldIndexCounter::default(); let mut counter = FieldIndexCounter::default();
NDArrayStructFields { NDArrayStructFields {
data: StructField::create(
&mut counter,
"data",
ctx.i8_type().ptr_type(AddressSpace::default()),
),
itemsize: StructField::create(&mut counter, "itemsize", llvm_usize), itemsize: StructField::create(&mut counter, "itemsize", llvm_usize),
ndims: StructField::create(&mut counter, "ndims", llvm_usize), ndims: StructField::create(&mut counter, "ndims", llvm_usize),
shape: StructField::create( shape: StructField::create(
@ -56,21 +60,16 @@ impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
"strides", "strides",
llvm_usize.ptr_type(AddressSpace::default()), llvm_usize.ptr_type(AddressSpace::default()),
), ),
data: StructField::create(
&mut counter,
"data",
ctx.i8_type().ptr_type(AddressSpace::default()),
),
} }
} }
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> { fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
vec![ vec![
self.data.into(),
self.itemsize.into(), self.itemsize.into(),
self.ndims.into(), self.ndims.into(),
self.shape.into(), self.shape.into(),
self.strides.into(), self.strides.into(),
self.data.into(),
] ]
} }
} }
@ -81,45 +80,106 @@ impl<'ctx> NDArrayType<'ctx> {
llvm_ty: PointerType<'ctx>, llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
) -> Result<(), String> { ) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let llvm_expected_ty = Self::fields(ctx, llvm_usize).into_vec();
let llvm_ndarray_ty = llvm_ty.get_element_type(); let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
}; };
if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() { if llvm_ndarray_ty.count_fields() != 5 {
return Err(format!( return Err(format!(
"Expected {} fields in `NDArray`, got {}", "Expected 5 fields in `NDArray`, got {}",
llvm_expected_ty.len(),
llvm_ndarray_ty.count_fields() llvm_ndarray_ty.count_fields()
)); ));
} }
llvm_expected_ty let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap();
.iter() let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else {
.enumerate() return Err(format!("Expected pointer type for `ndarray.data`, got {ndarray_data_ty}"));
.map(|(i, expected_ty)| { };
(expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap()) let ndarray_data = ndarray_pdata.get_element_type();
}) let Ok(ndarray_data) = IntType::try_from(ndarray_data) else {
.try_for_each(|(expected_ty, actual_ty)| { return Err(format!(
if expected_ty == actual_ty { "Expected pointer-to-int type for `ndarray.data`, got pointer-to-{ndarray_data}"
Ok(()) ));
} else { };
Err(format!("Expected {expected_ty} for `ndarray.data`, got {actual_ty}")) if ndarray_data.get_bit_width() != 8 {
return Err(format!(
"Expected pointer-to-8-bit int type for `ndarray.data`, got pointer-to-{}-bit int",
ndarray_data.get_bit_width()
));
}
let ndarray_itemsize_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
let Ok(ndarray_itemsize_ty) = IntType::try_from(ndarray_itemsize_ty) else {
return Err(format!(
"Expected int type for `ndarray.itemsize`, got {ndarray_itemsize_ty}"
));
};
if ndarray_itemsize_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected {}-bit int type for `ndarray.itemsize`, got {}-bit int",
llvm_usize.get_bit_width(),
ndarray_itemsize_ty.get_bit_width()
));
}
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
return Err(format!("Expected int type for `ndarray.ndims`, got {ndarray_ndims_ty}"));
};
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected {}-bit int type for `ndarray.ndims`, got {}-bit int",
llvm_usize.get_bit_width(),
ndarray_ndims_ty.get_bit_width()
));
}
let ndarray_shape_ty = llvm_ndarray_ty.get_field_type_at_index(3).unwrap();
let Ok(ndarray_pshape) = PointerType::try_from(ndarray_shape_ty) else {
return Err(format!(
"Expected pointer type for `ndarray.shape`, got {ndarray_shape_ty}"
));
};
let ndarray_shape = ndarray_pshape.get_element_type();
let Ok(ndarray_shape) = IntType::try_from(ndarray_shape) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.shape`, got pointer-to-{ndarray_shape}"
));
};
if ndarray_shape.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected pointer-to-{}-bit int type for `ndarray.shape`, got pointer-to-{}-bit int",
llvm_usize.get_bit_width(),
ndarray_shape.get_bit_width()
));
}
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(4).unwrap();
let Ok(ndarray_pstrides) = PointerType::try_from(ndarray_dims_ty) else {
return Err(format!(
"Expected pointer type for `ndarray.strides`, got {ndarray_dims_ty}"
));
};
let ndarray_strides = ndarray_pstrides.get_element_type();
let Ok(ndarray_strides) = IntType::try_from(ndarray_strides) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.strides`, got pointer-to-{ndarray_strides}"
));
};
if ndarray_strides.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected pointer-to-{}-bit int type for `ndarray.strides`, got pointer-to-{}-bit int",
llvm_usize.get_bit_width(),
ndarray_strides.get_bit_width()
));
} }
})?;
Ok(()) Ok(())
} }
// TODO: Move this into e.g. StructProxyType // TODO: Move this into e.g. StructProxyType
#[must_use] #[must_use]
fn fields( fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
ctx: impl AsContextRef<'ctx>,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize) NDArrayStructFields::new(ctx, llvm_usize)
} }
@ -127,7 +187,7 @@ impl<'ctx> NDArrayType<'ctx> {
#[must_use] #[must_use]
pub fn get_fields( pub fn get_fields(
&self, &self,
ctx: impl AsContextRef<'ctx>, ctx: &'ctx Context,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> { ) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize) Self::fields(ctx, llvm_usize)

View File

@ -1,7 +1,7 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use inkwell::{ use inkwell::{
context::AsContextRef, context::Context,
types::{BasicTypeEnum, IntType}, types::{BasicTypeEnum, IntType},
values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
}; };
@ -22,7 +22,7 @@ use crate::codegen::CodeGenContext;
/// ``` /// ```
pub trait StructFields<'ctx>: Eq + Copy { pub trait StructFields<'ctx>: Eq + Copy {
/// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types. /// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types.
fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self; fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self;
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in /// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
/// the type definition. /// the type definition.
@ -94,19 +94,6 @@ where
StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData } StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData }
} }
/// Creates an instance of [`StructField`] with a given index.
///
/// * `index` - The index of this field within its enclosing structure.
/// * `name` - Name of the field.
/// * `ty` - The type of this field.
pub(super) fn create_at(
index: u32,
name: &'static str,
ty: impl Into<BasicTypeEnum<'ctx>>,
) -> Self {
StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
}
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32 /// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
/// {idx...}, i32 {self.index}`. /// {idx...}, i32 {self.index}`.
pub fn ptr_by_array_gep( pub fn ptr_by_array_gep(

View File

@ -3,6 +3,7 @@ use inkwell::{
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
use itertools::Itertools;
use super::{ use super::{
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator, ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
@ -12,7 +13,7 @@ use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin, llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
types::NDArrayType, types::{structure::StructFields, NDArrayType},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
@ -50,127 +51,29 @@ impl<'ctx> NDArrayValue<'ctx> {
NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name } NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name }
} }
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.ndims
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the pointer to the field storing the size of each element of this `NDArray`.
fn ptr_to_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.itemsize
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the size of each element `itemsize` into this instance.
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the size of each element of this `NDArray` as a value.
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.shape
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_shape(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
NDArrayShapeProxy(self)
}
/// Returns the double-indirection pointer to the `stride` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.strides
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing the stride with the given `size`.
pub fn create_strides(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
#[must_use]
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
NDArrayStridesProxy(self)
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type() let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize) .get_fields(ctx.ctx, self.llvm_usize)
.data .into_iter()
.ptr_by_gep(ctx, self.value, self.name) .find_position(|field| field.0 == "data")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
var_name.as_str(),
)
.unwrap()
}
} }
/// Stores the array of data elements `data` into this instance. /// Stores the array of data elements `data` into this instance.
@ -206,6 +109,169 @@ impl<'ctx> NDArrayValue<'ctx> {
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> { pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
NDArrayDataProxy(self) NDArrayDataProxy(self)
} }
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type().get_fields(ctx.ctx, self.llvm_usize).ndims.ptr_by_gep(
ctx,
self.as_base_value(),
self.name,
)
}
/// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the pointer to the field storing the size of each element of this `NDArray`.
fn ptr_to_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.itemsize.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "itemsize")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, false)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the size of each element `itemsize` into this instance.
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the size of each element of this `NDArray` as a value.
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.shape.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "shape")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_shape(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
NDArrayShapeProxy(self)
}
/// Returns the double-indirection pointer to the `stride` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.strides.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "strides")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing the stride with the given `size`.
pub fn create_strides(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
#[must_use]
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
NDArrayStridesProxy(self)
}
} }
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {