Compare commits
11 Commits
0cca9dcfc4
...
c862dbd861
Author | SHA1 | Date |
---|---|---|
David Mak | c862dbd861 | |
David Mak | 684aafe54c | |
David Mak | a770d9f415 | |
David Mak | 6f702ac250 | |
David Mak | 64ec66d3dd | |
David Mak | 3c336b0ea5 | |
David Mak | 9fab65109a | |
David Mak | 3a5e7a98b1 | |
lyken | d3fb4204e7 | |
lyken | 8631fc8b58 | |
David Mak | 4b666f8706 |
|
@ -7,15 +7,16 @@ namespace {
|
||||||
* @brief The NDArray object
|
* @brief The NDArray object
|
||||||
*
|
*
|
||||||
* Official numpy implementation:
|
* Official numpy implementation:
|
||||||
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
|
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface
|
||||||
|
*
|
||||||
|
* Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The
|
||||||
|
* difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is
|
||||||
|
* that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject`
|
||||||
|
* does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its
|
||||||
|
* `data`. There are also minor differences in the struct layout.
|
||||||
*/
|
*/
|
||||||
template<typename SizeT>
|
template<typename SizeT>
|
||||||
struct NDArray {
|
struct NDArray {
|
||||||
/**
|
|
||||||
* @brief The underlying data this `ndarray` is pointing to.
|
|
||||||
*/
|
|
||||||
void* data;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief The number of bytes of a single element in `data`.
|
* @brief The number of bytes of a single element in `data`.
|
||||||
*/
|
*/
|
||||||
|
@ -41,5 +42,10 @@ struct NDArray {
|
||||||
* Note that `strides` can have negative values or contain 0.
|
* Note that `strides` can have negative values or contain 0.
|
||||||
*/
|
*/
|
||||||
SizeT* strides;
|
SizeT* strides;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief The underlying data this `ndarray` is pointing to.
|
||||||
|
*/
|
||||||
|
void* data;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
|
@ -1,5 +1,5 @@
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::{AsContextRef, Context, ContextRef},
|
||||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||||
values::{IntValue, PointerValue},
|
values::{IntValue, PointerValue},
|
||||||
AddressSpace,
|
AddressSpace,
|
||||||
|
@ -31,23 +31,19 @@ pub struct NDArrayType<'ctx> {
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||||
pub struct NDArrayStructFields<'ctx> {
|
pub struct NDArrayStructFields<'ctx> {
|
||||||
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
|
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
|
||||||
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||||
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
pub strides: StructField<'ctx, PointerValue<'ctx>>,
|
pub strides: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
|
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
|
||||||
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self {
|
fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||||
|
let ctx = unsafe { ContextRef::new(ctx.as_ctx_ref()) };
|
||||||
let mut counter = FieldIndexCounter::default();
|
let mut counter = FieldIndexCounter::default();
|
||||||
|
|
||||||
NDArrayStructFields {
|
NDArrayStructFields {
|
||||||
data: StructField::create(
|
|
||||||
&mut counter,
|
|
||||||
"data",
|
|
||||||
ctx.i8_type().ptr_type(AddressSpace::default()),
|
|
||||||
),
|
|
||||||
itemsize: StructField::create(&mut counter, "itemsize", llvm_usize),
|
itemsize: StructField::create(&mut counter, "itemsize", llvm_usize),
|
||||||
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
|
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
|
||||||
shape: StructField::create(
|
shape: StructField::create(
|
||||||
|
@ -60,16 +56,21 @@ impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
|
||||||
"strides",
|
"strides",
|
||||||
llvm_usize.ptr_type(AddressSpace::default()),
|
llvm_usize.ptr_type(AddressSpace::default()),
|
||||||
),
|
),
|
||||||
|
data: StructField::create(
|
||||||
|
&mut counter,
|
||||||
|
"data",
|
||||||
|
ctx.i8_type().ptr_type(AddressSpace::default()),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
|
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
|
||||||
vec![
|
vec![
|
||||||
self.data.into(),
|
|
||||||
self.itemsize.into(),
|
self.itemsize.into(),
|
||||||
self.ndims.into(),
|
self.ndims.into(),
|
||||||
self.shape.into(),
|
self.shape.into(),
|
||||||
self.strides.into(),
|
self.strides.into(),
|
||||||
|
self.data.into(),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -80,106 +81,45 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
llvm_ty: PointerType<'ctx>,
|
llvm_ty: PointerType<'ctx>,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
|
let ctx = llvm_ty.get_context();
|
||||||
|
|
||||||
|
let llvm_expected_ty = Self::fields(ctx, llvm_usize).into_vec();
|
||||||
|
|
||||||
let llvm_ndarray_ty = llvm_ty.get_element_type();
|
let llvm_ndarray_ty = llvm_ty.get_element_type();
|
||||||
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||||
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
|
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
|
||||||
};
|
};
|
||||||
if llvm_ndarray_ty.count_fields() != 5 {
|
if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() {
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"Expected 5 fields in `NDArray`, got {}",
|
"Expected {} fields in `NDArray`, got {}",
|
||||||
|
llvm_expected_ty.len(),
|
||||||
llvm_ndarray_ty.count_fields()
|
llvm_ndarray_ty.count_fields()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap();
|
llvm_expected_ty
|
||||||
let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else {
|
.iter()
|
||||||
return Err(format!("Expected pointer type for `ndarray.data`, got {ndarray_data_ty}"));
|
.enumerate()
|
||||||
};
|
.map(|(i, expected_ty)| {
|
||||||
let ndarray_data = ndarray_pdata.get_element_type();
|
(expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap())
|
||||||
let Ok(ndarray_data) = IntType::try_from(ndarray_data) else {
|
})
|
||||||
return Err(format!(
|
.try_for_each(|(expected_ty, actual_ty)| {
|
||||||
"Expected pointer-to-int type for `ndarray.data`, got pointer-to-{ndarray_data}"
|
if expected_ty == actual_ty {
|
||||||
));
|
Ok(())
|
||||||
};
|
} else {
|
||||||
if ndarray_data.get_bit_width() != 8 {
|
Err(format!("Expected {expected_ty} for `ndarray.data`, got {actual_ty}"))
|
||||||
return Err(format!(
|
|
||||||
"Expected pointer-to-8-bit int type for `ndarray.data`, got pointer-to-{}-bit int",
|
|
||||||
ndarray_data.get_bit_width()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let ndarray_itemsize_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
|
|
||||||
let Ok(ndarray_itemsize_ty) = IntType::try_from(ndarray_itemsize_ty) else {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected int type for `ndarray.itemsize`, got {ndarray_itemsize_ty}"
|
|
||||||
));
|
|
||||||
};
|
|
||||||
if ndarray_itemsize_ty.get_bit_width() != llvm_usize.get_bit_width() {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected {}-bit int type for `ndarray.itemsize`, got {}-bit int",
|
|
||||||
llvm_usize.get_bit_width(),
|
|
||||||
ndarray_itemsize_ty.get_bit_width()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
|
|
||||||
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
|
|
||||||
return Err(format!("Expected int type for `ndarray.ndims`, got {ndarray_ndims_ty}"));
|
|
||||||
};
|
|
||||||
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected {}-bit int type for `ndarray.ndims`, got {}-bit int",
|
|
||||||
llvm_usize.get_bit_width(),
|
|
||||||
ndarray_ndims_ty.get_bit_width()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let ndarray_shape_ty = llvm_ndarray_ty.get_field_type_at_index(3).unwrap();
|
|
||||||
let Ok(ndarray_pshape) = PointerType::try_from(ndarray_shape_ty) else {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected pointer type for `ndarray.shape`, got {ndarray_shape_ty}"
|
|
||||||
));
|
|
||||||
};
|
|
||||||
let ndarray_shape = ndarray_pshape.get_element_type();
|
|
||||||
let Ok(ndarray_shape) = IntType::try_from(ndarray_shape) else {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected pointer-to-int type for `ndarray.shape`, got pointer-to-{ndarray_shape}"
|
|
||||||
));
|
|
||||||
};
|
|
||||||
if ndarray_shape.get_bit_width() != llvm_usize.get_bit_width() {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected pointer-to-{}-bit int type for `ndarray.shape`, got pointer-to-{}-bit int",
|
|
||||||
llvm_usize.get_bit_width(),
|
|
||||||
ndarray_shape.get_bit_width()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(4).unwrap();
|
|
||||||
let Ok(ndarray_pstrides) = PointerType::try_from(ndarray_dims_ty) else {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected pointer type for `ndarray.strides`, got {ndarray_dims_ty}"
|
|
||||||
));
|
|
||||||
};
|
|
||||||
let ndarray_strides = ndarray_pstrides.get_element_type();
|
|
||||||
let Ok(ndarray_strides) = IntType::try_from(ndarray_strides) else {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected pointer-to-int type for `ndarray.strides`, got pointer-to-{ndarray_strides}"
|
|
||||||
));
|
|
||||||
};
|
|
||||||
if ndarray_strides.get_bit_width() != llvm_usize.get_bit_width() {
|
|
||||||
return Err(format!(
|
|
||||||
"Expected pointer-to-{}-bit int type for `ndarray.strides`, got pointer-to-{}-bit int",
|
|
||||||
llvm_usize.get_bit_width(),
|
|
||||||
ndarray_strides.get_bit_width()
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Move this into e.g. StructProxyType
|
// TODO: Move this into e.g. StructProxyType
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
|
fn fields(
|
||||||
|
ctx: impl AsContextRef<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> NDArrayStructFields<'ctx> {
|
||||||
NDArrayStructFields::new(ctx, llvm_usize)
|
NDArrayStructFields::new(ctx, llvm_usize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,7 +127,7 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn get_fields(
|
pub fn get_fields(
|
||||||
&self,
|
&self,
|
||||||
ctx: &'ctx Context,
|
ctx: impl AsContextRef<'ctx>,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
) -> NDArrayStructFields<'ctx> {
|
) -> NDArrayStructFields<'ctx> {
|
||||||
Self::fields(ctx, llvm_usize)
|
Self::fields(ctx, llvm_usize)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::AsContextRef,
|
||||||
types::{BasicTypeEnum, IntType},
|
types::{BasicTypeEnum, IntType},
|
||||||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
|
values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
|
||||||
};
|
};
|
||||||
|
@ -22,7 +22,7 @@ use crate::codegen::CodeGenContext;
|
||||||
/// ```
|
/// ```
|
||||||
pub trait StructFields<'ctx>: Eq + Copy {
|
pub trait StructFields<'ctx>: Eq + Copy {
|
||||||
/// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types.
|
/// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types.
|
||||||
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self;
|
fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self;
|
||||||
|
|
||||||
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
|
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
|
||||||
/// the type definition.
|
/// the type definition.
|
||||||
|
@ -94,6 +94,19 @@ where
|
||||||
StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData }
|
StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Creates an instance of [`StructField`] with a given index.
|
||||||
|
///
|
||||||
|
/// * `index` - The index of this field within its enclosing structure.
|
||||||
|
/// * `name` - Name of the field.
|
||||||
|
/// * `ty` - The type of this field.
|
||||||
|
pub(super) fn create_at(
|
||||||
|
index: u32,
|
||||||
|
name: &'static str,
|
||||||
|
ty: impl Into<BasicTypeEnum<'ctx>>,
|
||||||
|
) -> Self {
|
||||||
|
StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
|
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
|
||||||
/// {idx...}, i32 {self.index}`.
|
/// {idx...}, i32 {self.index}`.
|
||||||
pub fn ptr_by_array_gep(
|
pub fn ptr_by_array_gep(
|
||||||
|
|
|
@ -3,7 +3,6 @@ use inkwell::{
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
|
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
|
||||||
|
@ -13,7 +12,7 @@ use crate::codegen::{
|
||||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
||||||
llvm_intrinsics::call_int_umin,
|
llvm_intrinsics::call_int_umin,
|
||||||
stmt::gen_for_callback_incrementing,
|
stmt::gen_for_callback_incrementing,
|
||||||
types::{structure::StructFields, NDArrayType},
|
types::NDArrayType,
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -51,29 +50,127 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name }
|
NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||||
|
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
self.get_type()
|
||||||
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.ndims
|
||||||
|
.ptr_by_gep(ctx, self.value, self.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the number of dimensions `ndims` into this instance.
|
||||||
|
pub fn store_ndims<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
ndims: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
|
||||||
|
|
||||||
|
let pndims = self.ptr_to_ndims(ctx);
|
||||||
|
ctx.builder.build_store(pndims, ndims).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of dimensions of this `NDArray` as a value.
|
||||||
|
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||||
|
let pndims = self.ptr_to_ndims(ctx);
|
||||||
|
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the pointer to the field storing the size of each element of this `NDArray`.
|
||||||
|
fn ptr_to_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
self.get_type()
|
||||||
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.itemsize
|
||||||
|
.ptr_by_gep(ctx, self.value, self.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the size of each element `itemsize` into this instance.
|
||||||
|
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
ndims: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
|
||||||
|
|
||||||
|
let pndims = self.ptr_to_ndims(ctx);
|
||||||
|
ctx.builder.build_store(pndims, ndims).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the size of each element of this `NDArray` as a value.
|
||||||
|
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||||
|
let pndims = self.ptr_to_ndims(ctx);
|
||||||
|
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
||||||
|
/// `getelementptr` on the field.
|
||||||
|
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
self.get_type()
|
||||||
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.shape
|
||||||
|
.ptr_by_gep(ctx, self.value, self.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the array of dimension sizes `dims` into this instance.
|
||||||
|
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
||||||
|
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
|
||||||
|
pub fn create_shape(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
|
||||||
|
NDArrayShapeProxy(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the double-indirection pointer to the `stride` array, as if by calling
|
||||||
|
/// `getelementptr` on the field.
|
||||||
|
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
self.get_type()
|
||||||
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.strides
|
||||||
|
.ptr_by_gep(ctx, self.value, self.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the array of dimension sizes `dims` into this instance.
|
||||||
|
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
||||||
|
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience method for creating a new array storing the stride with the given `size`.
|
||||||
|
pub fn create_strides(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
|
||||||
|
NDArrayStridesProxy(self)
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||||
/// on the field.
|
/// on the field.
|
||||||
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
self.get_type()
|
||||||
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
|
|
||||||
|
|
||||||
let field_offset = self
|
|
||||||
.get_type()
|
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
.into_iter()
|
.data
|
||||||
.find_position(|field| field.0 == "data")
|
.ptr_by_gep(ctx, self.value, self.name)
|
||||||
.unwrap()
|
|
||||||
.0 as u64;
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
self.as_base_value(),
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
|
|
||||||
var_name.as_str(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stores the array of data elements `data` into this instance.
|
/// Stores the array of data elements `data` into this instance.
|
||||||
|
@ -109,169 +206,6 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
|
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
|
||||||
NDArrayDataProxy(self)
|
NDArrayDataProxy(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
|
||||||
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
|
||||||
self.get_type().get_fields(ctx.ctx, self.llvm_usize).ndims.ptr_by_gep(
|
|
||||||
ctx,
|
|
||||||
self.as_base_value(),
|
|
||||||
self.name,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stores the number of dimensions `ndims` into this instance.
|
|
||||||
pub fn store_ndims<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &G,
|
|
||||||
ndims: IntValue<'ctx>,
|
|
||||||
) {
|
|
||||||
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
|
|
||||||
|
|
||||||
let pndims = self.ptr_to_ndims(ctx);
|
|
||||||
ctx.builder.build_store(pndims, ndims).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the number of dimensions of this `NDArray` as a value.
|
|
||||||
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
|
||||||
let pndims = self.ptr_to_ndims(ctx);
|
|
||||||
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the pointer to the field storing the size of each element of this `NDArray`.
|
|
||||||
fn ptr_to_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let var_name = self.name.map(|v| format!("{v}.itemsize.addr")).unwrap_or_default();
|
|
||||||
|
|
||||||
let field_offset = self
|
|
||||||
.get_type()
|
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
|
||||||
.into_iter()
|
|
||||||
.find_position(|field| field.0 == "itemsize")
|
|
||||||
.unwrap()
|
|
||||||
.0 as u64;
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
self.as_base_value(),
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, false)],
|
|
||||||
var_name.as_str(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stores the size of each element `itemsize` into this instance.
|
|
||||||
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &G,
|
|
||||||
ndims: IntValue<'ctx>,
|
|
||||||
) {
|
|
||||||
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
|
|
||||||
|
|
||||||
let pndims = self.ptr_to_ndims(ctx);
|
|
||||||
ctx.builder.build_store(pndims, ndims).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the size of each element of this `NDArray` as a value.
|
|
||||||
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
|
||||||
let pndims = self.ptr_to_ndims(ctx);
|
|
||||||
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
|
||||||
/// `getelementptr` on the field.
|
|
||||||
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let var_name = self.name.map(|v| format!("{v}.shape.addr")).unwrap_or_default();
|
|
||||||
|
|
||||||
let field_offset = self
|
|
||||||
.get_type()
|
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
|
||||||
.into_iter()
|
|
||||||
.find_position(|field| field.0 == "shape")
|
|
||||||
.unwrap()
|
|
||||||
.0 as u64;
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
self.as_base_value(),
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
|
|
||||||
var_name.as_str(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stores the array of dimension sizes `dims` into this instance.
|
|
||||||
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
|
||||||
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
|
|
||||||
pub fn create_shape(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
llvm_usize: IntType<'ctx>,
|
|
||||||
size: IntValue<'ctx>,
|
|
||||||
) {
|
|
||||||
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
|
|
||||||
NDArrayShapeProxy(self)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `stride` array, as if by calling
|
|
||||||
/// `getelementptr` on the field.
|
|
||||||
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let var_name = self.name.map(|v| format!("{v}.strides.addr")).unwrap_or_default();
|
|
||||||
|
|
||||||
let field_offset = self
|
|
||||||
.get_type()
|
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
|
||||||
.into_iter()
|
|
||||||
.find_position(|field| field.0 == "strides")
|
|
||||||
.unwrap()
|
|
||||||
.0 as u64;
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
self.as_base_value(),
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
|
|
||||||
var_name.as_str(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stores the array of dimension sizes `dims` into this instance.
|
|
||||||
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
|
||||||
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience method for creating a new array storing the stride with the given `size`.
|
|
||||||
pub fn create_strides(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
llvm_usize: IntType<'ctx>,
|
|
||||||
size: IntValue<'ctx>,
|
|
||||||
) {
|
|
||||||
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
|
|
||||||
NDArrayStridesProxy(self)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
||||||
|
|
Loading…
Reference in New Issue