forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: implement general ndarray basic indexing

This commit is contained in:
lyken 2024-07-28 17:45:02 +08:00
parent 4b14609342
commit bd5cb14d0d
10 changed files with 846 additions and 344 deletions

View File

@ -1,13 +1,11 @@
#pragma once #pragma once
#include <irrt/int_defs.hpp> #include <irrt/int_defs.hpp>
#include <irrt/slice.hpp>
#include <irrt/utils.hpp> #include <irrt/utils.hpp>
// NDArray indices are always `uint32_t`. // NDArray indices are always `uint32_t`.
using NDIndex = uint32_t; using NDIndexInt = uint32_t;
// The type of an index or a value describing the length of a
// range/slice is always `int32_t`.
using SliceIndex = int32_t;
namespace { namespace {
// adapted from GNU Scientific Library: // adapted from GNU Scientific Library:
@ -43,7 +41,7 @@ SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len,
template <typename SizeT> template <typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims,
SizeT num_dims, NDIndex* idxs) { SizeT num_dims, NDIndexInt* idxs) {
SizeT stride = 1; SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) { for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1; SizeT i = num_dims - dim - 1;
@ -55,7 +53,7 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims,
template <typename SizeT> template <typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims,
const NDIndex* indices, const NDIndexInt* indices,
SizeT num_indices) { SizeT num_indices) {
SizeT idx = 0; SizeT idx = 0;
SizeT stride = 1; SizeT stride = 1;
@ -104,8 +102,8 @@ void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, SizeT lhs_ndims,
template <typename SizeT> template <typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims, void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims, SizeT src_ndims,
const NDIndex* in_idx, const NDIndexInt* in_idx,
NDIndex* out_idx) { NDIndexInt* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) { for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1; SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
@ -293,24 +291,24 @@ uint64_t __nac3_ndarray_calc_size64(const uint64_t* list_data,
} }
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims,
uint32_t num_dims, NDIndex* idxs) { uint32_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
} }
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims,
uint64_t num_dims, NDIndex* idxs) { uint64_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
} }
uint32_t __nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, uint32_t __nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims,
const NDIndex* indices, const NDIndexInt* indices,
uint32_t num_indices) { uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices,
num_indices); num_indices);
} }
uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims,
const NDIndex* indices, const NDIndexInt* indices,
uint64_t num_indices) { uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices,
num_indices); num_indices);
@ -333,16 +331,16 @@ void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims, void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims, uint32_t src_ndims,
const NDIndex* in_idx, const NDIndexInt* in_idx,
NDIndex* out_idx) { NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx,
out_idx); out_idx);
} }
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims, void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims, uint64_t src_ndims,
const NDIndex* in_idx, const NDIndexInt* in_idx,
NDIndex* out_idx) { NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx,
out_idx); out_idx);
} }

View File

