forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: implement ndarray indexing

The functionality for `...` and `np.newaxis` is there in IRRT, but there
is no implementation of them for @kernel Python expressions because of
M-Labs/nac3#486.
This commit is contained in:
lyken 2024-08-21 13:43:07 +08:00
parent 494616f0d9
commit 7e22f09e6e
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
6 changed files with 665 additions and 349 deletions

View File

@ -3,6 +3,7 @@
#include <irrt/math_util.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/ndarray/indexing.hpp>
#include <irrt/ndarray/iter.hpp>
#include <irrt/original.hpp>
#include <irrt/range.hpp>

View File

@ -0,0 +1,243 @@
#pragma once
#include <irrt/exception.hpp>
#include <irrt/int_types.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/range.hpp>
#include <irrt/slice.hpp>
namespace
{
typedef uint8_t NDIndexType;
/**
* @brief A single element index
*
* `data` points to a `int32_t`.
*/
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
/**
* @brief A slice index
*
* `data` points to a `Slice<int32_t>`.
*/
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
/**
* @brief `np.newaxis` / `None`
*
* `data` is unused.
*/
const NDIndexType ND_INDEX_TYPE_NEWAXIS = 2;
/**
* @brief `Ellipsis` / `...`
*
* `data` is unused.
*/
const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3;
/**
* @brief An index used in ndarray indexing
*/
struct NDIndex
{
/**
* @brief Enum tag to specify the type of index.
*
* Please see comments of each enum constant.
*/
NDIndexType type;
/**
* @brief The accompanying data associated with `type`.
*
* Please see comments of each enum constant.
*/
uint8_t *data;
};
} // namespace
namespace
{
namespace ndarray
{
namespace indexing
{
/**
* @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
*
* This function is very similar to performing `dst_ndarray = src_ndarray[indices]` in Python.
*
* This function also does proper assertions on `indices` to check for out of bounds access.
*
* # Notes on `dst_ndarray`
* The caller is responsible for allocating space for the resulting ndarray.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after
* indexing `src_ndarray` with `indices`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged.
* - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed.
* - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works.
*
* @param indices indices to index `src_ndarray`, ordered in the same way you would write them in Python.
* @param src_ndarray The NDArray to be indexed.
* @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above,
*/
template <typename SizeT>
void index(SizeT num_indices, const NDIndex *indices, const NDArray<SizeT> *src_ndarray, NDArray<SizeT> *dst_ndarray)
{
// Validate `indices`.
// Expected value of `dst_ndarray->ndims`.
SizeT expected_dst_ndims = src_ndarray->ndims;
// To check for "too many indices for array: array is ?-dimensional, but ? were indexed"
SizeT num_indexed = 0;
// There may be ellipsis `...` in `indices`. There can only be 0 or 1 ellipsis.
SizeT num_ellipsis = 0;
for (SizeT i = 0; i < num_indices; i++)
{
if (indices[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT)
{
expected_dst_ndims--;
num_indexed++;
}
else if (indices[i].type == ND_INDEX_TYPE_SLICE)
{
num_indexed++;
}
else if (indices[i].type == ND_INDEX_TYPE_NEWAXIS)
{
expected_dst_ndims++;
}
else if (indices[i].type == ND_INDEX_TYPE_ELLIPSIS)
{
num_ellipsis++;
if (num_ellipsis > 1)
{
raise_exception(SizeT, EXN_INDEX_ERROR, "an index can only have a single ellipsis ('...')", NO_PARAM,
NO_PARAM, NO_PARAM);
}
}
else
{
__builtin_unreachable();
}
}
debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims);
if (src_ndarray->ndims - num_indexed < 0)
{
raise_exception(SizeT, EXN_INDEX_ERROR,
"too many indices for array: array is {0}-dimensional, "
"but {1} were indexed",
src_ndarray->ndims, num_indices, NO_PARAM);
}
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// Reference code: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
SizeT src_axis = 0;
SizeT dst_axis = 0;
for (int32_t i = 0; i < num_indices; i++)
{
const NDIndex *index = &indices[i];
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT)
{
SizeT input = (SizeT) * ((int32_t *)index->data);
SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
if (k == -1)
{
raise_exception(SizeT, EXN_INDEX_ERROR,
"index {0} is out of bounds for axis {1} "
"with size {2}",
input, src_axis, src_ndarray->shape[src_axis]);
}
dst_ndarray->data += k * src_ndarray->strides[src_axis];
src_axis++;
}
else if (index->type == ND_INDEX_TYPE_SLICE)
{
Slice<int32_t> *slice = (Slice<int32_t> *)index->data;
Range<int32_t> range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis];
dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = (SizeT)range.len<SizeT>();
dst_axis++;
src_axis++;
}
else if (index->type == ND_INDEX_TYPE_NEWAXIS)
{
dst_ndarray->strides[dst_axis] = 0;
dst_ndarray->shape[dst_axis] = 1;
dst_axis++;
}
else if (index->type == ND_INDEX_TYPE_ELLIPSIS)
{
// The number of ':' entries this '...' implies.
SizeT ellipsis_size = src_ndarray->ndims - num_indexed;
for (SizeT j = 0; j < ellipsis_size; j++)
{
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_axis++;
src_axis++;
}
}
else
{
__builtin_unreachable();
}
}
for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++)
{
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
}
debug_assert_eq(SizeT, src_ndarray->ndims, src_axis);
debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis);
}
} // namespace indexing
} // namespace ndarray
} // namespace
extern "C"
{
using namespace ndarray::indexing;
void __nac3_ndarray_index(int32_t num_indices, NDIndex *indices, NDArray<int32_t> *src_ndarray,
NDArray<int32_t> *dst_ndarray)
{
index(num_indices, indices, src_ndarray, dst_ndarray);
}
void __nac3_ndarray_index64(int64_t num_indices, NDIndex *indices, NDArray<int64_t> *src_ndarray,
NDArray<int64_t> *dst_ndarray)
{
index(num_indices, indices, src_ndarray, dst_ndarray);
}
}

