[core] coregen/types: Implement StructFields for NDArray

This commit is contained in:
David Mak 2024-11-13 15:53:29 +08:00
parent db0e1eb3d4
commit ebe1b5b85f
2 changed files with 207 additions and 86 deletions

View File

@ -1,11 +1,15 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::IntValue,
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use super::ProxyType;
use super::{
structure::{FieldIndexCounter, StructField, StructFields},
ProxyType,
};
use crate::codegen::{
values::{ArraySliceValue, NDArrayValue, ProxyValue},
{CodeGenContext, CodeGenerator},
@ -19,6 +23,51 @@ pub struct NDArrayType<'ctx> {
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct NDArrayStructFields<'ctx> {
pub data: StructField<'ctx, PointerValue<'ctx>>,
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
pub ndims: StructField<'ctx, IntValue<'ctx>>,
pub shape: StructField<'ctx, PointerValue<'ctx>>,
pub strides: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self {
let mut counter = FieldIndexCounter::default();
NDArrayStructFields {
data: StructField::create(
&mut counter,
"data",
ctx.i8_type().ptr_type(AddressSpace::default()),
),
itemsize: StructField::create(&mut counter, "itemsize", llvm_usize),
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
shape: StructField::create(
&mut counter,
"shape",
llvm_usize.ptr_type(AddressSpace::default()),
),
strides: StructField::create(
&mut counter,
"strides",
llvm_usize.ptr_type(AddressSpace::default()),
),
}
}
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
vec![
self.data.into(),
self.itemsize.into(),
self.ndims.into(),
self.shape.into(),
self.strides.into(),
]
}
}
impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable(
@ -86,19 +135,39 @@ impl<'ctx> NDArrayType<'ctx> {
Ok(())
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
fn fields(
ctx: &'ctx Context,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize)
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(
&self,
ctx: &'ctx Context,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
//
// * num_dims: Number of dimensions in the array
// * dims: Pointer to an array containing the size of each dimension
// * data : Pointer to an array containing the array data
let field_tys = [
llvm_usize.into(),
llvm_usize.ptr_type(AddressSpace::default()).into(),
ctx.i8_type().ptr_type(AddressSpace::default()).into(),
];
// * itemsize: The size of each NDArray elements in bytes
// * ndims : Number of dimensions in the array
// * shape : Pointer to an array containing the shape of the NDArray
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
let field_tys = Self::fields(ctx, llvm_usize)
.into_iter()
.map(|field| field.1)
.collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}

View File

@ -3,7 +3,7 @@ use inkwell::{
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use itertools::Itertools;
use super::{
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
@ -12,7 +12,7 @@ use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
types::NDArrayType,
types::{NDArrayType, structure::StructFields},
CodeGenContext, CodeGenerator,
};
@ -48,90 +48,25 @@ impl<'ctx> NDArrayValue<'ctx> {
NDArrayValue { value: ptr, dtype, llvm_usize, name }
}
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
/// on the field.
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_dim_sizes(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> {
NDArrayDimsProxy(self)
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "data")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
var_name.as_str(),
)
.unwrap()
@ -171,6 +106,123 @@ impl<'ctx> NDArrayValue<'ctx> {
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
NDArrayDataProxy(self)
}
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.ndims
.ptr_by_gep(ctx, self.as_base_value(), self.name)
}
/// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
/// on the field.
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.itemsize.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "itemsize")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, false)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the size of each element `itemsize` into this instance.
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the size of each element of this `NDArray` as a value.
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.shape.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "shape")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_dim_sizes(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> {
NDArrayDimsProxy(self)
}
}
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {