From 8f9d2d82dda593a216be1f6c5ce5b2db681a1068 Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 21 Aug 2024 13:43:07 +0800 Subject: [PATCH] core/ndstrides: implement ndarray indexing The functionality for `...` and `np.newaxis` is there in IRRT, but there is no implementation of them for @kernel Python expressions because of https://git.m-labs.hk/M-Labs/nac3/issues/486. --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/indexing.hpp | 220 +++++++++++ nac3core/src/codegen/expr.rs | 370 ++---------------- nac3core/src/codegen/irrt/mod.rs | 19 +- nac3core/src/codegen/object/mod.rs | 1 + .../src/codegen/object/ndarray/indexing.rs | 226 +++++++++++ nac3core/src/codegen/object/ndarray/mod.rs | 45 ++- nac3core/src/codegen/object/utils/mod.rs | 1 + nac3core/src/codegen/object/utils/slice.rs | 125 ++++++ 9 files changed, 660 insertions(+), 350 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/indexing.hpp create mode 100644 nac3core/src/codegen/object/ndarray/indexing.rs create mode 100644 nac3core/src/codegen/object/utils/mod.rs create mode 100644 nac3core/src/codegen/object/utils/slice.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 43f15e8..b586ce5 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -7,4 +7,5 @@ #include "irrt/slice.hpp" #include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/def.hpp" -#include "irrt/ndarray/iter.hpp" \ No newline at end of file +#include "irrt/ndarray/iter.hpp" +#include "irrt/ndarray/indexing.hpp" diff --git a/nac3core/irrt/irrt/ndarray/indexing.hpp b/nac3core/irrt/irrt/ndarray/indexing.hpp new file mode 100644 index 0000000..b2597d0 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/indexing.hpp @@ -0,0 +1,220 @@ +#pragma once + +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/def.hpp" +#include "irrt/range.hpp" +#include "irrt/slice.hpp" + +namespace { +typedef uint8_t NDIndexType; + +/** + * @brief A single element index + * + * `data` points to a `int32_t`. + */ +const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0; + +/** + * @brief A slice index + * + * `data` points to a `Slice`. + */ +const NDIndexType ND_INDEX_TYPE_SLICE = 1; + +/** + * @brief `np.newaxis` / `None` + * + * `data` is unused. + */ +const NDIndexType ND_INDEX_TYPE_NEWAXIS = 2; + +/** + * @brief `Ellipsis` / `...` + * + * `data` is unused. + */ +const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3; + +/** + * @brief An index used in ndarray indexing + * + * That is: + * ``` + * my_ndarray[::-1, 3, ..., np.newaxis] + * ^^^^ ^ ^^^ ^^^^^^^^^^ each of these is represented by an NDIndex. + * ``` + */ +struct NDIndex { + /** + * @brief Enum tag to specify the type of index. + * + * Please see the comment of each enum constant. + */ + NDIndexType type; + + /** + * @brief The accompanying data associated with `type`. + * + * Please see the comment of each enum constant. + */ + uint8_t* data; +}; +} // namespace + +namespace { +namespace ndarray { +namespace indexing { +/** + * @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) + * + * This function is very similar to performing `dst_ndarray = src_ndarray[indices]` in Python. + * + * This function also does proper assertions on `indices` to check for out of bounds access and more. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after + * indexing `src_ndarray` with `indices`. + * - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data`. + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`. + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed. + * - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works. + * + * @param indices indices to index `src_ndarray`, ordered in the same way you would write them in Python. + * @param src_ndarray The NDArray to be indexed. + * @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above, + */ +template +void index(SizeT num_indices, const NDIndex* indices, const NDArray* src_ndarray, NDArray* dst_ndarray) { + // Validate `indices`. + + // Expected value of `dst_ndarray->ndims`. + SizeT expected_dst_ndims = src_ndarray->ndims; + // To check for "too many indices for array: array is ?-dimensional, but ? were indexed" + SizeT num_indexed = 0; + // There may be ellipsis `...` in `indices`. There can only be 0 or 1 ellipsis. + SizeT num_ellipsis = 0; + + for (SizeT i = 0; i < num_indices; i++) { + if (indices[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) { + expected_dst_ndims--; + num_indexed++; + } else if (indices[i].type == ND_INDEX_TYPE_SLICE) { + num_indexed++; + } else if (indices[i].type == ND_INDEX_TYPE_NEWAXIS) { + expected_dst_ndims++; + } else if (indices[i].type == ND_INDEX_TYPE_ELLIPSIS) { + num_ellipsis++; + if (num_ellipsis > 1) { + raise_exception(SizeT, EXN_INDEX_ERROR, "an index can only have a single ellipsis ('...')", NO_PARAM, + NO_PARAM, NO_PARAM); + } + } else { + __builtin_unreachable(); + } + } + + debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims); + + if (src_ndarray->ndims - num_indexed < 0) { + raise_exception(SizeT, EXN_INDEX_ERROR, + "too many indices for array: array is {0}-dimensional, " + "but {1} were indexed", + src_ndarray->ndims, num_indices, NO_PARAM); + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // Reference code: + // https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652 + SizeT src_axis = 0; + SizeT dst_axis = 0; + + for (int32_t i = 0; i < num_indices; i++) { + const NDIndex* index = &indices[i]; + if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) { + SizeT input = (SizeT) * ((int32_t*)index->data); + + SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input); + if (k == -1) { + raise_exception(SizeT, EXN_INDEX_ERROR, + "index {0} is out of bounds for axis {1} " + "with size {2}", + input, src_axis, src_ndarray->shape[src_axis]); + } + + dst_ndarray->data += k * src_ndarray->strides[src_axis]; + + src_axis++; + } else if (index->type == ND_INDEX_TYPE_SLICE) { + Slice* slice = (Slice*)index->data; + + Range range = slice->indices_checked(src_ndarray->shape[src_axis]); + + dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis]; + dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis]; + dst_ndarray->shape[dst_axis] = (SizeT)range.len(); + + dst_axis++; + src_axis++; + } else if (index->type == ND_INDEX_TYPE_NEWAXIS) { + dst_ndarray->strides[dst_axis] = 0; + dst_ndarray->shape[dst_axis] = 1; + + dst_axis++; + } else if (index->type == ND_INDEX_TYPE_ELLIPSIS) { + // The number of ':' entries this '...' implies. + SizeT ellipsis_size = src_ndarray->ndims - num_indexed; + + for (SizeT j = 0; j < ellipsis_size; j++) { + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis]; + + dst_axis++; + src_axis++; + } + } else { + __builtin_unreachable(); + } + } + + for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) { + dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis]; + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + + debug_assert_eq(SizeT, src_ndarray->ndims, src_axis); + debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis); +} +} // namespace indexing +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::indexing; + +void __nac3_ndarray_index(int32_t num_indices, + NDIndex* indices, + NDArray* src_ndarray, + NDArray* dst_ndarray) { + index(num_indices, indices, src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_index64(int64_t num_indices, + NDIndex* indices, + NDArray* src_ndarray, + NDArray* dst_ndarray) { + index(num_indices, indices, src_ndarray, dst_ndarray); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8a01042..84a4c3f 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -21,7 +21,7 @@ use nac3parser::ast::{ use super::{ classes::{ ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, ProxyValue, - RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, + RangeValue, UntypedArrayLikeAccessor, }, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name, @@ -32,6 +32,10 @@ use super::{ }, macros::codegen_unreachable, need_sret, numpy, + object::{ + any::AnyObject, + ndarray::{indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject}, + }, stmt::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, @@ -40,11 +44,7 @@ use super::{ }; use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, - toplevel::{ - helper::PrimDef, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, - DefinitionId, TopLevelDef, - }, + toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, typecheck::{ magic_methods::{Binop, BinopVariant, HasOpInfo}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, @@ -2505,338 +2505,6 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( ) } -/// Generates code for a subscript expression on an `ndarray`. -/// -/// * `ty` - The `Type` of the `NDArray` elements. -/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`. -/// * `v` - The `NDArray` value. -/// * `slice` - The slice expression used to subscript into the `ndarray`. -fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ty: Type, - ndims: Type, - v: NDArrayValue<'ctx>, - slice: &Expr>, -) -> Result>, String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else { - codegen_unreachable!(ctx) - }; - - let ndims = values - .iter() - .map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone())) - .collect::, _>>() - .map_err(|val| { - format!( - "Expected non-negative literal for ndarray.ndims, got {}", - i128::try_from(val).unwrap() - ) - })?; - - assert!(!ndims.is_empty()); - - // The number of dimensions subscripted by the index expression. - // Slicing a ndarray will yield the same number of dimensions, whereas indexing into a - // dimension will remove a dimension. - let subscripted_dims = match &slice.node { - ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| { - if let ExprKind::Slice { .. } = &value_subexpr.node { - acc - } else { - acc + 1 - } - }), - - ExprKind::Slice { .. } => 0, - _ => 1, - }; - - let ndarray_ndims_ty = ctx.unifier.get_fresh_literal( - ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(), - None, - ); - let ndarray_ty = - make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty)); - let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); - let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); - let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); - let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap(); - - // Check that len is non-zero - let len = v.load_ndims(ctx); - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(), - "0:IndexError", - "too many indices for array: array is {0}-dimensional but 1 were indexed", - [Some(len), None, None], - slice.location, - ); - - // Normalizes a possibly-negative index to its corresponding positive index - let normalize_index = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - dim: u64| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "") - .unwrap()) - }, - |_, _| Ok(Some(index)), - |generator, ctx| { - let llvm_i32 = ctx.ctx.i32_type(); - - let len = unsafe { - v.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, true), - None, - ) - }; - - let index = ctx - .builder - .build_int_add( - len, - ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(), - "", - ) - .unwrap(); - - Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap())) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value)) - }; - - // Converts a slice expression into a slice-range tuple - let expr_to_slice = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - node: &ExprKind>, - dim: u64| { - match node { - ExprKind::Constant { value: Constant::Int(v), .. } => { - let Some(index) = - normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)? - else { - return Ok(None); - }; - - Ok(Some((index, index, llvm_i32.const_int(1, true)))) - } - - ExprKind::Slice { lower, upper, step } => { - let dim_sz = unsafe { - v.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, false), - None, - ) - }; - - handle_slice_indices(lower, upper, step, ctx, generator, dim_sz) - } - - _ => { - let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) }; - let index = index - .to_basic_value_enum(ctx, generator, slice.custom.unwrap())? - .into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, dim)? else { - return Ok(None); - }; - - Ok(Some((index, index, llvm_i32.const_int(1, true)))) - } - } - }; - - let make_indices_arr = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>| - -> Result<_, String> { - Ok(if let ExprKind::Tuple { elts, .. } = &slice.node { - let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap()); - let index_addr = generator.gen_array_var_alloc( - ctx, - llvm_int_ty, - llvm_usize.const_int(elts.len() as u64, false), - None, - )?; - - for (i, elt) in elts.iter().enumerate() { - let Some(index) = generator.gen_expr(ctx, elt)? else { - return Ok(None); - }; - - let index = index - .to_basic_value_enum(ctx, generator, elt.custom.unwrap())? - .into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, 0)? else { - return Ok(None); - }; - - let store_ptr = unsafe { - index_addr.ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - None, - ) - }; - ctx.builder.build_store(store_ptr, index).unwrap(); - } - - Some(index_addr) - } else if let Some(index) = generator.gen_expr(ctx, slice)? { - let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap()); - let index_addr = generator.gen_array_var_alloc( - ctx, - llvm_int_ty, - llvm_usize.const_int(1u64, false), - None, - )?; - - let index = - index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) }; - - let store_ptr = unsafe { - index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - ctx.builder.build_store(store_ptr, index).unwrap(); - - Some(index_addr) - } else { - None - }) - }; - - Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 { - let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; - - v.data().get(ctx, generator, &index_addr, None).into() - } else { - match &slice.node { - ExprKind::Tuple { elts, .. } => { - let slices = elts - .iter() - .enumerate() - .map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64)) - .take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some)) - .collect::, _>>()?; - if slices.len() < elts.len() { - return Ok(None); - } - - let slices = slices.into_iter().map(Option::unwrap).collect_vec(); - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into() - } - - ExprKind::Slice { .. } => { - let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else { - return Ok(None); - }; - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into() - } - - _ => { - // Accessing an element from a multi-dimensional `ndarray` - - let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; - - // Create a new array, remove the top dimension from the dimension-size-list, and copy the - // elements over - let subscripted_ndarray = - generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None); - - let num_dims = v.load_ndims(ctx); - ndarray.store_ndims( - ctx, - generator, - ctx.builder - .build_int_sub(num_dims, llvm_usize.const_int(1, false), "") - .unwrap(), - ); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); - - let ndarray_num_dims = ctx - .builder - .build_int_z_extend_or_bit_cast( - ndarray.load_ndims(ctx), - llvm_usize.size_of().get_type(), - "", - ) - .unwrap(); - let v_dims_src_ptr = unsafe { - v.dim_sizes().ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - call_memcpy_generic( - ctx, - ndarray.dim_sizes().base_ptr(ctx, generator), - v_dims_src_ptr, - ctx.builder - .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), - (None, None), - ); - let ndarray_num_elems = ctx - .builder - .build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "") - .unwrap(); - ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); - - let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None); - call_memcpy_generic( - ctx, - ndarray.data().base_ptr(ctx, generator), - v_data_src_ptr, - ctx.builder - .build_int_mul( - ndarray_num_elems, - llvm_ndarray_data_t.size_of().unwrap(), - "", - ) - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - ndarray.as_base_value().into() - } - } - })) -} - /// See [`CodeGenerator::gen_expr`]. pub fn gen_expr<'ctx, G: CodeGenerator>( generator: &mut G, @@ -3493,18 +3161,26 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v.data().get(ctx, generator, &index, None).into() } } - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { - let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); - - let v = if let Some(v) = generator.gen_expr(ctx, value)? { - v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? - .into_pointer_value() - } else { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + let Some(ndarray) = generator.gen_expr(ctx, value)? else { return Ok(None); }; - let v = NDArrayValue::from_ptr_val(v, usize, None); - return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); + let ndarray_ty = value.custom.unwrap(); + let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let ndarray = NDArrayObject::from_object( + generator, + ctx, + AnyObject { ty: ndarray_ty, value: ndarray }, + ); + + let indices = gen_ndarray_subscript_ndindices(generator, ctx, slice)?; + let result = ndarray + .index(generator, ctx, &indices) + .split_unsized(generator, ctx) + .to_basic_value_enum(); + return Ok(Some(ValueEnum::Dynamic(result))); } TypeEnum::TTuple { .. } => { let index: u32 = diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 7f2cd68..94996cd 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -19,7 +19,7 @@ use super::{ llvm_intrinsics, macros::codegen_unreachable, model::{function::FnCall, *}, - object::ndarray::{nditer::NDIter, NDArray}, + object::ndarray::{indexing::NDIndex, nditer::NDIter, NDArray}, stmt::gen_for_callback_incrementing, CodeGenContext, CodeGenerator, }; @@ -1112,3 +1112,20 @@ pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next"); FnCall::builder(generator, ctx, &name).arg(iter).returning_void(); } + +pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + num_indices: Instance<'ctx, Int>, + indices: Instance<'ctx, Ptr>>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_index"); + FnCall::builder(generator, ctx, &name) + .arg(num_indices) + .arg(indices) + .arg(src_ndarray) + .arg(dst_ndarray) + .returning_void(); +} diff --git a/nac3core/src/codegen/object/mod.rs b/nac3core/src/codegen/object/mod.rs index 9466d70..17b0b94 100644 --- a/nac3core/src/codegen/object/mod.rs +++ b/nac3core/src/codegen/object/mod.rs @@ -2,3 +2,4 @@ pub mod any; pub mod list; pub mod ndarray; pub mod tuple; +pub mod utils; diff --git a/nac3core/src/codegen/object/ndarray/indexing.rs b/nac3core/src/codegen/object/ndarray/indexing.rs new file mode 100644 index 0000000..d4fbbb3 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/indexing.rs @@ -0,0 +1,226 @@ +use super::NDArrayObject; +use crate::codegen::{ + irrt::call_nac3_ndarray_index, + model::*, + object::utils::slice::{RustSlice, Slice}, + CodeGenContext, CodeGenerator, +}; + +pub type NDIndexType = Byte; + +/// Fields of [`NDIndex`] +#[derive(Debug, Clone, Copy)] +pub struct NDIndexFields<'ctx, F: FieldTraversal<'ctx>> { + pub type_: F::Output>, + pub data: F::Output>>, +} + +/// An IRRT representation of an ndarray subscript index. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct NDIndex; + +impl<'ctx> StructKind<'ctx> for NDIndex { + type Fields> = NDIndexFields<'ctx, F>; + + fn iter_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { type_: traversal.add_auto("type"), data: traversal.add_auto("data") } + } +} + +// A convenience enum representing a [`NDIndex`]. +#[derive(Debug, Clone)] +pub enum RustNDIndex<'ctx> { + SingleElement(Instance<'ctx, Int>), + Slice(RustSlice<'ctx, Int32>), + NewAxis, + Ellipsis, +} + +impl<'ctx> RustNDIndex<'ctx> { + /// Get the value to set `NDIndex::type` for this variant. + fn get_type_id(&self) -> u64 { + // Defined in IRRT, must be in sync + match self { + RustNDIndex::SingleElement(_) => 0, + RustNDIndex::Slice(_) => 1, + RustNDIndex::NewAxis => 2, + RustNDIndex::Ellipsis => 3, + } + } + + /// Serialize this [`RustNDIndex`] by writing it into an LLVM [`NDIndex`]. + fn write_to_ndindex( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + dst_ndindex_ptr: Instance<'ctx, Ptr>>, + ) { + // Set `dst_ndindex_ptr->type` + dst_ndindex_ptr.gep(ctx, |f| f.type_).store( + ctx, + Int(NDIndexType::default()).const_int(generator, ctx.ctx, self.get_type_id(), false), + ); + + // Set `dst_ndindex_ptr->data` + match self { + RustNDIndex::SingleElement(in_index) => { + let index_ptr = Int(Int32).alloca(generator, ctx); + index_ptr.store(ctx, *in_index); + + dst_ndindex_ptr + .gep(ctx, |f| f.data) + .store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte))); + } + RustNDIndex::Slice(in_rust_slice) => { + let user_slice_ptr = Struct(Slice(Int32)).alloca(generator, ctx); + in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr); + + dst_ndindex_ptr + .gep(ctx, |f| f.data) + .store(ctx, user_slice_ptr.pointer_cast(generator, ctx, Int(Byte))); + } + RustNDIndex::NewAxis | RustNDIndex::Ellipsis => {} + } + } + + /// Serialize a list of `RustNDIndex` as a newly allocated LLVM array of `NDIndex`. + pub fn make_ndindices( + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + in_ndindices: &[RustNDIndex<'ctx>], + ) -> (Instance<'ctx, Int>, Instance<'ctx, Ptr>>) { + let ndindex_model = Struct(NDIndex); + + // Allocate the LLVM ndindices. + let num_ndindices = + Int(SizeT).const_int(generator, ctx.ctx, in_ndindices.len() as u64, false); + let ndindices = ndindex_model.array_alloca(generator, ctx, num_ndindices.value); + + // Initialize all of them. + for (i, in_ndindex) in in_ndindices.iter().enumerate() { + let pndindex = ndindices.offset_const(ctx, i64::try_from(i).unwrap()); + in_ndindex.write_to_ndindex(generator, ctx, pndindex); + } + + (num_ndindices, ndindices) + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Get the expected `ndims` after indexing with `indices`. + #[must_use] + fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> u64 { + let mut ndims = self.ndims; + for index in indices { + match index { + RustNDIndex::SingleElement(_) => { + ndims -= 1; // Single elements decrements ndims + } + RustNDIndex::NewAxis => { + ndims += 1; // `np.newaxis` / `none` adds a new axis + } + RustNDIndex::Ellipsis | RustNDIndex::Slice(_) => {} + } + } + ndims + } + + /// Index into the ndarray, and return a newly-allocated view on this ndarray. + /// + /// This function behaves like NumPy's ndarray indexing, but if the indices index + /// into a single element, an unsized ndarray is returned. + #[must_use] + pub fn index( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + indices: &[RustNDIndex<'ctx>], + ) -> Self { + let dst_ndims = self.deduce_ndims_after_indexing_with(indices); + let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, dst_ndims); + + let (num_indices, indices) = RustNDIndex::make_ndindices(generator, ctx, indices); + call_nac3_ndarray_index( + generator, + ctx, + num_indices, + indices, + self.instance, + dst_ndarray.instance, + ); + + dst_ndarray + } +} + +pub mod util { + use itertools::Itertools; + use nac3parser::ast::{Expr, ExprKind}; + + use crate::{ + codegen::{model::*, object::utils::slice::util::gen_slice, CodeGenContext, CodeGenerator}, + typecheck::typedef::Type, + }; + + use super::RustNDIndex; + + /// Generate LLVM code to transform an ndarray subscript expression to + /// its list of [`RustNDIndex`] + /// + /// i.e., + /// ```python + /// my_ndarray[::3, 1, :2:] + /// ^^^^^^^^^^^ Then these into a three `RustNDIndex`es + /// ``` + pub fn gen_ndarray_subscript_ndindices<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + subscript: &Expr>, + ) -> Result>, String> { + // TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools + + // Annoying notes about `slice` + // - `my_array[5]` + // - slice is a `Constant` + // - `my_array[:5]` + // - slice is a `Slice` + // - `my_array[:]` + // - slice is a `Slice`, but lower upper step would all be `Option::None` + // - `my_array[:, :]` + // - slice is now a `Tuple` of two `Slice`-s + // + // In summary: + // - when there is a comma "," within [], `slice` will be a `Tuple` of the entries. + // - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself. + // + // So we first "flatten" out the slice expression + let index_exprs = match &subscript.node { + ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(), + _ => vec![subscript], + }; + + // Process all index expressions + let mut rust_ndindices: Vec = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here. + for index_expr in index_exprs { + // NOTE: Currently nac3core's slices do not have an object representation, + // so the code/implementation looks awkward - we have to do pattern matching on the expression + let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node { + // Handle slices + let slice = gen_slice(generator, ctx, lower, upper, step)?; + RustNDIndex::Slice(slice) + } else { + // Treat and handle everything else as a single element index. + let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum( + ctx, + generator, + ctx.primitives.int32, // Must be int32, this checks for illegal values + )?; + let index = Int(Int32).check_value(generator, ctx.ctx, index).unwrap(); + + RustNDIndex::SingleElement(index) + }; + rust_ndindices.push(ndindex); + } + Ok(rust_ndindices) + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 8b55c2e..7fa3365 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -1,7 +1,7 @@ use inkwell::{ context::Context, types::BasicType, - values::{BasicValueEnum, PointerValue}, + values::{BasicValue, BasicValueEnum, PointerValue}, AddressSpace, }; @@ -22,6 +22,7 @@ use crate::{ }; pub mod factory; +pub mod indexing; pub mod nditer; pub mod shape_util; @@ -352,6 +353,30 @@ impl<'ctx> NDArrayObject<'ctx> { call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance); } + /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. + #[must_use] + pub fn is_unsized(&self) -> bool { + self.ndims == 0 + } + + /// If this ndarray is unsized, return its sole value as an [`AnyObject`]. + /// Otherwise, do nothing and return the ndarray itself. + pub fn split_unsized( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> ScalarOrNDArray<'ctx> { + if self.is_unsized() { + // NOTE: `np.size(self) == 0` here is never possible. + let zero = Int(SizeT).const_0(generator, ctx.ctx); + let value = self.get_nth_scalar(generator, ctx, zero).value; + + ScalarOrNDArray::Scalar(AnyObject { ty: self.dtype, value }) + } else { + ScalarOrNDArray::NDArray(*self) + } + } + /// Fill the ndarray with a scalar. /// /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. @@ -369,3 +394,21 @@ impl<'ctx> NDArrayObject<'ctx> { .unwrap(); } } + +/// A convenience enum for implementing functions that acts on scalars or ndarrays or both. +#[derive(Debug, Clone, Copy)] +pub enum ScalarOrNDArray<'ctx> { + Scalar(AnyObject<'ctx>), + NDArray(NDArrayObject<'ctx>), +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. + #[must_use] + pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { + match self { + ScalarOrNDArray::Scalar(scalar) => scalar.value, + ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(), + } + } +} diff --git a/nac3core/src/codegen/object/utils/mod.rs b/nac3core/src/codegen/object/utils/mod.rs new file mode 100644 index 0000000..913812d --- /dev/null +++ b/nac3core/src/codegen/object/utils/mod.rs @@ -0,0 +1 @@ +pub mod slice; diff --git a/nac3core/src/codegen/object/utils/slice.rs b/nac3core/src/codegen/object/utils/slice.rs new file mode 100644 index 0000000..f4bb3f5 --- /dev/null +++ b/nac3core/src/codegen/object/utils/slice.rs @@ -0,0 +1,125 @@ +use crate::codegen::{model::*, CodeGenContext, CodeGenerator}; + +/// Fields of [`Slice`] +#[derive(Debug, Clone)] +pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>, N: IntKind<'ctx>> { + pub start_defined: F::Output>, + pub start: F::Output>, + pub stop_defined: F::Output>, + pub stop: F::Output>, + pub step_defined: F::Output>, + pub step: F::Output>, +} + +/// An IRRT representation of an (unresolved) slice. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct Slice(pub N); + +impl<'ctx, N: IntKind<'ctx>> StructKind<'ctx> for Slice { + type Fields> = SliceFields<'ctx, F, N>; + + fn iter_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + start_defined: traversal.add_auto("start_defined"), + start: traversal.add("start", Int(self.0)), + stop_defined: traversal.add_auto("stop_defined"), + stop: traversal.add("stop", Int(self.0)), + step_defined: traversal.add_auto("step_defined"), + step: traversal.add("step", Int(self.0)), + } + } +} + +/// A Rust structure that has [`Slice`] utilities and looks like a [`Slice`] but +/// `start`, `stop` and `step` are held by LLVM registers only and possibly +/// [`Option::None`] if unspecified. +#[derive(Debug, Clone)] +pub struct RustSlice<'ctx, N: IntKind<'ctx>> { + // It is possible that `start`, `stop`, and `step` are all `None`. + // We need to know the `int_kind` even when that is the case. + pub int_kind: N, + pub start: Option>>, + pub stop: Option>>, + pub step: Option>>, +} + +impl<'ctx, N: IntKind<'ctx>> RustSlice<'ctx, N> { + /// Write the contents to an LLVM [`Slice`]. + pub fn write_to_slice( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + dst_slice_ptr: Instance<'ctx, Ptr>>>, + ) { + let false_ = Int(Bool).const_false(generator, ctx.ctx); + let true_ = Int(Bool).const_true(generator, ctx.ctx); + + match self.start { + Some(start) => { + dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, true_); + dst_slice_ptr.gep(ctx, |f| f.start).store(ctx, start); + } + None => dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, false_), + } + + match self.stop { + Some(stop) => { + dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, true_); + dst_slice_ptr.gep(ctx, |f| f.stop).store(ctx, stop); + } + None => dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, false_), + } + + match self.step { + Some(step) => { + dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, true_); + dst_slice_ptr.gep(ctx, |f| f.step).store(ctx, step); + } + None => dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, false_), + } + } +} + +pub mod util { + use nac3parser::ast::Expr; + + use crate::{ + codegen::{model::*, CodeGenContext, CodeGenerator}, + typecheck::typedef::Type, + }; + + use super::RustSlice; + + /// Generate LLVM IR for an [`ExprKind::Slice`] and convert it into a [`RustSlice`]. + #[allow(clippy::type_complexity)] + pub fn gen_slice<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + lower: &Option>>>, + upper: &Option>>>, + step: &Option>>>, + ) -> Result, String> { + let mut help = |value_expr: &Option>>>| -> Result<_, String> { + Ok(match value_expr { + None => None, + Some(value_expr) => { + let value_expr = generator + .gen_expr(ctx, value_expr)? + .unwrap() + .to_basic_value_enum(ctx, generator, ctx.primitives.int32)?; + + let value_expr = + Int(Int32).check_value(generator, ctx.ctx, value_expr).unwrap(); + + Some(value_expr) + } + }) + }; + + let start = help(lower)?; + let stop = help(upper)?; + let step = help(step)?; + + Ok(RustSlice { int_kind: Int32, start, stop, step }) + } +}