View File

@ -2,7 +2,7 @@ use crate::{
codegen::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType,
ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
ProxyValue, RangeValue, UntypedArrayLikeAccessor,
},
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
@ -19,11 +19,7 @@ use crate::{
CodeGenContext, CodeGenTask, CodeGenerator,
},
symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, TopLevelDef,
},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
typecheck::{
magic_methods::{Binop, BinopVariant, HasOpInfo},
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
@ -44,6 +40,14 @@ use std::cmp::min;
use std::iter::{repeat, repeat_with};
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use super::{
model::*,
object::{
any::AnyObject,
ndarray::{indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject},
},
};
pub fn get_subst_key(
unifier: &mut Unifier,
obj: Option<Type>,
@ -2492,338 +2496,6 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
)
}
/// Generates code for a subscript expression on an `ndarray`.
///
/// * `ty` - The `Type` of the `NDArray` elements.
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
/// * `v` - The `NDArray` value.
/// * `slice` - The slice expression used to subscript into the `ndarray`.
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type,
ndims: Type,
v: NDArrayValue<'ctx>,
slice: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
unreachable!()
};
let ndims = values
.iter()
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
.collect::<Result<Vec<_>, _>>()
.map_err(|val| {
format!(
"Expected non-negative literal for ndarray.ndims, got {}",
i128::try_from(val).unwrap()
)
})?;
assert!(!ndims.is_empty());
// The number of dimensions subscripted by the index expression.
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
// dimension will remove a dimension.
let subscripted_dims = match &slice.node {
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
if let ExprKind::Slice { .. } = &value_subexpr.node {
acc
} else {
acc + 1
}
}),
ExprKind::Slice { .. } => 0,
_ => 1,
};
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
None,
);
let ndarray_ty =
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap();
// Check that len is non-zero
let len = v.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(),
"0:IndexError",
"too many indices for array: array is {0}-dimensional but 1 were indexed",
[Some(len), None, None],
slice.location,
);
// Normalizes a possibly-negative index to its corresponding positive index
let normalize_index = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
dim: u64| {
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "")
.unwrap())
},
|_, _| Ok(Some(index)),
|generator, ctx| {
let llvm_i32 = ctx.ctx.i32_type();
let len = unsafe {
v.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(dim, true),
None,
)
};
let index = ctx
.builder
.build_int_add(
len,
ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(),
"",
)
.unwrap();
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
},
)
.map(|v| v.map(BasicValueEnum::into_int_value))
};
// Converts a slice expression into a slice-range tuple
let expr_to_slice = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
node: &ExprKind<Option<Type>>,
dim: u64| {
match node {
ExprKind::Constant { value: Constant::Int(v), .. } => {
let Some(index) =
normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)?
else {
return Ok(None);
};
Ok(Some((index, index, llvm_i32.const_int(1, true))))
}
ExprKind::Slice { lower, upper, step } => {
let dim_sz = unsafe {
v.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(dim, false),
None,
)
};
handle_slice_indices(lower, upper, step, ctx, generator, dim_sz)
}
_ => {
let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) };
let index = index
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, dim)? else {
return Ok(None);
};
Ok(Some((index, index, llvm_i32.const_int(1, true))))
}
}
};
let make_indices_arr = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>|
-> Result<_, String> {
Ok(if let ExprKind::Tuple { elts, .. } = &slice.node {
let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(elts.len() as u64, false),
None,
)?;
for (i, elt) in elts.iter().enumerate() {
let Some(index) = generator.gen_expr(ctx, elt)? else {
return Ok(None);
};
let index = index
.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
return Ok(None);
};
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
None,
)
};
ctx.builder.build_store(store_ptr, index).unwrap();
}
Some(index_addr)
} else if let Some(index) = generator.gen_expr(ctx, slice)? {
let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(1u64, false),
None,
)?;
let index =
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder.build_store(store_ptr, index).unwrap();
Some(index_addr)
} else {
None
})
};
Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
v.data().get(ctx, generator, &index_addr, None).into()
} else {
match &slice.node {
ExprKind::Tuple { elts, .. } => {
let slices = elts
.iter()
.enumerate()
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
.collect::<Result<Vec<_>, _>>()?;
if slices.len() < elts.len() {
return Ok(None);
}
let slices = slices.into_iter().map(Option::unwrap).collect_vec();
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into()
}
ExprKind::Slice { .. } => {
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
return Ok(None);
};
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into()
}
_ => {
// Accessing an element from a multi-dimensional `ndarray`
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over
let subscripted_ndarray =
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
ctx,
generator,
ctx.builder
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
.unwrap(),
);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ctx
.builder
.build_int_z_extend_or_bit_cast(
ndarray.load_ndims(ctx),
llvm_usize.size_of().get_type(),
"",
)
.unwrap();
let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
v_dims_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
let ndarray_num_elems = ctx
.builder
.build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "")
.unwrap();
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
call_memcpy_generic(
ctx,
ndarray.data().base_ptr(ctx, generator),
v_data_src_ptr,
ctx.builder
.build_int_mul(
ndarray_num_elems,
llvm_ndarray_data_t.size_of().unwrap(),
"",
)
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
ndarray.as_base_value().into()
}
}
}))
}
/// See [`CodeGenerator::gen_expr`].
pub fn gen_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
@ -3463,18 +3135,26 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
v.data().get(ctx, generator, &index, None).into()
}
}
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
.into_pointer_value()
} else {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let Some(ndarray) = generator.gen_expr(ctx, value)? else {
return Ok(None);
};
let v = NDArrayValue::from_ptr_val(v, usize, None);
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
let ndarray_ty = value.custom.unwrap();
let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = NDArrayObject::from_object(
generator,
ctx,
AnyObject { ty: ndarray_ty, value: ndarray },
);
let indices = gen_ndarray_subscript_ndindices(generator, ctx, slice)?;
let result = ndarray
.index(generator, ctx, &indices)
.split_unsized(generator, ctx)
.to_basic_value_enum();
return Ok(Some(ValueEnum::Dynamic(result)));
}
TypeEnum::TTuple { .. } => {
let index: u32 =
@ -3517,3 +3197,42 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
_ => unimplemented!(),
}))
}
/// Generate LLVM IR for an [`ExprKind::Slice`]
#[allow(clippy::type_complexity)]
pub fn gen_slice<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lower: &Option<Box<Expr<Option<Type>>>>,
upper: &Option<Box<Expr<Option<Type>>>>,
step: &Option<Box<Expr<Option<Type>>>>,
) -> Result<
(
Option<Instance<'ctx, Int<Int32>>>,
Option<Instance<'ctx, Int<Int32>>>,
Option<Instance<'ctx, Int<Int32>>>,
),
String,
> {
let mut help = |value_expr: &Option<Box<Expr<Option<Type>>>>| -> Result<_, String> {
Ok(match value_expr {
None => None,
Some(value_expr) => {
let value_expr = generator
.gen_expr(ctx, value_expr)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?;
let value_expr = Int(Int32).check_value(generator, ctx.ctx, value_expr).unwrap();
Some(value_expr)
}
})
};
let lower = help(lower)?;
let upper = help(upper)?;
let step = help(step)?;
Ok((lower, upper, step))
}

