[core] coregen/types: Implement StructFields for NDArray
This commit is contained in:
parent
db0e1eb3d4
commit
b8b34e0b2f
|
@ -1,11 +1,15 @@
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::Context,
|
||||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||||
values::IntValue,
|
values::{IntValue, PointerValue},
|
||||||
AddressSpace,
|
AddressSpace,
|
||||||
};
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
use super::ProxyType;
|
use super::{
|
||||||
|
structure::{FieldIndexCounter, StructField, StructFields},
|
||||||
|
ProxyType,
|
||||||
|
};
|
||||||
use crate::codegen::{
|
use crate::codegen::{
|
||||||
values::{ArraySliceValue, NDArrayValue, ProxyValue},
|
values::{ArraySliceValue, NDArrayValue, ProxyValue},
|
||||||
{CodeGenContext, CodeGenerator},
|
{CodeGenContext, CodeGenerator},
|
||||||
|
@ -19,6 +23,51 @@ pub struct NDArrayType<'ctx> {
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub struct NDArrayStructFields<'ctx> {
|
||||||
|
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
|
||||||
|
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||||
|
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
pub strides: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
|
||||||
|
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self {
|
||||||
|
let mut counter = FieldIndexCounter::default();
|
||||||
|
|
||||||
|
NDArrayStructFields {
|
||||||
|
data: StructField::create(
|
||||||
|
&mut counter,
|
||||||
|
"data",
|
||||||
|
ctx.i8_type().ptr_type(AddressSpace::default()),
|
||||||
|
),
|
||||||
|
itemsize: StructField::create(&mut counter, "itemsize", llvm_usize),
|
||||||
|
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
|
||||||
|
shape: StructField::create(
|
||||||
|
&mut counter,
|
||||||
|
"shape",
|
||||||
|
llvm_usize.ptr_type(AddressSpace::default()),
|
||||||
|
),
|
||||||
|
strides: StructField::create(
|
||||||
|
&mut counter,
|
||||||
|
"strides",
|
||||||
|
llvm_usize.ptr_type(AddressSpace::default()),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
|
||||||
|
vec![
|
||||||
|
self.data.into(),
|
||||||
|
self.itemsize.into(),
|
||||||
|
self.ndims.into(),
|
||||||
|
self.shape.into(),
|
||||||
|
self.strides.into(),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'ctx> NDArrayType<'ctx> {
|
impl<'ctx> NDArrayType<'ctx> {
|
||||||
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
|
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
|
||||||
pub fn is_representable(
|
pub fn is_representable(
|
||||||
|
@ -86,19 +135,34 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Move this into e.g. StructProxyType
|
||||||
|
#[must_use]
|
||||||
|
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
|
||||||
|
NDArrayStructFields::new(ctx, llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Move this into e.g. StructProxyType
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_fields(
|
||||||
|
&self,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> NDArrayStructFields<'ctx> {
|
||||||
|
Self::fields(ctx, llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||||
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
|
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
|
||||||
//
|
//
|
||||||
// * num_dims: Number of dimensions in the array
|
|
||||||
// * dims: Pointer to an array containing the size of each dimension
|
|
||||||
// * data : Pointer to an array containing the array data
|
// * data : Pointer to an array containing the array data
|
||||||
let field_tys = [
|
// * itemsize: The size of each NDArray elements in bytes
|
||||||
llvm_usize.into(),
|
// * ndims : Number of dimensions in the array
|
||||||
llvm_usize.ptr_type(AddressSpace::default()).into(),
|
// * shape : Pointer to an array containing the shape of the NDArray
|
||||||
ctx.i8_type().ptr_type(AddressSpace::default()).into(),
|
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
|
||||||
];
|
let field_tys =
|
||||||
|
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||||
|
|
||||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ use inkwell::{
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
|
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
|
||||||
|
@ -12,7 +13,7 @@ use crate::codegen::{
|
||||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
||||||
llvm_intrinsics::call_int_umin,
|
llvm_intrinsics::call_int_umin,
|
||||||
stmt::gen_for_callback_incrementing,
|
stmt::gen_for_callback_incrementing,
|
||||||
types::NDArrayType,
|
types::{structure::StructFields, NDArrayType},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -48,90 +49,25 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
NDArrayValue { value: ptr, dtype, llvm_usize, name }
|
NDArrayValue { value: ptr, dtype, llvm_usize, name }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
|
||||||
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
self.as_base_value(),
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
|
||||||
var_name.as_str(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stores the number of dimensions `ndims` into this instance.
|
|
||||||
pub fn store_ndims<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
generator: &G,
|
|
||||||
ndims: IntValue<'ctx>,
|
|
||||||
) {
|
|
||||||
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
|
|
||||||
|
|
||||||
let pndims = self.ptr_to_ndims(ctx);
|
|
||||||
ctx.builder.build_store(pndims, ndims).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the number of dimensions of this `NDArray` as a value.
|
|
||||||
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
|
||||||
let pndims = self.ptr_to_ndims(ctx);
|
|
||||||
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
|
|
||||||
/// on the field.
|
|
||||||
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
self.as_base_value(),
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
|
||||||
var_name.as_str(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stores the array of dimension sizes `dims` into this instance.
|
|
||||||
fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
|
||||||
ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
|
|
||||||
pub fn create_dim_sizes(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
llvm_usize: IntType<'ctx>,
|
|
||||||
size: IntValue<'ctx>,
|
|
||||||
) {
|
|
||||||
self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> {
|
|
||||||
NDArrayDimsProxy(self)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||||
/// on the field.
|
/// on the field.
|
||||||
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
|
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
let field_offset = self
|
||||||
|
.get_type()
|
||||||
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.into_iter()
|
||||||
|
.find_position(|field| field.0 == "data")
|
||||||
|
.unwrap()
|
||||||
|
.0 as u64;
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_in_bounds_gep(
|
.build_in_bounds_gep(
|
||||||
self.as_base_value(),
|
self.as_base_value(),
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
|
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
|
||||||
var_name.as_str(),
|
var_name.as_str(),
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -171,6 +107,124 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
|
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
|
||||||
NDArrayDataProxy(self)
|
NDArrayDataProxy(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||||
|
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
self.get_type().get_fields(ctx.ctx, self.llvm_usize).ndims.ptr_by_gep(
|
||||||
|
ctx,
|
||||||
|
self.as_base_value(),
|
||||||
|
self.name,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the number of dimensions `ndims` into this instance.
|
||||||
|
pub fn store_ndims<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
ndims: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
|
||||||
|
|
||||||
|
let pndims = self.ptr_to_ndims(ctx);
|
||||||
|
ctx.builder.build_store(pndims, ndims).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of dimensions of this `NDArray` as a value.
|
||||||
|
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||||
|
let pndims = self.ptr_to_ndims(ctx);
|
||||||
|
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
|
||||||
|
/// on the field.
|
||||||
|
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let var_name = self.name.map(|v| format!("{v}.itemsize.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
let field_offset = self
|
||||||
|
.get_type()
|
||||||
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.into_iter()
|
||||||
|
.find_position(|field| field.0 == "itemsize")
|
||||||
|
.unwrap()
|
||||||
|
.0 as u64;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
self.as_base_value(),
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, false)],
|
||||||
|
var_name.as_str(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the size of each element `itemsize` into this instance.
|
||||||
|
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &G,
|
||||||
|
ndims: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
|
||||||
|
|
||||||
|
let pndims = self.ptr_to_ndims(ctx);
|
||||||
|
ctx.builder.build_store(pndims, ndims).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the size of each element of this `NDArray` as a value.
|
||||||
|
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||||
|
let pndims = self.ptr_to_ndims(ctx);
|
||||||
|
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
||||||
|
/// `getelementptr` on the field.
|
||||||
|
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let var_name = self.name.map(|v| format!("{v}.shape.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
let field_offset = self
|
||||||
|
.get_type()
|
||||||
|
.get_fields(ctx.ctx, self.llvm_usize)
|
||||||
|
.into_iter()
|
||||||
|
.find_position(|field| field.0 == "shape")
|
||||||
|
.unwrap()
|
||||||
|
.0 as u64;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
self.as_base_value(),
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, true)],
|
||||||
|
var_name.as_str(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stores the array of dimension sizes `dims` into this instance.
|
||||||
|
fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
||||||
|
ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
|
||||||
|
pub fn create_dim_sizes(
|
||||||
|
&self,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> {
|
||||||
|
NDArrayDimsProxy(self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
||||||
|
|
Loading…
Reference in New Issue