From 9a82b033b6816f8d1de4ae39af762c7b251f3cf8 Mon Sep 17 00:00:00 2001 From: lyken Date: Sun, 14 Jul 2024 15:45:06 +0800 Subject: [PATCH] core: ndarray fill generic --- nac3core/irrt/irrt/ndarray/ndarray.hpp | 21 +++++ nac3core/src/codegen/irrt/classes.rs | 16 +--- nac3core/src/codegen/irrt/numpy.rs | 114 ++++++++++++++----------- nac3core/src/codegen/optics.rs | 9 +- 4 files changed, 94 insertions(+), 66 deletions(-) diff --git a/nac3core/irrt/irrt/ndarray/ndarray.hpp b/nac3core/irrt/irrt/ndarray/ndarray.hpp index 29c2e739..c0dcc8a6 100644 --- a/nac3core/irrt/irrt/ndarray/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray/ndarray.hpp @@ -96,6 +96,19 @@ struct NDArray { } return get_nth_pelement(nth); } + + void set_pelement_value(uint8_t* pelement, const uint8_t* pvalue) { + __builtin_memcpy(pelement, pvalue, itemsize); + } + + // Fill the ndarray with a value + void fill_generic(const uint8_t* pvalue) { + const SizeT size = this->size(); + for (SizeT i = 0; i < size; i++) { + uint8_t* pelement = get_nth_pelement(i); // No need for checked_get_nth_pelement + set_pelement_value(pelement, pvalue); + } + } }; } @@ -131,4 +144,12 @@ void __nac3_ndarray_set_strides_by_shape(NDArray* ndarray) { void __nac3_ndarray_set_strides_by_shape64(NDArray* ndarray) { ndarray->set_strides_by_shape(); } + +void __nac3_ndarray_fill_generic(NDArray* ndarray, uint8_t* pvalue) { + ndarray->fill_generic(pvalue); +} + +void __nac3_ndarray_fill_generic64(NDArray* ndarray, uint8_t* pvalue) { + ndarray->fill_generic(pvalue); +} } \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/classes.rs b/nac3core/src/codegen/irrt/classes.rs index fe2f5f39..0c28c745 100644 --- a/nac3core/src/codegen/irrt/classes.rs +++ b/nac3core/src/codegen/irrt/classes.rs @@ -1,9 +1,7 @@ use inkwell::types::IntType; use crate::codegen::{ - optics::{ - Address, AddressLens, ArraySlice, FieldBuilder, GepGetter, IntLens, Optic, StructureOptic, - }, + optics::{Address, AddressLens, ArraySlice, FieldBuilder, GepGetter, IntLens, StructureOptic}, CodeGenContext, }; @@ -41,18 +39,12 @@ pub struct NpArrayFields<'ctx> { pub strides: GepGetter>>, } -// Note: NpArrayLens's ElementOptic is purely for type-safety and type-guidances -// The underlying LLVM ndarray doesn't care, it only holds an opaque (uint8_t*) pointer to the elements. #[derive(Debug, Clone, Copy)] -pub struct NpArrayLens<'ctx, ElementOptic> { +pub struct NpArrayLens<'ctx> { pub size_type: IntType<'ctx>, - pub element_optic: ElementOptic, } -// NDArray is *frequently* used, so here is a type alias -pub type NpArray<'ctx, ElementOptic> = Address<'ctx, NpArrayLens<'ctx, ElementOptic>>; - -impl<'ctx, ElementOptic: Optic<'ctx>> StructureOptic<'ctx> for NpArrayLens<'ctx, ElementOptic> { +impl<'ctx> StructureOptic<'ctx> for NpArrayLens<'ctx> { type Fields = NpArrayFields<'ctx>; fn struct_name(&self) -> &'static str { @@ -74,7 +66,7 @@ impl<'ctx, ElementOptic: Optic<'ctx>> StructureOptic<'ctx> for NpArrayLens<'ctx, } // Other convenient utilities for NpArray -impl<'ctx, ElementOptic: Optic<'ctx>> NpArray<'ctx, ElementOptic> { +impl<'ctx> Address<'ctx, NpArrayLens<'ctx>> { pub fn shape_array(&self, ctx: &CodeGenContext<'ctx, '_>) -> ArraySlice<'ctx, IntLens<'ctx>> { let ndims = self.focus(ctx, |fields| &fields.ndims).load(ctx, "ndims"); let shape_base_ptr = self.focus(ctx, |fields| &fields.shape).load(ctx, "shape"); diff --git a/nac3core/src/codegen/irrt/numpy.rs b/nac3core/src/codegen/irrt/numpy.rs index c81ba9da..ea576136 100644 --- a/nac3core/src/codegen/irrt/numpy.rs +++ b/nac3core/src/codegen/irrt/numpy.rs @@ -1,14 +1,14 @@ use std::marker::PhantomData; use inkwell::{ - types::BasicType, + types::{BasicType, BasicTypeEnum}, values::{BasicValueEnum, IntValue}, }; use crate::{ codegen::{ classes::{ListValue, UntypedArrayLikeAccessor}, - optics::{Address, AddressLens, ArraySlice, IntLens, Ixed, Optic}, + optics::{opaque_address_lens, Address, AddressLens, ArraySlice, IntLens, Ixed, Optic}, stmt::gen_for_callback_incrementing, CodeGenContext, CodeGenerator, }, @@ -16,7 +16,7 @@ use crate::{ }; use super::{ - classes::{ErrorContextLens, NpArray, NpArrayLens}, + classes::{ErrorContextLens, NpArrayLens}, new::{ check_error_context, get_sized_dependent_function_name, prepare_error_context, FunctionBuilder, @@ -176,38 +176,43 @@ where } } -pub fn alloca_ndarray<'ctx, G, ElementOptic: Optic<'ctx>>( +pub fn alloca_ndarray<'ctx, G>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - element_optic: ElementOptic, + elem_type: BasicTypeEnum<'ctx>, ndims: IntValue<'ctx>, name: &str, -) -> Result, String> +) -> Result>, String> where G: CodeGenerator + ?Sized, { let size_type = generator.get_size_type(ctx.ctx); - let itemsize = element_optic.get_llvm_type(ctx.ctx).size_of().unwrap(); + // Allocate ndarray + let ndarray_ptr = NpArrayLens { size_type }.alloca(ctx, name); + + // Set ndims + ndarray_ptr.focus(ctx, |fields| &fields.ndims).store(ctx, &ndims); + + // Set itemsize + let itemsize = elem_type.size_of().unwrap(); let itemsize = ctx.builder.build_int_s_extend_or_bit_cast(itemsize, size_type, "itemsize").unwrap(); + ndarray_ptr.focus(ctx, |fields| &fields.itemsize).store(ctx, &itemsize); - let shape = ctx.builder.build_array_alloca(size_type, ndims, "shape").unwrap(); - let strides = ctx.builder.build_array_alloca(size_type, ndims, "strides").unwrap(); - - let ndarray = NpArrayLens { size_type, element_optic }.alloca(ctx, name); - - // Set ndims, itemsize; and allocate shape and store on the stack - ndarray.focus(ctx, |fields| &fields.ndims).store(ctx, &ndims); - ndarray.focus(ctx, |fields| &fields.itemsize).store(ctx, &itemsize); - ndarray + // Allocate and set shape + let shape_ptr = ctx.builder.build_array_alloca(size_type, ndims, "shape").unwrap(); + ndarray_ptr .focus(ctx, |fields| &fields.shape) - .store(ctx, &Address { addressee_optic: IntLens(size_type), address: shape }); - ndarray - .focus(ctx, |fields| &fields.strides) - .store(ctx, &Address { addressee_optic: IntLens(size_type), address: strides }); + .store(ctx, &Address { addressee_optic: IntLens(size_type), address: shape_ptr }); - Ok(ndarray) + // Allocate and set strides + let strides_ptr = ctx.builder.build_array_alloca(size_type, ndims, "strides").unwrap(); + ndarray_ptr + .focus(ctx, |fields| &fields.strides) + .store(ctx, &Address { addressee_optic: IntLens(size_type), address: strides_ptr }); + + Ok(ndarray_ptr) } enum NDArrayInitMode<'ctx, G: CodeGenerator + ?Sized> { @@ -217,75 +222,75 @@ enum NDArrayInitMode<'ctx, G: CodeGenerator + ?Sized> { } /// TODO: DOCUMENT ME -fn alloca_ndarray_and_init<'ctx, G, ElementOptic: Optic<'ctx>>( +fn alloca_ndarray_and_init<'ctx, G>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - element_optic: ElementOptic, + elem_type: BasicTypeEnum<'ctx>, init_mode: NDArrayInitMode<'ctx, G>, name: &str, -) -> Result, String> +) -> Result>, String> where G: CodeGenerator + ?Sized, { // It is implemented verbosely in order to make the initialization modes super clear in their intent. match init_mode { NDArrayInitMode::NDim { ndim: ndims, _phantom } => { - let ndarray = alloca_ndarray(generator, ctx, element_optic, ndims, name)?; + let ndarray = alloca_ndarray(generator, ctx, elem_type, ndims, name)?; Ok(ndarray) } NDArrayInitMode::Shape { shape } => { let ndims = shape.count; - let ndarray = alloca_ndarray(generator, ctx, element_optic, ndims, name)?; + let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, ndims, name)?; // Fill `ndarray.shape` - (shape.write_to_array)(generator, ctx, &ndarray.shape_array(ctx))?; + (shape.write_to_array)(generator, ctx, &ndarray_ptr.shape_array(ctx))?; // Check if `shape` has bad inputs call_nac3_ndarray_util_assert_shape_no_negative( generator, ctx, ndims, - &ndarray.focus(ctx, |fields| &fields.shape).load(ctx, "shape"), + &ndarray_ptr.focus(ctx, |fields| &fields.shape).load(ctx, "shape"), ); // NOTE: DO NOT DO `set_strides_by_shape` HERE. // Simply this is because we specified that `SetShape` wouldn't do `set_strides_by_shape` - Ok(ndarray) + Ok(ndarray_ptr) } NDArrayInitMode::ShapeAndAllocaData { shape } => { let ndims = shape.count; - let ndarray = alloca_ndarray(generator, ctx, element_optic, ndims, name)?; + let ndarray_ptr = alloca_ndarray(generator, ctx, elem_type, ndims, name)?; // Fill `ndarray.shape` - (shape.write_to_array)(generator, ctx, &ndarray.shape_array(ctx))?; + (shape.write_to_array)(generator, ctx, &ndarray_ptr.shape_array(ctx))?; // Check if `shape` has bad inputs call_nac3_ndarray_util_assert_shape_no_negative( generator, ctx, ndims, - &ndarray.focus(ctx, |fields| &fields.shape).load(ctx, "shape"), + &ndarray_ptr.focus(ctx, |fields| &fields.shape).load(ctx, "shape"), ); // Now we populate `ndarray.data` by alloca-ing. // But first, we need to know the size of the ndarray to know how many elements to alloca, // since calculating nbytes of an ndarray requires `ndarray.shape` to be set. - let ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, &ndarray); + let ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, &ndarray_ptr); // Alloca `data` and assign it to `ndarray.data` let data_ptr = ctx.builder.build_array_alloca(ctx.ctx.i8_type(), ndarray_nbytes, "data").unwrap(); - ndarray.focus(ctx, |fields| &fields.data).store( + ndarray_ptr.focus(ctx, |fields| &fields.data).store( ctx, &Address { addressee_optic: IntLens::int8(ctx.ctx), address: data_ptr }, ); // Finally, do `set_strides_by_shape` // Check out https://ajcr.net/stride-guide-part-1/ to see what numpy "strides" are. - call_nac3_ndarray_set_strides_by_shape(generator, ctx, &ndarray); + call_nac3_ndarray_set_strides_by_shape(generator, ctx, &ndarray_ptr); - Ok(ndarray) + Ok(ndarray_ptr) } } } @@ -294,7 +299,7 @@ fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Siz generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ndims: IntValue<'ctx>, - shape: &Address<'ctx, IntLens<'ctx>>, + shape_ptr: &Address<'ctx, IntLens<'ctx>>, ) { let size_type = generator.get_size_type(ctx.ctx); @@ -308,19 +313,15 @@ fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Siz ) .arg("errctx", &AddressLens(ErrorContextLens), &errctx) .arg("ndims", &IntLens(size_type), &ndims) - .arg("shape", &AddressLens(IntLens(size_type)), shape) + .arg("shape", &AddressLens(IntLens(size_type)), shape_ptr) .returning_void(); check_error_context(generator, ctx, &errctx); } -fn call_nac3_ndarray_set_strides_by_shape< - 'ctx, - G: CodeGenerator + ?Sized, - ElementOptic: Optic<'ctx>, ->( +fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: &NpArray<'ctx, ElementOptic>, + ndarray_ptr: &Address<'ctx, NpArrayLens<'ctx>>, ) { let size_type = generator.get_size_type(ctx.ctx); @@ -331,14 +332,14 @@ fn call_nac3_ndarray_set_strides_by_shape< "__nac3_ndarray_util_assert_shape_no_negative", ), ) - .arg("ndarray", &AddressLens(ndarray.addressee_optic.clone()), ndarray) + .arg("ndarray", &AddressLens(NpArrayLens { size_type }), ndarray_ptr) .returning_void(); } -fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized, ElementOptic: Optic<'ctx>>( +fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: &NpArray<'ctx, ElementOptic>, + ndarray_ptr: &Address<'ctx, NpArrayLens<'ctx>>, ) -> IntValue<'ctx> { let size_type = generator.get_size_type(ctx.ctx); @@ -349,6 +350,23 @@ fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized, ElementOptic: Optic "__nac3_ndarray_util_assert_shape_no_negative", ), ) - .arg("ndarray", &AddressLens(ndarray.addressee_optic.clone()), ndarray) + .arg("ndarray", &AddressLens(NpArrayLens { size_type }), ndarray_ptr) .returning("nbytes", &IntLens(size_type)) } + +fn call_nac3_ndarray_fill_generic<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray_ptr: &Address<'ctx, NpArrayLens<'ctx>>, + fill_value_ptr: &Address<'ctx, IntLens<'ctx>>, +) { + let size_type = generator.get_size_type(ctx.ctx); + + FunctionBuilder::begin( + ctx, + &get_sized_dependent_function_name(size_type, "__nac3_ndarray_fill_generic"), + ) + .arg("ndarray", &AddressLens(NpArrayLens { size_type }), ndarray_ptr) + .arg("pvalue", &opaque_address_lens(ctx.ctx), fill_value_ptr) + .returning_void(); +} diff --git a/nac3core/src/codegen/optics.rs b/nac3core/src/codegen/optics.rs index ddc021a1..97824b5c 100644 --- a/nac3core/src/codegen/optics.rs +++ b/nac3core/src/codegen/optics.rs @@ -50,8 +50,6 @@ pub trait MemorySetter<'ctx>: Optic<'ctx> { fn set(&self, ctx: &CodeGenContext<'ctx, '_>, pointer: PointerValue<'ctx>, value: &Self::Value); } -pub trait SizedIntLens<'ctx>: Optic<'ctx, Value = IntValue<'ctx>> {} - // NOTE: I wanted to make Int8Lens, Int16Lens, Int32Lens, with all // having the trait IsIntLens, and implement `impl Optic for T`, // but that clashes with StructureOptic!! @@ -141,10 +139,9 @@ impl<'ctx, AddresseeOptic> OpticValue<'ctx> for Address<'ctx, AddresseeOptic> { #[derive(Debug, Clone)] pub struct AddressLens(pub AddresseeOptic); -impl AddressLens { - pub fn new_opaque<'ctx>(&self, ctx: &CodeGenContext<'ctx, '_>) -> AddressLens> { - AddressLens(IntLens::int8(ctx.ctx)) - } +#[must_use] +pub fn opaque_address_lens(ctx: &Context) -> AddressLens> { + AddressLens(IntLens::int8(ctx)) } impl<'ctx, AddresseeOptic: Optic<'ctx>> Optic<'ctx> for AddressLens {