View File

@ -7,7 +7,7 @@ use super::{
},
llvm_intrinsics,
model::*,
object::ndarray::{nditer::NDIter, NDArray},
object::ndarray::{indexing::NDIndex, nditer::NDIter, NDArray},
CodeGenContext, CodeGenerator,
};
use crate::codegen::classes::TypedArrayLikeAccessor;
@ -1119,3 +1119,20 @@ pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next");
CallFunction::begin(generator, ctx, &name).arg(iter).returning_void();
}
pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
num_indices: Instance<'ctx, Int<SizeT>>,
indices: Instance<'ctx, Ptr<Struct<NDIndex>>>,
src_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
dst_ndarray: Instance<'ctx, Ptr<Struct<NDArray>>>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_index");
CallFunction::begin(generator, ctx, &name)
.arg(num_indices)
.arg(indices)
.arg(src_ndarray)
.arg(dst_ndarray)
.returning_void();
}

View File

@ -0,0 +1,293 @@
use crate::codegen::{irrt::call_nac3_ndarray_index, model::*, CodeGenContext, CodeGenerator};
use super::NDArrayObject;
pub type NDIndexType = Byte;
/// Fields of [`NDIndex`]
#[derive(Debug, Clone, Copy)]
pub struct NDIndexFields<'ctx, F: FieldTraversal<'ctx>> {
pub type_: F::Out<Int<NDIndexType>>, // Defined to be uint8_t in IRRT
pub data: F::Out<Ptr<Int<Byte>>>,
}
/// An IRRT representation of an ndarray subscript index.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct NDIndex;
impl<'ctx> StructKind<'ctx> for NDIndex {
type Fields<F: FieldTraversal<'ctx>> = NDIndexFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { type_: traversal.add_auto("type"), data: traversal.add_auto("data") }
}
}
/// Fields of [`Slice`]
#[derive(Debug, Clone)]
pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>> {
pub start_defined: F::Out<Int<Bool>>,
pub start: F::Out<Int<Int32>>,
pub stop_defined: F::Out<Int<Bool>>,
pub stop: F::Out<Int<Int32>>,
pub step_defined: F::Out<Int<Bool>>,
pub step: F::Out<Int<Int32>>,
}
/// An IRRT representation of an (unresolved) slice.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct Slice;
impl<'ctx> StructKind<'ctx> for Slice {
type Fields<F: FieldTraversal<'ctx>> = SliceFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
start_defined: traversal.add_auto("start_defined"),
start: traversal.add_auto("start"),
stop_defined: traversal.add_auto("stop_defined"),
stop: traversal.add_auto("stop"),
step_defined: traversal.add_auto("step_defined"),
step: traversal.add_auto("step"),
}
}
}
/// A convenience structure to prepare a [`Slice`].
#[derive(Debug, Clone)]
pub struct RustSlice<'ctx> {
pub start: Option<Instance<'ctx, Int<Int32>>>,
pub stop: Option<Instance<'ctx, Int<Int32>>>,
pub step: Option<Instance<'ctx, Int<Int32>>>,
}
impl<'ctx> RustSlice<'ctx> {
/// Write the contents to an LLVM [`Slice`].
pub fn write_to_slice<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
dst_slice_ptr: Instance<'ctx, Ptr<Struct<Slice>>>,
) {
let false_ = Int(Bool).const_false(generator, ctx.ctx);
let true_ = Int(Bool).const_true(generator, ctx.ctx);
match self.start {
Some(start) => {
dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.start).store(ctx, start);
}
None => dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, false_),
}
match self.stop {
Some(stop) => {
dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.stop).store(ctx, stop);
}
None => dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, false_),
}
match self.step {
Some(step) => {
dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.step).store(ctx, step);
}
None => dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, false_),
}
}
}
// A convenience enum to prepare an [`NDIndex`].
#[derive(Debug, Clone)]
pub enum RustNDIndex<'ctx> {
SingleElement(Instance<'ctx, Int<Int32>>), // TODO: To be SizeT
Slice(RustSlice<'ctx>),
NewAxis,
Ellipsis,
}
impl<'ctx> RustNDIndex<'ctx> {
/// Get the value to set `NDIndex::type` for this variant.
fn get_type_id(&self) -> u64 {
// Defined in IRRT, must be in sync
match self {
RustNDIndex::SingleElement(_) => 0,
RustNDIndex::Slice(_) => 1,
RustNDIndex::NewAxis => 2,
RustNDIndex::Ellipsis => 3,
}
}
/// Write the contents to an LLVM [`NDIndex`].
fn write_to_ndindex<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
dst_ndindex_ptr: Instance<'ctx, Ptr<Struct<NDIndex>>>,
) {
// Set `dst_ndindex_ptr->type`
dst_ndindex_ptr.gep(ctx, |f| f.type_).store(
ctx,
Int(NDIndexType::default()).const_int(generator, ctx.ctx, self.get_type_id()),
);
// Set `dst_ndindex_ptr->data`
match self {
RustNDIndex::SingleElement(in_index) => {
let index_ptr = Int(Int32).alloca(generator, ctx);
index_ptr.store(ctx, *in_index);
dst_ndindex_ptr
.gep(ctx, |f| f.data)
.store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte)));
}
RustNDIndex::Slice(in_rust_slice) => {
let user_slice_ptr = Struct(Slice).alloca(generator, ctx);
in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr);
dst_ndindex_ptr
.gep(ctx, |f| f.data)
.store(ctx, user_slice_ptr.pointer_cast(generator, ctx, Int(Byte)));
}
RustNDIndex::NewAxis | RustNDIndex::Ellipsis => {}
}
}
/// Allocate an array of `NDIndex`es on the stack and return its stack pointer.
pub fn alloca_ndindices<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
in_ndindices: &[RustNDIndex<'ctx>],
) -> (Instance<'ctx, Int<SizeT>>, Instance<'ctx, Ptr<Struct<NDIndex>>>) {
let ndindex_model = Struct(NDIndex);
let num_ndindices = Int(SizeT).const_int(generator, ctx.ctx, in_ndindices.len() as u64);
let ndindices = ndindex_model.array_alloca(generator, ctx, num_ndindices.value);
for (i, in_ndindex) in in_ndindices.iter().enumerate() {
let pndindex = ndindices.offset_const(ctx, i as u64);
in_ndindex.write_to_ndindex(generator, ctx, pndindex);
}
(num_ndindices, ndindices)
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Get the ndims [`Type`] after indexing with a given slice.
#[must_use]
pub fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> u64 {
let mut ndims = self.ndims;
for index in indices {
match index {
RustNDIndex::SingleElement(_) => {
ndims -= 1; // Single elements decrements ndims
}
RustNDIndex::NewAxis => {
ndims += 1; // `np.newaxis` / `none` adds a new axis
}
RustNDIndex::Ellipsis | RustNDIndex::Slice(_) => {}
}
}
ndims
}
/// Index into the ndarray, and return a newly-allocated view on this ndarray.
///
/// This function behaves like NumPy's ndarray indexing, but if the indices index
/// into a single element, an unsized ndarray is returned.
#[must_use]
pub fn index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
indices: &[RustNDIndex<'ctx>],
) -> Self {
let dst_ndims = self.deduce_ndims_after_indexing_with(indices);
let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, dst_ndims);
let (num_indices, indices) = RustNDIndex::alloca_ndindices(generator, ctx, indices);
call_nac3_ndarray_index(
generator,
ctx,
num_indices,
indices,
self.instance,
dst_ndarray.instance,
);
dst_ndarray
}
}
pub mod util {
use itertools::Itertools;
use nac3parser::ast::{Expr, ExprKind};
use crate::{
codegen::{expr::gen_slice, model::*, CodeGenContext, CodeGenerator},
typecheck::typedef::Type,
};
use super::{RustNDIndex, RustSlice};
/// Generate LLVM code to transform an ndarray subscript expression to
/// its list of [`RustNDIndex`]
///
/// i.e.,
/// ```python
/// my_ndarray[::3, 1, :2:]
/// ^^^^^^^^^^^ Then these into a three `RustNDIndex`es
/// ```
pub fn gen_ndarray_subscript_ndindices<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
subscript: &Expr<Option<Type>>,
) -> Result<Vec<RustNDIndex<'ctx>>, String> {
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
// Annoying notes about `slice`
// - `my_array[5]`
// - slice is a `Constant`
// - `my_array[:5]`
// - slice is a `Slice`
// - `my_array[:]`
// - slice is a `Slice`, but lower upper step would all be `Option::None`
// - `my_array[:, :]`
// - slice is now a `Tuple` of two `Slice`-s
//
// In summary:
// - when there is a comma "," within [], `slice` will be a `Tuple` of the entries.
// - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself.
//
// So we first "flatten" out the slice expression
let index_exprs = match &subscript.node {
ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(),
_ => vec![subscript],
};
// Process all index expressions
let mut rust_ndindices: Vec<RustNDIndex> = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here.
for index_expr in index_exprs {
// NOTE: Currently nac3core's slices do not have an object representation,
// so the code/implementation looks awkward - we have to do pattern matching on the expression
let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node {
// Handle slices
let (lower, upper, step) = gen_slice(generator, ctx, lower, upper, step)?;
RustNDIndex::Slice(RustSlice { start: lower, stop: upper, step })
} else {
// Treat and handle everything else as a single element index.
let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
ctx,
generator,
ctx.primitives.int32, // Must be int32, this checks for illegal values
)?;
let index = Int(Int32).check_value(generator, ctx.ctx, index).unwrap();
RustNDIndex::SingleElement(index)
};
rust_ndindices.push(ndindex);
}
Ok(rust_ndindices)
}
}