@ -0,0 +1,200 @@
#pragma once
#include <irrt/error_context.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/slice.hpp>
namespace {
typedef uint8_t NDIndexType;
/**
* @brief A single element index
*
* See https://numpy.org/doc/stable/user/basics.indexing.html#single-element-indexing
*
* `data` points to a `SliceIndex`.
*/
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
/**
* @brief A slice index
*
* See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
*
* `data` points to a `UserRange`.
*/
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
/**
* @brief An index used in ndarray indexing
*/
struct NDIndex {
/**
* @brief Enum tag to specify the type of index.
*
* Please see comments of each enum constant.
*/
NDIndexType type;
/**
* @brief The accompanying data associated with `type`.
*
* Please see comments of each enum constant.
*/
uint8_t* data;
};
} // namespace
namespace {
namespace ndarray {
namespace indexing {
namespace util {
/**
* @brief Return the expected rank of the resulting ndarray
* created by indexing an ndarray of rank `ndims` using `indexes`.
*/
template <typename SizeT>
void deduce_ndims_after_indexing(ErrorContext* errctx, SizeT* final_ndims,
SizeT ndims, SizeT num_indexes,
const NDIndex* indexes) {
if (num_indexes > ndims) {
errctx->set_exception(errctx->exceptions->index_error,
"too many indices for array: array is "
"{0}-dimensional, but {1} were indexed",
ndims, num_indexes);
return;
}
*final_ndims = ndims;
for (SizeT i = 0; i < num_indexes; i++) {
if (indexes[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
// An index demotes the rank by 1
(*final_ndims)--;
}
}
}
} // namespace util
/**
* @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
*
* This is function very similar to performing `dst_ndarray = src_ndarray[indexes]` in Python (where the variables
* can all be found in the parameter of this function).
*
* In other words, this function takes in an ndarray (`src_ndarray`), index it with `indexes`, and return the
* indexed array (by writing the result to `dst_ndarray`).
*
* This function also does proper assertions on `indexes`.
*
* # Notes on `dst_ndarray`
* The caller is responsible for allocating space for the resulting ndarray.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after
* indexing `src_ndarray` with `indexes`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged.
* - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed.
* - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works.
*
* @param indexes Indexes to index `src_ndarray`, ordered in the same way you would write them in Python.
* @param src_ndarray The NDArray to be indexed.
* @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above,
*/
template <typename SizeT>
void index(ErrorContext* errctx, SizeT num_indexes, const NDIndex* indexes,
const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// Reference code: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
SizeT src_axis = 0;
SizeT dst_axis = 0;
for (SliceIndex i = 0; i < num_indexes; i++) {
const NDIndex* index = &indexes[i];
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
SliceIndex input = *((SliceIndex*)index->data);
SliceIndex k = slice::resolve_index_in_length(
src_ndarray->shape[src_axis], input);
if (k == slice::OUT_OF_BOUNDS) {
errctx->set_exception(errctx->exceptions->index_error,
"index {0} is out of bounds for axis {1} "
"with size {2}",
input, src_axis,
src_ndarray->shape[src_axis]);
return;
}
dst_ndarray->data += k * src_ndarray->strides[src_axis];
src_axis++;
} else if (index->type == ND_INDEX_TYPE_SLICE) {
UserSlice* input = (UserSlice*)index->data;
Slice slice;
input->indices_checked(errctx, src_ndarray->shape[src_axis],
&slice);
if (errctx->has_exception()) {
return;
}
dst_ndarray->data +=
(SizeT)slice.start * src_ndarray->strides[src_axis];
dst_ndarray->strides[dst_axis] =
((SizeT)slice.step) * src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = (SizeT)slice.len();
dst_axis++;
src_axis++;
} else {
__builtin_unreachable();
}
}
for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) {
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
}
}
} // namespace indexing
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::indexing;
void __nac3_ndarray_indexing_deduce_ndims_after_indexing(
ErrorContext* errctx, int32_t* result, int32_t ndims, int32_t num_indexes,
const NDIndex* indexes) {
ndarray::indexing::util::deduce_ndims_after_indexing(errctx, result, ndims,
num_indexes, indexes);
}
void __nac3_ndarray_indexing_deduce_ndims_after_indexing64(
ErrorContext* errctx, int64_t* result, int64_t ndims, int64_t num_indexes,
const NDIndex* indexes) {
ndarray::indexing::util::deduce_ndims_after_indexing(errctx, result, ndims,
num_indexes, indexes);
}
void __nac3_ndarray_index(ErrorContext* errctx, int32_t num_indexes,
NDIndex* indexes, NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
index(errctx, num_indexes, indexes, src_ndarray, dst_ndarray);
}
void __nac3_ndarray_index64(ErrorContext* errctx, int64_t num_indexes,
NDIndex* indexes, NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
index(errctx, num_indexes, indexes, src_ndarray, dst_ndarray);
}
}

View File

