diff --git a/nac3core/irrt/irrt_numpy_ndarray.hpp b/nac3core/irrt/irrt_numpy_ndarray.hpp index 5b289ac0..396093ea 100644 --- a/nac3core/irrt/irrt_numpy_ndarray.hpp +++ b/nac3core/irrt/irrt_numpy_ndarray.hpp @@ -8,9 +8,6 @@ NDArray-related implementations. `*/ -// NDArray indices are always `uint32_t`. -using NDIndex = uint32_t; - namespace { namespace ndarray_util { template @@ -105,14 +102,18 @@ namespace { } struct NDSlice { - // A poor-man's `std::variant` + // A poor-man's enum variant type NDSliceType type; /* if type == INPUT_SLICE_TYPE_INDEX => `slice` points to a single `SizeT` - if type == INPUT_SLICE_TYPE_SLICE => `slice` points to a single `UserRange` + if type == INPUT_SLICE_TYPE_SLICE => `slice` points to a single `UserRange` + + `SizeT` is controlled by the caller: `NDSlice` only cares about where that + slice is (the pointer), `NDSlice` does not care/know about the actual `sizeof()` + of the slice value. */ - uint8_t *slice; + uint8_t* slice; }; namespace ndarray_util { @@ -123,7 +124,7 @@ namespace { SizeT final_ndims = ndims; for (SizeT i = 0; i < num_slices; i++) { if (slices[i].type == INPUT_SLICE_TYPE_INDEX) { - final_ndims--; // An integer slice demotes the rank by 1 + final_ndims--; // An index demotes the rank by 1 } } return final_ndims; @@ -213,8 +214,7 @@ namespace { } void set_pelement_value(uint8_t* pelement, const uint8_t* pvalue) { - // *pelement = 0; - // __builtin_memcpy(pelement, pvalue, itemsize); + __builtin_memcpy(pelement, pvalue, itemsize); } uint8_t* get_pelement_by_indices(const SizeT *indices) { @@ -284,7 +284,13 @@ namespace { } } - // To support numpy complex slices (e.g., `my_array[:50:2,4,:2:-1]`) + // To support numpy "basic indexing" https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing + // "Advanced indexing" https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing is not supported + // + // This function supports: + // - "scalar indexing", + // - "slicing and strides", + // - and "dimensional indexing tools" (TODO, but this is really easy to implement). // // Things assumed by this function: // - `dst_ndarray` is allocated by the caller @@ -295,7 +301,7 @@ namespace { // - `dst_ndarray->data` does not have to be set, it will be derived. // - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->itemsize` // - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values - void slice(SizeT num_ndslices, NDSlice* ndslices, NDArray* dst_ndarray) { + void subscript(SizeT num_ndslices, NDSlice* ndslices, NDArray* dst_ndarray) { // REFERENCE CODE (check out `_index_helper` in `__getitem__`): // https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652 @@ -322,15 +328,15 @@ namespace { // Handle when the ndslice is a slice (represented by UserSlice in IRRT) // e.g., `my_array[::2, -5, ::-1]` // ^^^------^^^^----- like these - UserSlice* user_slice = (UserSlice*) ndslice->slice; - Slice slice = user_slice->indices(this->shape[this_axis]); // To resolve negative indices and other funny stuff written by the user + UserSlice* user_slice = (UserSlice*) ndslice->slice; + Slice slice = user_slice->indices(this->shape[this_axis]); // To resolve negative indices and other funny stuff written by the user // NOTE: There is no need to write special code to handle negative steps/strides. // This simple implementation meticulously handles both positive and negative steps/strides. // Check out the tinynumpy and IRRT's test cases if you are not convinced. - dst_ndarray->data += slice.start * this->strides[this_axis]; // Add offset (NOTE: no need to `* itemsize`, strides count in # of bytes) - dst_ndarray->strides[dst_axis] = slice.step * this->strides[this_axis]; // Determine stride - dst_ndarray->shape[dst_axis] = slice.len(); // Determine shape dimension + dst_ndarray->data += (SizeT) slice.start * this->strides[this_axis]; // Add offset (NOTE: no need to `* itemsize`, strides count in # of bytes) + dst_ndarray->strides[dst_axis] = ((SizeT) slice.step) * this->strides[this_axis]; // Determine stride + dst_ndarray->shape[dst_axis] = (SizeT) slice.len(); // Determine shape dimension // Next dst_axis++; @@ -426,7 +432,7 @@ namespace { for (SizeT i = 0; i < size; i++) { uint8_t* src_pelement = broadcasted_src_ndarray_strides->get_nth_pelement(i); uint8_t* this_pelement = this->get_nth_pelement(i); - this->set_pelement_value(src_pelement, src_pelement); + this->set_pelement_value(this_pelement, src_pelement); } } }; @@ -457,7 +463,19 @@ extern "C" { ndarray->fill_generic(pvalue); } - // void __nac3_ndarray_slice(NDArray* ndarray, int32_t num_slices, NDSlice *slices, NDArray *dst_ndarray) { - // // ndarray->slice(num_slices, slices, dst_ndarray); - // } + int32_t __nac3_ndarray_deduce_ndims_after_slicing(int32_t ndims, int32_t num_slices, const NDSlice* slices) { + return ndarray_util::deduce_ndims_after_slicing(ndims, num_slices, slices); + } + + int64_t __nac3_ndarray_deduce_ndims_after_slicing64(int64_t ndims, int64_t num_slices, const NDSlice* slices) { + return ndarray_util::deduce_ndims_after_slicing(ndims, num_slices, slices); + } + + void __nac3_ndarray_subscript(NDArray* ndarray, int32_t num_slices, NDSlice* slices, NDArray *dst_ndarray) { + ndarray->subscript(num_slices, slices, dst_ndarray); + } + + void __nac3_ndarray_subscript64(NDArray* ndarray, int32_t num_slices, NDSlice* slices, NDArray *dst_ndarray) { + ndarray->subscript(num_slices, slices, dst_ndarray); + } } \ No newline at end of file diff --git a/nac3core/irrt/irrt_slice.hpp b/nac3core/irrt/irrt_slice.hpp index 4a565245..ff048618 100644 --- a/nac3core/irrt/irrt_slice.hpp +++ b/nac3core/irrt/irrt_slice.hpp @@ -4,19 +4,15 @@ #include "irrt_typedefs.hpp" namespace { - // A proper slice in IRRT, all negative indices have be resolved to absolute values. - // Even though nac3core's slices are always `int32_t`, we will template slice anyway - // since this struct is used as a general utility. - template struct Slice { - T start; - T stop; - T step; + SliceIndex start; + SliceIndex stop; + SliceIndex step; // The length/The number of elements of the slice if it were a range, // i.e., the value of `len(range(this->start, this->stop, this->end))` - T len() { - T diff = stop - start; + SliceIndex len() { + SliceIndex diff = stop - start; if (diff > 0 && step > 0) { return ((diff - 1) / step) + 1; } else if (diff < 0 && step < 0) { @@ -27,38 +23,45 @@ namespace { } }; - template - T resolve_index_in_length(T length, T index) { + SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) { irrt_assert(length >= 0); if (index < 0) { // Remember that index is negative, so do a plus here - return max(length + index, 0); + return max(length + index, 0); } else { - return min(length, index); + return min(length, index); } } + // A user-written Python-like slice. + // + // i.e., this slice is a triple of either an int or nothing. (e.g., `my_array[:10:2]`, `start` is None) + // + // You can "resolve" a `UserSlice` by using `UserSlice::indices()` + // // NOTE: using a bitfield for the `*_defined` is better, at the // cost of a more annoying implementation in nac3core inkwell - template struct UserSlice { + // Did the user specify `start`? If 0, `start` is undefined (and contains an empty value) uint8_t start_defined; - T start; + SliceIndex start; + // Similar to `start_defined` uint8_t stop_defined; - T stop; + SliceIndex stop; + // Similar to `start_defined` uint8_t step_defined; - T step; + SliceIndex step; // Like Python's `slice(start, stop, step).indices(length)` - Slice indices(T length) { + Slice indices(SliceIndex length) { // NOTE: This function implements Python's `slice.indices` *FAITHFULLY*. // SEE: https://github.com/python/cpython/blob/f62161837e68c1c77961435f1b954412dd5c2b65/Objects/sliceobject.c#L546 irrt_assert(length >= 0); irrt_assert(!step_defined || step != 0); // step_defined -> step != 0; step cannot be zero if specified by user - Slice result; + Slice result; result.step = step_defined ? step : 1; bool step_is_negative = result.step < 0; diff --git a/nac3core/irrt/irrt_test.cpp b/nac3core/irrt/irrt_test.cpp index dfe4e9d7..c900ee8e 100644 --- a/nac3core/irrt/irrt_test.cpp +++ b/nac3core/irrt/irrt_test.cpp @@ -248,10 +248,10 @@ void test_ndarray_set_to_eye() { } void test_slice_1() { - // Test `slice(5, None, None).indices(100) == slice(5, 100, 1)` + // Test `subscript(5, None, None).indices(100) == subscript(5, 100, 1)` BEGIN_TEST(); - UserSlice user_slice = { + UserSlice user_slice = { .start_defined = 1, .start = 5, .stop_defined = 0, @@ -265,10 +265,10 @@ void test_slice_1() { } void test_slice_2() { - // Test `slice(400, 999, None).indices(100) == slice(100, 100, 1)` + // Test `subscript(400, 999, None).indices(100) == subscript(100, 100, 1)` BEGIN_TEST(); - UserSlice user_slice = { + UserSlice user_slice = { .start_defined = 1, .start = 400, .stop_defined = 0, @@ -282,10 +282,10 @@ void test_slice_2() { } void test_slice_3() { - // Test `slice(-10, -5, None).indices(100) == slice(90, 95, 1)` + // Test `subscript(-10, -5, None).indices(100) == subscript(90, 95, 1)` BEGIN_TEST(); - UserSlice user_slice = { + UserSlice user_slice = { .start_defined = 1, .start = -10, .stop_defined = 1, @@ -300,10 +300,10 @@ void test_slice_3() { } void test_slice_4() { - // Test `slice(None, None, -5).indices(100) == (99, -1, -5)` + // Test `subscript(None, None, -5).indices(100) == (99, -1, -5)` BEGIN_TEST(); - UserSlice user_slice = { + UserSlice user_slice = { .start_defined = 0, .stop_defined = 0, .step_defined = 1, @@ -366,14 +366,14 @@ void test_ndslice_1() { }; // Create the slice in `ndarray[-2::, 1::2]` - UserSlice user_slice_1 = { + UserSlice user_slice_1 = { .start_defined = 1, .start = -2, .stop_defined = 0, .step_defined = 0 }; - UserSlice user_slice_2 = { + UserSlice user_slice_2 = { .start_defined = 1, .start = 1, .stop_defined = 0, @@ -387,7 +387,7 @@ void test_ndslice_1() { { .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 } }; - ndarray.slice(num_ndslices, ndslices, &dst_ndarray); + ndarray.subscript(num_ndslices, ndslices, &dst_ndarray); int32_t expected_shape[dst_ndims] = { 2, 2 }; int32_t expected_strides[dst_ndims] = { 32, 16 }; @@ -450,7 +450,7 @@ void test_ndslice_2() { // Create the slice in `ndarray[2, ::-2]` int32_t user_slice_1 = 2; - UserSlice user_slice_2 = { + UserSlice user_slice_2 = { .start_defined = 0, .stop_defined = 0, .step_defined = 1, @@ -463,7 +463,7 @@ void test_ndslice_2() { { .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 } }; - ndarray.slice(num_ndslices, ndslices, &dst_ndarray); + ndarray.subscript(num_ndslices, ndslices, &dst_ndarray); int32_t expected_shape[dst_ndims] = { 2 }; int32_t expected_strides[dst_ndims] = { -16 }; diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index a88da1ab..de1fdbf9 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1793,7 +1793,7 @@ pub struct StructFields<'ctx> { pub fields: Vec>, } -struct StructFieldsBuilder<'ctx> { +pub struct StructFieldsBuilder<'ctx> { gep_index_counter: u32, /// Name of the struct to be built. name: &'static str, @@ -1909,11 +1909,11 @@ impl<'ctx> StructFields<'ctx> { } impl<'ctx> StructFieldsBuilder<'ctx> { - fn start(name: &'static str) -> Self { + pub fn start(name: &'static str) -> Self { StructFieldsBuilder { gep_index_counter: 0, name, fields: Vec::new() } } - fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> { + pub fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> { let index = self.gep_index_counter; self.gep_index_counter += 1; @@ -1923,11 +1923,12 @@ impl<'ctx> StructFieldsBuilder<'ctx> { field // Return to the caller to conveniently let them do whatever they want } - fn end(self) -> StructFields<'ctx> { + pub fn end(self) -> StructFields<'ctx> { StructFields { name: self.name, fields: self.fields } } } +// TODO: Use derppening's abstraction #[derive(Debug, Clone, Copy)] pub struct NpArrayType<'ctx> { pub size_type: IntType<'ctx>, @@ -1952,15 +1953,15 @@ impl<'ctx> NpArrayType<'ctx> { } pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> { - self.fields().whole_struct.get_struct_type(ctx) + self.fields(ctx).whole_struct.get_struct_type(ctx) } - pub fn fields(&self) -> NpArrayStructFields<'ctx> { + pub fn fields(&self, ctx: &'ctx Context) -> NpArrayStructFields<'ctx> { let mut builder = StructFieldsBuilder::start("NpArray"); let addrspace = AddressSpace::default(); - let byte_type = self.size_type.get_context().i8_type(); + let byte_type = ctx.i8_type(); // Make sure the struct matches PERFECTLY with that defined in `nac3core/irrt`. let data = builder.add_field("data", byte_type.ptr_type(addrspace).into()); @@ -2018,6 +2019,23 @@ impl<'ctx> NpArrayType<'ctx> { return value; } + + pub fn value_from_ptr( + &self, + ctx: &'ctx Context, + in_ndarray_ptr: PointerValue<'ctx>, + ) -> NpArrayValue<'ctx> { + if cfg!(debug_assertions) { + // Sanity check on `in_ndarray_ptr`'s type + + let in_ndarray_struct_type = + in_ndarray_ptr.get_type().get_element_type().into_struct_type(); + + // unwrap to check + self.fields(ctx).whole_struct.is_type(in_ndarray_struct_type).unwrap(); + } + NpArrayValue { ty: *self, ptr: in_ndarray_ptr } + } } #[derive(Debug, Clone, Copy)] @@ -2028,47 +2046,47 @@ pub struct NpArrayValue<'ctx> { impl<'ctx> NpArrayValue<'ctx> { pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, new_data_ptr: PointerValue<'ctx>) { - let field = self.ty.fields().data; + let field = self.ty.fields(ctx.ctx).data; field.store(ctx, self.ptr, new_data_ptr); } pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - let field = self.ty.fields().ndims; + let field = self.ty.fields(ctx.ctx).ndims; field.load(ctx, self.ptr).into_int_value() } pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, new_ndims: IntValue<'ctx>) { - let field = self.ty.fields().ndims; + let field = self.ty.fields(ctx.ctx).ndims; field.store(ctx, self.ptr, new_ndims); } pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - let field = self.ty.fields().itemsize; + let field = self.ty.fields(ctx.ctx).itemsize; field.load(ctx, self.ptr).into_int_value() } pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, new_itemsize: IntValue<'ctx>) { - let field = self.ty.fields().itemsize; + let field = self.ty.fields(ctx.ctx).itemsize; field.store(ctx, self.ptr, new_itemsize); } pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let field = self.ty.fields().shape; + let field = self.ty.fields(ctx.ctx).shape; field.load(ctx, self.ptr).into_pointer_value() } pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, new_shape_ptr: PointerValue<'ctx>) { - let field = self.ty.fields().shape; + let field = self.ty.fields(ctx.ctx).shape; field.store(ctx, self.ptr, new_shape_ptr); } pub fn load_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let field = self.ty.fields().strides; + let field = self.ty.fields(ctx.ctx).strides; field.load(ctx, self.ptr).into_pointer_value() } pub fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - let field = self.ty.fields().strides; + let field = self.ty.fields(ctx.ctx).strides; field.store(ctx, self.ptr, value); } @@ -2078,7 +2096,7 @@ impl<'ctx> NpArrayValue<'ctx> { ctx: &CodeGenContext<'ctx, '_>, ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { // Get the pointer to `shape` - let field = self.ty.fields().shape; + let field = self.ty.fields(ctx.ctx).shape; let shape = field.load(ctx, self.ptr).into_pointer_value(); // Load `ndims` @@ -2097,7 +2115,7 @@ impl<'ctx> NpArrayValue<'ctx> { ctx: &CodeGenContext<'ctx, '_>, ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { // Get the pointer to `strides` - let field = self.ty.fields().strides; + let field = self.ty.fields(ctx.ctx).strides; let strides = field.load(ctx, self.ptr).into_pointer_value(); // Load `ndims` diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 4785844f..25493590 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1,10 +1,14 @@ -use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; +use std::{ + collections::HashMap, + convert::TryInto, + iter::{once, zip}, +}; use crate::{ codegen::{ classes::{ - ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, - ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, + ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayDataProxy, NDArrayValue, + ProxyType, ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_abi_type, get_llvm_type, @@ -39,10 +43,12 @@ use inkwell::{ }; use itertools::{chain, izip, Either, Itertools}; use nac3parser::ast::{ - self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, - Unaryop, + self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Located, Location, Operator, + StrRef, Unaryop, }; +use super::classes::{NpArrayType, NpArrayValue}; + pub fn get_subst_key( unifier: &mut Unifier, obj: Option, @@ -2094,18 +2100,97 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( /// Generates code for a subscript expression on an `ndarray`. /// /// * `ty` - The `Type` of the `NDArray` elements. -/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`. -/// * `v` - The `NDArray` value. +/// * `ndarray` - The `NDArray` value. /// * `slice` - The slice expression used to subscript into the `ndarray`. fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ty: Type, - ndims: Type, - v: NDArrayValue<'ctx>, + ndarray: NpArrayValue<'ctx>, slice: &Expr>, ) -> Result>, String> { - todo!() + // TODO: bounds check (on IRRT (how?), or using inkwell) + // TODO: For invalid `slice`, throw a proper error + // TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools + + let size_type = ndarray.ty.size_type; + debug_assert_eq!(size_type, generator.get_size_type(ctx.ctx)); // The ndarray's size_type somehow isn't that of `generator.get_size_type()`... there would be a bug + + // Annoying notes about `slice` + // - `my_array[5]` + // - slice is a `Constant` + // - `my_array[:5]` + // - slice is a `Slice` + // - `my_array[:]` + // - slice is a `Slice`, but lower upper step would all be `Option::None` + // - `my_array[:, :]` + // - slice is now a `Tuple` of two `Slice`-s + // + // In summary: + // - when there is a comma "," within [], `slice` will be a `Tuple` of the entries. + // - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself. + // + // `entries` will flatten it out. + let entries = match &slice.node { + ExprKind::Tuple { elts, ctx } => elts.iter().collect_vec(), + _ => vec![slice], + }; + + // This could have been written as a `ndslices = entries.into_iter().map(...)`, + // but error shortcutting part would be annoying + let mut ndslices = vec![]; + for entry in entries.into_iter() { + // NOTE: Currently nac3core's slices do not have an object representation, + // so the code/implementation looks awkward - we have to do pattern matching on the expression + let ndslice = match &entry.node { + ExprKind::Slice { lower: start, upper: stop, step } => { + // Helper function here to deduce code duplication + let mut help = |value_expr: &Option< + Box>, Option>>, + >| + -> Result<_, String> { + Ok(match value_expr { + None => None, + Some(value_expr) => Some( + generator + .gen_expr(ctx, &value_expr)? + .unwrap() + .to_basic_value_enum(ctx, generator, ctx.primitives.int32)? + .into_int_value(), + ), + }) + }; + + let start = help(start)?; + let stop = help(stop)?; + let step = help(step)?; + + // NOTE: Now start stop step should all be 32-bit ints after typechecking + // ...and `IrrtUserSlice` expects `int32`s + NDSlice::Slice(UserSlice { start, stop, step }) + } + _ => { + // Anything else that is not a slice (might be illegal values), + // For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error + + let index = generator + .gen_expr(ctx, entry)? + .unwrap() + .to_basic_value_enum(ctx, generator, ctx.primitives.int32)? + .into_int_value(); + + NDSlice::Index(index) + } + }; + ndslices.push(ndslice); + } + + // Finally, perform the actual subscript logic + let subndarray = call_nac3_ndarray_subscript_and_alloc_dst(generator, ctx, ndarray, &ndslices.iter().collect_vec()); + + // ...and return the result + let result = ValueEnum::Dynamic(subndarray.ptr.into()); + Ok(Some(result)) // let llvm_i1 = ctx.ctx.bool_type(); // let llvm_i32 = ctx.ctx.i32_type(); @@ -3031,17 +3116,27 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } } TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { - let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); + let (elem_ty, _) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); - let v = if let Some(v) = generator.gen_expr(ctx, value)? { + // Get the pointer to the ndarray described by `value` + let ndarray_ptr = if let Some(v) = generator.gen_expr(ctx, value)? { v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? .into_pointer_value() } else { return Ok(None); }; - let v = NDArrayValue::from_ptr_val(v, usize, None); - return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); + // Derive the current NDArray struct type independently... + let ndarray_ty = NpArrayType { + elem_type: ctx.get_llvm_type(generator, *elem_ty), + size_type: generator.get_size_type(ctx.ctx), + }; + + // ...and wrap it in `NDArrayValue` + let ndarray = ndarray_ty.value_from_ptr(ctx.ctx, ndarray_ptr); + + // Implementation + return gen_ndarray_subscript_expr(generator, ctx, *elem_ty, ndarray, slice); } TypeEnum::TTuple { .. } => { let index: u32 = diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 7a8c44d9..0df3bc50 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,11 +1,18 @@ -use crate::{typecheck::typedef::Type, util::SizeVariant}; +use std::ops::Deref; + +use crate::{ + codegen::classes::{NDArrayType, StructFieldsBuilder}, + typecheck::typedef::Type, + util::SizeVariant, +}; mod test; use super::{ classes::{ check_basic_types_match, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, - NDArrayValue, NpArrayType, NpArrayValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, + NDArrayValue, NpArrayType, NpArrayValue, StructField, StructFields, TypedArrayLikeAdapter, + UntypedArrayLikeAccessor, }, llvm_intrinsics, CodeGenContext, CodeGenerator, }; @@ -16,8 +23,11 @@ use inkwell::{ context::Context, memory_buffer::MemoryBuffer, module::Module, - types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType}, - values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue}, + types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType, StructType}, + values::{ + BasicValue, BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue, + PointerValue, + }, AddressSpace, IntPredicate, }; use itertools::Either; @@ -930,6 +940,202 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // ) // } +pub fn get_sliceindex_type<'ctx>(ctx: &'ctx Context) -> IntType<'ctx> { + ctx.i32_type() +} + +pub fn get_ndslicetype_constant_type<'ctx>(ctx: &'ctx Context) -> IntType<'ctx> { + ctx.i8_type() +} + +// TODO: Move to classes.rs? +/// NOTE: All `IntValue<'ctx>` must be `int32_t` +pub struct UserSlice<'ctx> { + pub start: Option>, + pub stop: Option>, + pub step: Option>, +} + +pub struct IrrtUserSliceStructFields<'ctx> { + pub whole_struct: StructFields<'ctx>, + + pub start_defined: StructField<'ctx>, + pub start: StructField<'ctx>, + + pub stop_defined: StructField<'ctx>, + pub stop: StructField<'ctx>, + + pub step_defined: StructField<'ctx>, + pub step: StructField<'ctx>, +} + +// TODO: EMPTY STRUCT +struct IrrtUserSlice {} + +impl IrrtUserSlice { + pub fn fields<'ctx>(ctx: &'ctx Context) -> IrrtUserSliceStructFields<'ctx> { + let int8 = ctx.i8_type(); + + // MUST match the corresponding struct defined in IRRT + let mut builder = StructFieldsBuilder::start("NDSlice"); + let start_defined = builder.add_field("start_defined", int8.into()); + let start = builder.add_field("start", get_sliceindex_type(ctx).into()); + let stop_defined = builder.add_field("stop_defined", int8.into()); + let stop = builder.add_field("stop", get_sliceindex_type(ctx).into()); + let step_defined = builder.add_field("step_defined", int8.into()); + let step = builder.add_field("step", get_sliceindex_type(ctx).into()); + + IrrtUserSliceStructFields { + start_defined, + start, + stop_defined, + stop, + step_defined, + step, + whole_struct: builder.end(), + } + } + + pub fn alloca_user_slice<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + user_slice: &UserSlice<'ctx>, + ) -> PointerValue<'ctx> { + // Derive the struct_type + let fields = Self::fields(ctx.ctx); + let struct_type = fields.whole_struct.get_struct_type(ctx.ctx); + + // ...and then allocate for a real `UserSlice` in LLVM + let user_slice_ptr = ctx.builder.build_alloca(struct_type, "user_slice").unwrap(); + + // i8 type to set start_defined, stop_defined, step_defined. + let llvm_i8 = ctx.ctx.i8_type(); + + // Now write to `user_slice_ptr` + let help = |value_defined_field: StructField<'ctx>, + value_field: StructField<'ctx>, + value: Option>| { + match value { + None => { + value_defined_field.store(ctx, user_slice_ptr, llvm_i8.const_zero()); + value_field.store( + ctx, + user_slice_ptr, + get_sliceindex_type(ctx.ctx).const_zero(), + ); + } + Some(value) => { + debug_assert_eq!(get_sliceindex_type(ctx.ctx), value.get_type()); // Sanity check just in case there is a bug somewhere + value_defined_field.store(ctx, user_slice_ptr, llvm_i8.const_int(1, false)); + value_field.store(ctx, user_slice_ptr, value); + } + } + }; + + help(fields.start_defined, fields.start, user_slice.start); + help(fields.stop_defined, fields.stop, user_slice.stop); + help(fields.step_defined, fields.step, user_slice.step); + + user_slice_ptr + } +} + +// TODO: Move to classes.rs? +/// A numpy slice. This corresponds to `NDSlice` defined in IRRT, +/// but with more Rust sugar to help with implementing codegen. +pub enum NDSlice<'ctx> { + /// Index [`IntValue`] must be `int32_t` + Index(IntValue<'ctx>), + Slice(UserSlice<'ctx>), + // TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools; *should* be very easy to implement +} + +// TODO: Empty struct +pub struct IrrtNDSlice {} + +pub struct IrrtNDSliceStructFields<'ctx> { + pub whole_struct: StructFields<'ctx>, + pub type_: StructField<'ctx>, + pub slice: StructField<'ctx>, +} + +impl IrrtNDSlice { + pub fn fields<'ctx>(ctx: &'ctx Context) -> IrrtNDSliceStructFields<'ctx> { + let mut builder = StructFieldsBuilder::start("NDSlice"); + + // MUST match the corresponding struct defined in IRRT + let type_ = builder.add_field("type", get_ndslicetype_constant_type(ctx).into()); + let slice = builder.add_field("slice", get_opaque_uint8_ptr_type(ctx).into()); + + IrrtNDSliceStructFields { type_, slice, whole_struct: builder.end() } + } + + pub fn alloca_ndslices<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + ndslices: &Vec<&NDSlice<'ctx>>, + ) -> PointerValue<'ctx> { + let fields = Self::fields(ctx.ctx); + + // Derive the number of slices + let count = ndslices.len(); + let count_llvm = ctx.ctx.i32_type().const_int(count as u64, false); + + // Allocate `count` number of ndslices, and ready to write + let struct_type = fields.whole_struct.get_struct_type(ctx.ctx); + let ndslices_ptr = + ctx.builder.build_array_alloca(struct_type, count_llvm, "ndslices").unwrap(); + + for (ndslice_i, ndslice) in ndslices.iter().enumerate() { + // Setup the values that will build a real `NDSlice` in LLVM + // NOTE: (A INPUT_SLICE_TYPE_* constant value for `NDSlice::type`, the ptr for `NDArray::slice`) + let (type_, slice_ptr): (u8, PointerValue<'ctx>) = match ndslice { + NDSlice::Index(index) => { + // Sanity check on `index.get_type()` just in case there is a bug somewhere + debug_assert_eq!(get_sliceindex_type(ctx.ctx), index.get_type()); + + // Refer to the IRRT to see what should be set to `uint8_t* slice`. + let slice_ptr = ctx + .builder + .build_alloca(get_sliceindex_type(ctx.ctx), "index_slice") + .unwrap(); + ctx.builder.build_store(slice_ptr, *index).unwrap(); + + let type_ = 0; // const NDSliceType INPUT_SLICE_TYPE_INDEX = 0; + (type_, slice_ptr) + } + NDSlice::Slice(user_slice) => { + // Allocate the user_slice + let slice_ptr = IrrtUserSlice::alloca_user_slice(ctx, user_slice); + + let type_ = 1; // const NDSliceType INPUT_SLICE_TYPE_SLICE = 1; + (type_, slice_ptr) + } + }; + + // Get the pointer to the ndslice_i-th entry of `ndslices_llvm` and write to it. + let gep_index = ctx.ctx.i32_type().const_int(ndslice_i as u64, false); + let ndslice_entry_ptr = unsafe { + ctx.builder + .build_in_bounds_gep(ndslices_ptr, &[gep_index], "ndslice_entry") + .unwrap() + }; + + // Write `type_` + let type_llvm = get_ndslicetype_constant_type(ctx.ctx).const_int(type_ as u64, false); + fields.type_.store(ctx, ndslice_entry_ptr, type_llvm); + + // Write `slice` + // `slice_ptr` has to be casted to `uint8_t*` beforehand + let slice_ptr = ctx + .builder + .build_pointer_cast(slice_ptr, get_opaque_uint8_ptr_type(ctx.ctx), "slices_casted") + .unwrap(); + fields.slice.store(ctx, ndslice_entry_ptr, slice_ptr); + } + + ndslices_ptr + } +} + fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant { match ty.get_bit_width() { 32 => SizeVariant::Bits32, @@ -966,18 +1172,18 @@ where } pub fn get_irrt_ndarray_ptr_type<'ctx>( - ctx: &CodeGenContext<'ctx, '_>, + ctx: &'ctx Context, size_type: IntType<'ctx>, ) -> PointerType<'ctx> { - let i8_type = ctx.ctx.i8_type(); + let i8_type = ctx.i8_type(); let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() }; - let struct_ty = ndarray_ty.get_struct_type(ctx.ctx); + let struct_ty = ndarray_ty.get_struct_type(ctx); struct_ty.ptr_type(AddressSpace::default()) } -pub fn get_irrt_opaque_uint8_ptr_type<'ctx>(ctx: &CodeGenContext<'ctx, '_>) -> PointerType<'ctx> { - ctx.ctx.i8_type().ptr_type(AddressSpace::default()) +pub fn get_opaque_uint8_ptr_type<'ctx>(ctx: &'ctx Context) -> PointerType<'ctx> { + ctx.i8_type().ptr_type(AddressSpace::default()) } pub fn call_nac3_ndarray_size<'ctx>( @@ -987,7 +1193,7 @@ pub fn call_nac3_ndarray_size<'ctx>( // Get the IRRT function let size_type = ndarray.ty.size_type; let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_size", || { - size_type.fn_type(&[get_irrt_ndarray_ptr_type(ctx, size_type).into()], false) + size_type.fn_type(&[get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into()], false) }); // Call the IRRT function @@ -1014,8 +1220,8 @@ pub fn call_nac3_ndarray_fill_generic<'ctx>( get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_fill_generic", || { ctx.ctx.void_type().fn_type( &[ - get_irrt_ndarray_ptr_type(ctx, size_type).into(), // NDArray* ndarray - get_irrt_opaque_uint8_ptr_type(ctx).into(), // uint8_t* pvalue + get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray* ndarray + get_opaque_uint8_ptr_type(ctx.ctx).into(), // uint8_t* pvalue ], false, ) @@ -1026,7 +1232,8 @@ pub fn call_nac3_ndarray_fill_generic<'ctx>( ctx.builder.build_store(pvalue, fill_value).unwrap(); // Cast pvalue to `uint8_t*` - let pvalue = ctx.builder.build_pointer_cast(pvalue, get_irrt_opaque_uint8_ptr_type(ctx), "").unwrap(); + let pvalue = + ctx.builder.build_pointer_cast(pvalue, get_opaque_uint8_ptr_type(ctx.ctx), "").unwrap(); // Call the IRRT function ctx.builder @@ -1041,22 +1248,25 @@ pub fn call_nac3_ndarray_fill_generic<'ctx>( .unwrap(); } - pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NpArrayValue<'ctx>, ) { // Get the IRRT function let size_type = ndarray.ty.size_type; - let function = - get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_set_strides_by_shape", || { + let function = get_size_type_dependent_function( + ctx, + size_type, + "__nac3_ndarray_set_strides_by_shape", + || { ctx.ctx.void_type().fn_type( &[ - get_irrt_ndarray_ptr_type(ctx, size_type).into(), // NDArray* ndarray + get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray* ndarray ], false, ) - }); + }, + ); // Call the IRRT function ctx.builder @@ -1069,3 +1279,132 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( ) .unwrap(); } + +pub fn call_nac3_ndarray_deduce_ndims_after_slicing_raw<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + size_type: IntType<'ctx>, + ndims: IntValue<'ctx>, + num_slices: IntValue<'ctx>, + ndslices_ptr: PointerValue<'ctx>, +) -> IntValue<'ctx> { + // Get the IRRT function + let function = get_size_type_dependent_function( + ctx, + size_type, + "__nac3_ndarray_deduce_ndims_after_slicing", + || { + size_type.fn_type( + &[ + size_type.into(), // SizeT ndims + size_type.into(), // SizeT num_slices + IrrtNDSlice::fields(ctx.ctx) + .whole_struct + .get_struct_type(ctx.ctx) + .ptr_type(AddressSpace::default()) + .into(), // NDSlice* slices + ], + false, + ) + }, + ); + + // Call the IRRT function + ctx.builder + .build_call( + function, + &[ + ndims.into(), // ndims + num_slices.into(), // num_slices + ndslices_ptr.into(), // slices + ], + "ndims_after_slicing", + ) + .unwrap() + .try_as_basic_value() + .unwrap_left() + .into_int_value() +} + +// TODO: RENAME ME AND MY FRIENDS +pub fn call_nac3_ndarray_subscript_raw<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NpArrayValue<'ctx>, + num_slices: IntValue<'ctx>, + slices: PointerValue<'ctx>, + dst_ndarray: NpArrayValue<'ctx>, +) { + // Get the IRRT function + let size_type = ndarray.ty.size_type; + let function = + get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_subscript", || { + ctx.ctx.void_type().fn_type( + &[ + get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray* ndarray + size_type.into(), // SizeT num_slices + IrrtNDSlice::fields(ctx.ctx) + .whole_struct + .get_struct_type(ctx.ctx) + .ptr_type(AddressSpace::default()) + .into(), // NDSlice* slices + get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray* dst_ndarray + ], + false, + ) + }); + + // Call the IRRT function + ctx.builder + .build_call( + function, + &[ + ndarray.ptr.into(), // ndarray + num_slices.into(), // num_slices + slices.into(), // slices + dst_ndarray.ptr.into(), // dst_ndarray + ], + "subndarray", + ) + .unwrap(); +} + +pub fn call_nac3_ndarray_subscript_and_alloc_dst<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NpArrayValue<'ctx>, + ndslices: &Vec<&NDSlice<'ctx>>, +) -> NpArrayValue<'ctx> +where + G: CodeGenerator + ?Sized, +{ + // First we will calculate the correct ndims of the dst_ndarray + // Then allocate for dst_ndarray (A known `ndims` value is required for this) + // Finally do call the IRRT function that actually does subscript + + let size_type = ndarray.ty.size_type; + + // Prepare the argument `ndims` + let ndims = ndarray.load_ndims(ctx); + + // Prepare the argument `num_slices` in LLVM - which conveniently is simply `ndslices.len()` + let num_slices = size_type.const_int(ndslices.len() as u64, false); + + // Prepare the argument `slices` + let ndslices_ptr = IrrtNDSlice::alloca_ndslices(ctx, ndslices); + + // Deduce the ndims + let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing_raw( + ctx, + ndarray.ty.size_type, + ndims, + num_slices, + ndslices_ptr, + ); + + // Allocate `dst_ndarray` + let dst_ndarray = + ndarray.ty.var_alloc(generator, ctx, dst_ndims, Some("subscript_dst_ndarray")); + + call_nac3_ndarray_subscript_raw(ctx, ndarray, num_slices, ndslices_ptr, dst_ndarray); + + dst_ndarray +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 03d0f100..c4c80b35 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -37,7 +37,7 @@ use super::{ classes::NpArrayValue, irrt::{ call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, get_irrt_ndarray_ptr_type, - get_irrt_opaque_uint8_ptr_type, + get_opaque_uint8_ptr_type, }, stmt::gen_return, }; @@ -2306,7 +2306,7 @@ where // We also have to cast `data_ptr` to `uint8_t*` because that is what `NDArray` has. let data_ptr = ctx .builder - .build_pointer_cast(data_ptr, get_irrt_opaque_uint8_ptr_type(ctx), "data_casted") + .build_pointer_cast(data_ptr, get_opaque_uint8_ptr_type(ctx.ctx), "data_casted") .unwrap(); ndarray.store_data(ctx, data_ptr); @@ -2410,7 +2410,7 @@ pub fn gen_ndarray_zeros<'ctx>( let ndarray = call_ndarray_fill_impl( generator, context, - float64_ty, + float64_ty, // `elem_ty` is always `float64` shape, shape_ty, float64_llvm_type.const_zero().as_basic_value_enum(), @@ -2442,11 +2442,11 @@ pub fn gen_ndarray_ones<'ctx>( let ndarray = call_ndarray_fill_impl( generator, context, - float64_ty, + float64_ty, // `elem_ty` is always `float64` shape, shape_ty, float64_llvm_type.const_float(1.0).as_basic_value_enum(), - Some("np_zeros.result"), + Some("np_ones.result"), )?; Ok(ndarray.ptr) } diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 5d489481..f9e1844b 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -442,11 +442,11 @@ fn test_classes_range_type_new() { // fn test_classes_ndarray_type_new() { // let ctx = inkwell::context::Context::create(); // let generator = DefaultCodeGenerator::new(String::new(), 64); -// +// // let llvm_i32 = ctx.i32_type(); // let llvm_usize = generator.get_size_type(&ctx); -// +// // let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into()); // assert!(NDArrayType::is_type(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); // } -// \ No newline at end of file +//