[core] coregen/types: Implement StructFields for NDArray

This commit is contained in:
David Mak 2024-11-13 15:53:29 +08:00
parent db0e1eb3d4
commit b8b34e0b2f
2 changed files with 203 additions and 85 deletions

View File

@ -1,11 +1,15 @@
use inkwell::{ use inkwell::{
context::Context, context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::IntValue, values::{IntValue, PointerValue},
AddressSpace, AddressSpace,
}; };
use itertools::Itertools;
use super::ProxyType; use super::{
structure::{FieldIndexCounter, StructField, StructFields},
ProxyType,
};
use crate::codegen::{ use crate::codegen::{
values::{ArraySliceValue, NDArrayValue, ProxyValue}, values::{ArraySliceValue, NDArrayValue, ProxyValue},
{CodeGenContext, CodeGenerator}, {CodeGenContext, CodeGenerator},
@ -19,6 +23,51 @@ pub struct NDArrayType<'ctx> {
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
} }
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct NDArrayStructFields<'ctx> {
pub data: StructField<'ctx, PointerValue<'ctx>>,
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
pub ndims: StructField<'ctx, IntValue<'ctx>>,
pub shape: StructField<'ctx, PointerValue<'ctx>>,
pub strides: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self {
let mut counter = FieldIndexCounter::default();
NDArrayStructFields {
data: StructField::create(
&mut counter,
"data",
ctx.i8_type().ptr_type(AddressSpace::default()),
),
itemsize: StructField::create(&mut counter, "itemsize", llvm_usize),
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
shape: StructField::create(
&mut counter,
"shape",
llvm_usize.ptr_type(AddressSpace::default()),
),
strides: StructField::create(
&mut counter,
"strides",
llvm_usize.ptr_type(AddressSpace::default()),
),
}
}
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
vec![
self.data.into(),
self.itemsize.into(),
self.ndims.into(),
self.shape.into(),
self.strides.into(),
]
}
}
impl<'ctx> NDArrayType<'ctx> { impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable( pub fn is_representable(
@ -86,19 +135,34 @@ impl<'ctx> NDArrayType<'ctx> {
Ok(()) Ok(())
} }
// TODO: Move this into e.g. StructProxyType
#[must_use]
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize)
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(
&self,
ctx: &'ctx Context,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`. /// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use] #[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* } // struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
// //
// * num_dims: Number of dimensions in the array
// * dims: Pointer to an array containing the size of each dimension
// * data : Pointer to an array containing the array data // * data : Pointer to an array containing the array data
let field_tys = [ // * itemsize: The size of each NDArray elements in bytes
llvm_usize.into(), // * ndims : Number of dimensions in the array
llvm_usize.ptr_type(AddressSpace::default()).into(), // * shape : Pointer to an array containing the shape of the NDArray
ctx.i8_type().ptr_type(AddressSpace::default()).into(), // * strides : Pointer to an array indicating the number of bytes between each element at a dimension
]; let field_tys =
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
} }

View File

@ -3,6 +3,7 @@ use inkwell::{
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
use itertools::Itertools;
use super::{ use super::{
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator, ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
@ -12,7 +13,7 @@ use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin, llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
types::NDArrayType, types::{structure::StructFields, NDArrayType},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
@ -48,90 +49,25 @@ impl<'ctx> NDArrayValue<'ctx> {
NDArrayValue { value: ptr, dtype, llvm_usize, name } NDArrayValue { value: ptr, dtype, llvm_usize, name }
} }
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
/// on the field.
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_dim_sizes(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> {
NDArrayDimsProxy(self)
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "data")
.unwrap()
.0 as u64;
unsafe { unsafe {
ctx.builder ctx.builder
.build_in_bounds_gep( .build_in_bounds_gep(
self.as_base_value(), self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], &[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
var_name.as_str(), var_name.as_str(),
) )
.unwrap() .unwrap()
@ -171,6 +107,124 @@ impl<'ctx> NDArrayValue<'ctx> {
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> { pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
NDArrayDataProxy(self) NDArrayDataProxy(self)
} }
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type().get_fields(ctx.ctx, self.llvm_usize).ndims.ptr_by_gep(
ctx,
self.as_base_value(),
self.name,
)
}
/// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
/// on the field.
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.itemsize.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "itemsize")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, false)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the size of each element `itemsize` into this instance.
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the size of each element of this `NDArray` as a value.
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.shape.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "shape")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_dim_sizes(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> {
NDArrayDimsProxy(self)
}
} }
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {