[core] codegen: Implement construction of unsized ndarrays

Partially based on f731e604: core/ndstrides: add more ScalarOrNDArray
and NDArrayObject utils.
This commit is contained in:
David Mak 2024-12-12 11:19:01 +08:00
parent 061747c67b
commit 27a6f47330

View File

@ -1,7 +1,7 @@
use inkwell::{
context::{AsContextRef, Context},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
values::{BasicValue, IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
@ -116,6 +116,19 @@ impl<'ctx> NDArrayType<'ctx> {
NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize }
}
/// Creates an instance of [`NDArrayType`] with `ndims` of 0.
#[must_use]
pub fn new_unsized<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
NDArrayType { ty: llvm_ndarray, dtype, ndims: Some(0), llvm_usize }
}
/// Creates an [`NDArrayType`] from a [unifier type][Type].
#[must_use]
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
@ -343,6 +356,34 @@ impl<'ctx> NDArrayType<'ctx> {
ndarray
}
/// Create an unsized ndarray to contain `value`.
#[must_use]
pub fn construct_unsized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
value: &impl BasicValue<'ctx>,
name: Option<&'ctx str>,
) -> NDArrayValue<'ctx> {
let value = value.as_basic_value_enum();
assert_eq!(value.get_type(), self.dtype);
assert!(self.ndims.is_none_or(|ndims| ndims == 0));
// We have to put the value on the stack to get a data pointer.
let data = ctx.builder.build_alloca(value.get_type(), "construct_unsized").unwrap();
ctx.builder.build_store(data, value).unwrap();
let data = ctx
.builder
.build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap();
let ndarray = Self::new_unsized(generator, ctx.ctx, value.get_type())
.construct_uninitialized(generator, ctx, name);
ctx.builder.build_store(ndarray.ptr_to_data(ctx), data).unwrap();
ndarray
}
/// Converts an existing value into a [`NDArrayValue`].
#[must_use]
pub fn map_value(