@ -1,7 +1,13 @@
#pragma once #pragma once
#include <irrt/error_context.hpp>
#include <irrt/int_defs.hpp> #include <irrt/int_defs.hpp>
#include <irrt/slice.hpp> #include <irrt/slice.hpp>
#include <irrt/utils.hpp>
// The type of an index or a value describing the length of a
// range/slice is always `int32_t`.
using SliceIndex = int32_t;
namespace { namespace {

View File

@ -6,5 +6,6 @@
#include <irrt/int_defs.hpp> #include <irrt/int_defs.hpp>
#include <irrt/ndarray/basic.hpp> #include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp> #include <irrt/ndarray/def.hpp>
#include <irrt/ndarray/indexing.hpp>
#include <irrt/slice.hpp> #include <irrt/slice.hpp>
#include <irrt/utils.hpp> #include <irrt/utils.hpp>

View File

@ -6,11 +6,13 @@
#include <cstdlib> #include <cstdlib>
#include <test/test_core.hpp> #include <test/test_core.hpp>
#include <test/test_ndarray_basic.hpp> #include <test/test_ndarray_basic.hpp>
#include <test/test_ndarray_indexing.hpp>
#include <test/test_slice.hpp> #include <test/test_slice.hpp>
int main() { int main() {
test::core::run(); test::core::run();
test::slice::run(); test::slice::run();
test::ndarray_basic::run(); test::ndarray_basic::run();
test::ndarray_indexing::run();
return 0; return 0;
} }

View File

@ -0,0 +1,220 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace ndarray_indexing {
void test_normal_1() {
/*
Reference Python code:
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[-2:, 1::2]
# array([[ 5., 7.],
# [ 9., 11.]])
assert dst_ndarray.shape == (2, 2)
assert dst_ndarray.strides == (32, 16)
assert dst_ndarray[0, 0] == 5.0
assert dst_ndarray[0, 1] == 7.0
assert dst_ndarray[1, 0] == 9.0
assert dst_ndarray[1, 1] == 11.0
```
*/
BEGIN_TEST();
// Prepare src_ndarray
double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
6.0, 7.0, 8.0, 9.0, 10.0, 11.0};
int32_t src_itemsize = sizeof(double);
const int32_t src_ndims = 2;
int32_t src_shape[src_ndims] = {3, 4};
int32_t src_strides[src_ndims] = {};
NDArray<int32_t> src_ndarray = {.data = (uint8_t *)src_data,
.itemsize = src_itemsize,
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Prepare dst_ndarray
const int32_t dst_ndims = 2;
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
NDArray<int32_t> dst_ndarray = {.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
// Create the subscripts in `ndarray[-2::, 1::2]`
UserSlice subscript_1;
subscript_1.set_start(-2);
UserSlice subscript_2;
subscript_2.set_start(1);
subscript_2.set_step(2);
const int32_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}};
ErrorContext errctx = create_testing_errctx();
ndarray::indexing::index(&errctx, num_indexes, indexes, &src_ndarray,
&dst_ndarray);
assert_errctx_no_exception(&errctx);
int32_t expected_shape[dst_ndims] = {2, 2};
int32_t expected_strides[dst_ndims] = {32, 16};
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
// dst_ndarray[0, 0]
assert_values_match(5.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){0, 0})));
// dst_ndarray[0, 1]
assert_values_match(7.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){0, 1})));
// dst_ndarray[1, 0]
assert_values_match(9.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){1, 0})));
// dst_ndarray[1, 1]
assert_values_match(11.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){1, 1})));
}
void test_normal_2() {
/*
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[2, ::-2]
# array([11., 9.])
assert dst_ndarray.shape == (2,)
assert dst_ndarray.strides == (-16,)
assert dst_ndarray[0] == 11.0
assert dst_ndarray[1] == 9.0
```
*/
BEGIN_TEST();
// Prepare src_ndarray
double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
6.0, 7.0, 8.0, 9.0, 10.0, 11.0};
int32_t src_itemsize = sizeof(double);
const int32_t src_ndims = 2;
int32_t src_shape[src_ndims] = {3, 4};
int32_t src_strides[src_ndims] = {};
NDArray<int32_t> src_ndarray = {.data = (uint8_t *)src_data,
.itemsize = src_itemsize,
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Prepare dst_ndarray
const int32_t dst_ndims = 1;
int32_t dst_shape[dst_ndims] = {999}; // Empty values
int32_t dst_strides[dst_ndims] = {999}; // Empty values
NDArray<int32_t> dst_ndarray = {.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
// Create the subscripts in `ndarray[2, ::-2]`
int32_t subscript_1 = 2;
UserSlice subscript_2;
subscript_2.set_step(-2);
const int32_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SINGLE_ELEMENT, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}};
ErrorContext errctx = create_testing_errctx();
ndarray::indexing::index(&errctx, num_indexes, indexes, &src_ndarray,
&dst_ndarray);
assert_errctx_no_exception(&errctx);
int32_t expected_shape[dst_ndims] = {2};
int32_t expected_strides[dst_ndims] = {-16};
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
assert_values_match(11.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){0})));
assert_values_match(9.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int32_t[dst_ndims]){1})));
}
void test_index_subscript_out_of_bounds() {
/*
# Consider `my_array`
print(my_array.shape)
# (4, 5, 6)
my_array[2, 100] # error, index subscript at axis 1 is out of bounds
*/
BEGIN_TEST();
// Prepare src_ndarray
const int32_t src_ndims = 2;
int32_t src_shape[src_ndims] = {3, 4};
int32_t src_strides[src_ndims] = {};
NDArray<int32_t> src_ndarray = {
.data = (uint8_t *)nullptr, // placeholder, we wouldn't access it
.itemsize = sizeof(double), // placeholder
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Create the subscripts in `my_array[2, 100]`
int32_t subscript_1 = 2;
int32_t subscript_2 = 100;
const int32_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SINGLE_ELEMENT, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SINGLE_ELEMENT,
.data = (uint8_t *)&subscript_2}};
// Prepare dst_ndarray
const int32_t dst_ndims = 0;
int32_t dst_shape[dst_ndims] = {};
int32_t dst_strides[dst_ndims] = {};
NDArray<int32_t> dst_ndarray = {.data = nullptr, // placehloder
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
ErrorContext errctx = create_testing_errctx();
ndarray::indexing::index(&errctx, num_indexes, indexes, &src_ndarray,
&dst_ndarray);
assert_errctx_has_exception(&errctx, errctx.exceptions->index_error);
}
void run() {
test_normal_1();
test_normal_2();
test_index_subscript_out_of_bounds();
}
} // namespace ndarray_indexing
} // namespace test

View File

@ -1,10 +1,14 @@
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use super::{
irrt::slice::{RustUserSlice, SliceIndex},
structure::ndarray::NpArray,
};
use crate::{ use crate::{
codegen::{ codegen::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType,
ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, ProxyValue, RangeValue, UntypedArrayLikeAccessor,
}, },
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check, get_llvm_abi_type, get_llvm_type, gen_in_range_check, get_llvm_abi_type, get_llvm_type,
@ -21,11 +25,7 @@ use crate::{
CodeGenContext, CodeGenTask, CodeGenerator, CodeGenContext, CodeGenTask, CodeGenerator,
}, },
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{ toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, TopLevelDef,
},
typecheck::{ typecheck::{
magic_methods::{Binop, BinopVariant, HasOpInfo}, magic_methods::{Binop, BinopVariant, HasOpInfo},
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
@ -34,19 +34,25 @@ use crate::{
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
types::{AnyType, BasicType, BasicTypeEnum}, types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue}, values::{BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use itertools::{chain, izip, Either, Itertools}; use itertools::{chain, izip, Either, Itertools};
use nac3parser::ast::{ use nac3parser::ast::{
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Located, Location, Operator,
Unaryop, StrRef, Unaryop,
};
use ndarray::{
allocation::alloca_ndarray,
indexing::{call_nac3_ndarray_index, RustNDIndex},
}; };
use super::structure::cslice::CSlice;
use super::{ use super::{
model::*, model::*,
structure::exception::{Exception, ExceptionId}, structure::{
cslice::CSlice,
exception::{Exception, ExceptionId},
},
}; };
pub fn get_subst_key( pub fn get_subst_key(
@ -2127,334 +2133,149 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
/// Generates code for a subscript expression on an `ndarray`. /// Generates code for a subscript expression on an `ndarray`.
/// ///
/// * `ty` - The `Type` of the `NDArray` elements. /// * `elem_ty` - The `Type` of the `NDArray` elements.
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`. /// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
/// * `v` - The `NDArray` value. /// * `src_ndarray` - The `NDArray` value.
/// * `slice` - The slice expression used to subscript into the `ndarray`. /// * `subscript` - The subscript expression used to index into the `ndarray`.
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type, elem_ty: Type,
ndims: Type, ndims: Type,
v: NDArrayValue<'ctx>, src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
slice: &Expr<Option<Type>>, subscript: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
let llvm_i1 = ctx.ctx.bool_type(); // TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
let llvm_i32 = ctx.ctx.i32_type(); let tyctx = generator.type_context(ctx.ctx);
let llvm_usize = generator.get_size_type(ctx.ctx); let sizet_model = IntModel(SizeT);
let slice_index_model = IntModel(SliceIndex::default());
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else { // Annoying notes about `slice`
// - `my_array[5]`
// - slice is a `Constant`
// - `my_array[:5]`
// - slice is a `Slice`
// - `my_array[:]`
// - slice is a `Slice`, but lower upper step would all be `Option::None`
// - `my_array[:, :]`
// - slice is now a `Tuple` of two `Slice`-s
//
// In summary:
// - when there is a comma "," within [], `slice` will be a `Tuple` of the entries.
// - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself.
//
// So we first "flatten" out the slice expression
let index_exprs = match &subscript.node {
ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(),
_ => vec![subscript],
};
// Process all index expressions
let mut rust_ndindexes: Vec<RustNDIndex> = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here.
for index_expr in index_exprs {
// NOTE: Currently nac3core's slices do not have an object representation,
// so the code/implementation looks awkward - we have to do pattern matching on the expression
let ndindex = if let ExprKind::Slice { lower: start, upper: stop, step } = &index_expr.node
{
// Helper function here to deduce code duplication
type ValueExpr = Option<Box<Located<ExprKind<Option<Type>>, Option<Type>>>>;
let mut help = |value_expr: &ValueExpr| -> Result<_, String> {
Ok(match value_expr {
None => None,
Some(value_expr) => {
let value_expr = generator
.gen_expr(ctx, value_expr)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?;
let value_expr =
slice_index_model.check_value(tyctx, ctx.ctx, value_expr).unwrap();
Some(value_expr)
}
})
};
let start = help(start)?;
let stop = help(stop)?;
let step = help(step)?;
RustNDIndex::Slice(RustUserSlice { start, stop, step })
} else {
// Anything else that is not a slice (might be illegal values),
// For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error
let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
ctx,
generator,
ctx.primitives.int32,
)?;
let index = slice_index_model.check_value(tyctx, ctx.ctx, index).unwrap();
RustNDIndex::SingleElement(index)
};
rust_ndindexes.push(ndindex);
}
// Extract the `ndims` from a `Type` to `i128`
// We *HAVE* to know this statically, this is used to determine
// whether this subscript expression returns a scalar or an ndarray
let TypeEnum::TLiteral { values: ndims_values, .. } = &*ctx.unifier.get_ty_immutable(ndims)
else {
unreachable!() unreachable!()
}; };
assert_eq!(ndims_values.len(), 1);
let src_ndims = i128::try_from(ndims_values[0].clone()).unwrap();
let ndims = values // Check for "too many indices for array: array is ..." error
.iter() if src_ndims < rust_ndindexes.len() as i128 {
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone())) ctx.make_assert(
.collect::<Result<Vec<_>, _>>()
.map_err(|val| {
format!(
"Expected non-negative literal for ndarray.ndims, got {}",
i128::try_from(val).unwrap()
)
})?;
assert!(!ndims.is_empty());
// The number of dimensions subscripted by the index expression.
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
// dimension will remove a dimension.
let subscripted_dims = match &slice.node {
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
if let ExprKind::Slice { .. } = &value_subexpr.node {
acc
} else {
acc + 1
}
}),
ExprKind::Slice { .. } => 0,
_ => 1,
};
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
None,
);
let ndarray_ty =
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap();
// Check that len is non-zero
let len = v.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(),
"0:IndexError",
"too many indices for array: array is {0}-dimensional but 1 were indexed",
[Some(len), None, None],
slice.location,
);
// Normalizes a possibly-negative index to its corresponding positive index
let normalize_index = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
dim: u64| {
gen_if_else_expr_callback(
generator, generator,
ctx, ctx.ctx.bool_type().const_int(1, false),
|_, ctx| { "0:IndexError",
Ok(ctx "too many indices for array: array is {0}-dimensional, but {1} were indexed",
.builder [None, None, None],
.build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "") ctx.current_loc,
.unwrap()) );
}, }
|_, _| Ok(Some(index)),
|generator, ctx| {
let llvm_i32 = ctx.ctx.i32_type();
let len = unsafe { let dst_ndims = RustNDIndex::deduce_ndims_after_slicing(&rust_ndindexes, src_ndims as i32);
v.dim_sizes().get_typed_unchecked( let dst_ndarray = alloca_ndarray(
ctx, generator,
generator, ctx,
&llvm_usize.const_int(dim, true), sizet_model.constant(tyctx, ctx.ctx, dst_ndims as u64),
None, "subndarray",
) )?;
};
let index = ctx // Prepare the subscripts
.builder let (num_ndindexes, ndindexes) = RustNDIndex::alloca_ndindexes(tyctx, ctx, &rust_ndindexes);
.build_int_add(
len,
ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(),
"",
)
.unwrap();
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap())) // NOTE: IRRT does check for indexing errors
}, call_nac3_ndarray_index(generator, ctx, num_ndindexes, ndindexes, src_ndarray, dst_ndarray);
)
.map(|v| v.map(BasicValueEnum::into_int_value))
};
// Converts a slice expression into a slice-range tuple // ...and return the result, with two cases
let expr_to_slice = |generator: &mut G, let result_llvm_value: BasicValueEnum<'_> = if dst_ndims == 0 {
ctx: &mut CodeGenContext<'ctx, '_>, // 1) ndims == 0 (this happens when you do `np.zerps((3, 4))[1, 1]`), return the element
node: &ExprKind<Option<Type>>,
dim: u64| {
match node {
ExprKind::Constant { value: Constant::Int(v), .. } => {
let Some(index) =
normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)?
else {
return Ok(None);
};
Ok(Some((index, index, llvm_i32.const_int(1, true)))) let pelement = dst_ndarray.gep(ctx, |f| f.data).load(tyctx, ctx, "pelement"); // `*data` points to the first element by definition
}
ExprKind::Slice { lower, upper, step } => { // Cast the opaque `pelement` ptr to `elem_ty`
let dim_sz = unsafe { let elem_ty = ctx.get_llvm_type(generator, elem_ty);
v.dim_sizes().get_typed_unchecked( let pelement = ctx
ctx, .builder
generator, .build_pointer_cast(
&llvm_usize.const_int(dim, false), pelement.value,
None, elem_ty.ptr_type(AddressSpace::default()),
) "pelement_casted",
}; )
.unwrap();
handle_slice_indices(lower, upper, step, ctx, generator, dim_sz) ctx.builder.build_load(pelement, "element").unwrap().as_basic_value_enum()
}
_ => {
let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) };
let index = index
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, dim)? else {
return Ok(None);
};
Ok(Some((index, index, llvm_i32.const_int(1, true))))
}
}
};
let make_indices_arr = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>|
-> Result<_, String> {
Ok(if let ExprKind::Tuple { elts, .. } = &slice.node {
let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(elts.len() as u64, false),
None,
)?;
for (i, elt) in elts.iter().enumerate() {
let Some(index) = generator.gen_expr(ctx, elt)? else {
return Ok(None);
};
let index = index
.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
return Ok(None);
};
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
None,
)
};
ctx.builder.build_store(store_ptr, index).unwrap();
}
Some(index_addr)
} else if let Some(index) = generator.gen_expr(ctx, slice)? {
let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(1u64, false),
None,
)?;
let index =
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder.build_store(store_ptr, index).unwrap();
Some(index_addr)
} else {
None
})
};
Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
v.data().get(ctx, generator, &index_addr, None).into()
} else { } else {
match &slice.node { // 2) ndims > 0 (other cases), return subndarray
ExprKind::Tuple { elts, .. } => { dst_ndarray.value.as_basic_value_enum()
let slices = elts };
.iter() Ok(Some(ValueEnum::Dynamic(result_llvm_value)))
.enumerate()
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
.collect::<Result<Vec<_>, _>>()?;
if slices.len() < elts.len() {
return Ok(None);
}
let slices = slices.into_iter().map(Option::unwrap).collect_vec();
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into()
}
ExprKind::Slice { .. } => {
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
return Ok(None);
};
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into()
}
_ => {
// Accessing an element from a multi-dimensional `ndarray`
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over
let subscripted_ndarray =
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
ctx,
generator,
ctx.builder
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
.unwrap(),
);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ctx
.builder
.build_int_z_extend_or_bit_cast(
ndarray.load_ndims(ctx),
llvm_usize.size_of().get_type(),
"",
)
.unwrap();
let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
v_dims_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
let ndarray_num_elems = ctx
.builder
.build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "")
.unwrap();
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
call_memcpy_generic(
ctx,
ndarray.data().base_ptr(ctx, generator),
v_data_src_ptr,
ctx.builder
.build_int_mul(
ndarray_num_elems,
llvm_ndarray_data_t.size_of().unwrap(),
"",
)
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
ndarray.as_base_value().into()
}
}
}))
} }
/// See [`CodeGenerator::gen_expr`]. /// See [`CodeGenerator::gen_expr`].
@ -3097,17 +2918,22 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} }
} }
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); let tyctx = generator.type_context(ctx.ctx);
let pndarray_model = PtrModel(StructModel(NpArray));
let v = if let Some(v) = generator.gen_expr(ctx, value)? { let (dtype, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
.into_pointer_value() let Some(ndarray) = generator.gen_expr(ctx, value)? else {
} else {
return Ok(None); return Ok(None);
}; };
let v = NDArrayValue::from_ptr_val(v, usize, None);
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); let ndarray =
ndarray.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
let ndarray = pndarray_model.check_value(tyctx, ctx.ctx, ndarray).unwrap();
return gen_ndarray_subscript_expr(
generator, ctx, *dtype, *ndims, ndarray, slice,
);
} }
TypeEnum::TTuple { .. } => { TypeEnum::TTuple { .. } => {
let index: u32 = let index: u32 =

View File

@ -0,0 +1,170 @@
use crate::codegen::{
irrt::{
error_context::{check_error_context, setup_error_context},
slice::{RustUserSlice, SliceIndex, UserSlice},
util::{function::CallFunction, get_sizet_dependent_function_name},
},
model::*,
structure::ndarray::NpArray,
CodeGenContext, CodeGenerator,
};
pub type NDIndexType = Byte;
#[derive(Debug, Clone, Copy)]
pub struct NDIndexFields<F: FieldVisitor> {
pub type_: F::Field<IntModel<NDIndexType>>, // Defined to be uint8_t in IRRT
pub data: F::Field<PtrModel<IntModel<Byte>>>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct NDIndex;
impl StructKind for NDIndex {
type Fields<F: FieldVisitor> = NDIndexFields<F>;
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F> {
Self::Fields { type_: visitor.add("type"), data: visitor.add("data") }
}
}
// An enum variant to store the content
// and type of an NDIndex in high level.
#[derive(Debug, Clone)]
pub enum RustNDIndex<'ctx> {
SingleElement(Int<'ctx, SliceIndex>),
Slice(RustUserSlice<'ctx>),
}
impl<'ctx> RustNDIndex<'ctx> {
fn get_type_id(&self) -> u64 {
// Defined in IRRT, must be in sync
match self {
RustNDIndex::SingleElement(_) => 0,
RustNDIndex::Slice(_) => 1,
}
}
fn write_to_ndindex(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
dst_ndindex_ptr: Ptr<'ctx, StructModel<NDIndex>>,
) {
let ndindex_type_model = IntModel(NDIndexType::default());
let slice_index_model = IntModel(SliceIndex::default());
let user_slice_model = StructModel(UserSlice);
// Set `dst_ndindex_ptr->type`
dst_ndindex_ptr
.gep(ctx, |f| f.type_)
.store(ctx, ndindex_type_model.constant(tyctx, ctx.ctx, self.get_type_id()));
// Set `dst_ndindex_ptr->data`
let data = match self {
RustNDIndex::SingleElement(in_index) => {
let index_ptr = slice_index_model.alloca(tyctx, ctx, "index");
index_ptr.store(ctx, *in_index);
index_ptr.transmute(tyctx, ctx, IntModel(Byte), "")
}
RustNDIndex::Slice(in_rust_slice) => {
let user_slice_ptr = user_slice_model.alloca(tyctx, ctx, "user_slice");
in_rust_slice.write_to_user_slice(tyctx, ctx, user_slice_ptr);
user_slice_ptr.transmute(tyctx, ctx, IntModel(Byte), "")
}
};
dst_ndindex_ptr.gep(ctx, |f| f.data).store(ctx, data);
}
/// Allocate an array of `NDIndex`es onto the stack and return its stack pointer.
pub fn alloca_ndindexes(
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
in_ndindexes: &[RustNDIndex<'ctx>],
) -> (Int<'ctx, SizeT>, Ptr<'ctx, StructModel<NDIndex>>) {
let sizet_model = IntModel(SizeT);
let ndindex_model = StructModel(NDIndex);
let num_ndindexes = sizet_model.constant(tyctx, ctx.ctx, in_ndindexes.len() as u64);
let ndindexes = ndindex_model.array_alloca(tyctx, ctx, num_ndindexes.value, "ndindexes");
for (i, in_ndindex) in in_ndindexes.iter().enumerate() {
let i = sizet_model.constant(tyctx, ctx.ctx, i as u64);
let pndindex = ndindexes.offset(tyctx, ctx, i.value, "");
in_ndindex.write_to_ndindex(tyctx, ctx, pndindex);
}
(num_ndindexes, ndindexes)
}
#[must_use]
pub fn deduce_ndims_after_slicing(slices: &[RustNDIndex], original_ndims: i32) -> i32 {
let mut final_ndims: i32 = original_ndims;
for slice in slices {
match slice {
RustNDIndex::SingleElement(_) => {
final_ndims -= 1;
}
RustNDIndex::Slice(_) => {}
}
}
final_ndims
}
}
pub fn call_nac3_ndarray_indexing_deduce_ndims_after_indexing<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: Int<'ctx, SizeT>,
num_ndindexes: Int<'ctx, SizeT>,
ndindexs: Ptr<'ctx, StructModel<NDIndex>>,
) -> Int<'ctx, SizeT> {
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let pfinal_ndims = sizet_model.alloca(tyctx, ctx, "pfinal_ndims");
let errctx_ptr = setup_error_context(tyctx, ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(
tyctx,
"__nac3_ndarray_indexing_deduce_ndims_after_indexing",
),
)
.arg("errctx", errctx_ptr)
.arg("result", pfinal_ndims)
.arg("ndims", ndims)
.arg("num_ndindexs", num_ndindexes)
.arg("ndindexs", ndindexs)
.returning_void();
check_error_context(generator, ctx, errctx_ptr);
pfinal_ndims.load(tyctx, ctx, "final_ndims")
}
pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
num_indexes: Int<'ctx, SizeT>,
indexes: Ptr<'ctx, StructModel<NDIndex>>,
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
dst_ndarray: Ptr<'ctx, StructModel<NpArray>>,
) {
let tyctx = generator.type_context(ctx.ctx);
let perrctx = setup_error_context(tyctx, ctx);
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_index"),
)
.arg("errctx", perrctx)
.arg("num_indexes", num_indexes)
.arg("indexes", indexes)
.arg("src_ndarray", src_ndarray)
.arg("dst_ndarray", dst_ndarray)
.returning_void();
check_error_context(generator, ctx, perrctx);
}

View File

@ -1,2 +1,3 @@
pub mod allocation; pub mod allocation;
pub mod basic; pub mod basic;
pub mod indexing;

View File

@ -1,3 +1,81 @@
use crate::codegen::model::*; use crate::codegen::{model::*, CodeGenContext};
// nac3core's slicing index/length values are always int32_t
pub type SliceIndex = Int32; pub type SliceIndex = Int32;
#[derive(Debug, Clone)]
pub struct UserSliceFields<F: FieldVisitor> {
pub start_defined: F::Field<IntModel<Bool>>,
pub start: F::Field<IntModel<SliceIndex>>,
pub stop_defined: F::Field<IntModel<Bool>>,
pub stop: F::Field<IntModel<SliceIndex>>,
pub step_defined: F::Field<IntModel<Bool>>,
pub step: F::Field<IntModel<SliceIndex>>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct UserSlice;
impl StructKind for UserSlice {
type Fields<F: FieldVisitor> = UserSliceFields<F>;
fn visit_fields<F: FieldVisitor>(&self, visitor: &mut F) -> Self::Fields<F> {
Self::Fields {
start_defined: visitor.add("start_defined"),
start: visitor.add("start"),
stop_defined: visitor.add("stop_defined"),
stop: visitor.add("stop"),
step_defined: visitor.add("step_defined"),
step: visitor.add("step"),
}
}
}
#[derive(Debug, Clone)]
pub struct RustUserSlice<'ctx> {
pub start: Option<Int<'ctx, SliceIndex>>,
pub stop: Option<Int<'ctx, SliceIndex>>,
pub step: Option<Int<'ctx, SliceIndex>>,
}
impl<'ctx> RustUserSlice<'ctx> {
// Set the values of an LLVM UserSlice
// in the format of Python's `slice()`
pub fn write_to_user_slice(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
dst_slice_ptr: Ptr<'ctx, StructModel<UserSlice>>,
) {
let bool_model = IntModel(Bool);
let false_ = bool_model.constant(tyctx, ctx.ctx, 0);
let true_ = bool_model.constant(tyctx, ctx.ctx, 1);
// TODO: Code duplication. Probably okay...?
match self.start {
Some(start) => {
dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.start).store(ctx, start);
}
None => dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, false_),
}
match self.stop {
Some(stop) => {
dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.stop).store(ctx, stop);
}
None => dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, false_),
}
match self.step {
Some(step) => {
dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.step).store(ctx, step);
}
None => dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, false_),
}
}
}