View File

@ -1,11 +1,12 @@
pub mod factory;
pub mod indexing;
pub mod nditer;
pub mod shape_util;
use inkwell::{
context::Context,
types::BasicType,
values::{BasicValueEnum, PointerValue},
values::{BasicValue, BasicValueEnum, PointerValue},
AddressSpace,
};
@ -356,6 +357,30 @@ impl<'ctx> NDArrayObject<'ctx> {
call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance);
}
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
#[must_use]
pub fn is_unsized(&self) -> bool {
self.ndims == 0
}
/// If this ndarray is unsized, return its sole value as an [`AnyObject`].
/// Otherwise, do nothing and return the ndarray itself.
pub fn split_unsized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ScalarOrNDArray<'ctx> {
if self.is_unsized() {
// NOTE: `np.size(self) == 0` here is never possible.
let zero = Int(SizeT).const_0(generator, ctx.ctx);
let value = self.get_nth_scalar(generator, ctx, zero).value;
ScalarOrNDArray::Scalar(AnyObject { ty: self.dtype, value })
} else {
ScalarOrNDArray::NDArray(*self)
}
}
/// Fill the ndarray with a scalar.
///
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
@ -373,3 +398,21 @@ impl<'ctx> NDArrayObject<'ctx> {
.unwrap();
}
}
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
#[derive(Debug, Clone, Copy)]
pub enum ScalarOrNDArray<'ctx> {
Scalar(AnyObject<'ctx>),
NDArray(NDArrayObject<'ctx>),
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
#[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar.value,
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
}
}
}