[core] codegen/irrt: Add IRRT functions for strided-ndarray
This commit is contained in:
parent
a99ae4828a
commit
624e943cd6
@ -3,3 +3,5 @@
|
|||||||
#include "irrt/math.hpp"
|
#include "irrt/math.hpp"
|
||||||
#include "irrt/ndarray.hpp"
|
#include "irrt/ndarray.hpp"
|
||||||
#include "irrt/slice.hpp"
|
#include "irrt/slice.hpp"
|
||||||
|
#include "irrt/ndarray/basic.hpp"
|
||||||
|
#include "irrt/ndarray/def.hpp"
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
#include "irrt/int_types.hpp"
|
||||||
|
|
||||||
|
// TODO: To be deleted since NDArray with strides is done.
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template<typename SizeT>
|
template<typename SizeT>
|
||||||
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
||||||
|
342
nac3core/irrt/irrt/ndarray/basic.hpp
Normal file
342
nac3core/irrt/irrt/ndarray/basic.hpp
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "irrt/debug.hpp"
|
||||||
|
#include "irrt/exception.hpp"
|
||||||
|
#include "irrt/int_types.hpp"
|
||||||
|
#include "irrt/ndarray/def.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray {
|
||||||
|
namespace basic {
|
||||||
|
/**
|
||||||
|
* @brief Assert that `shape` does not contain negative dimensions.
|
||||||
|
*
|
||||||
|
* @param ndims Number of dimensions in `shape`
|
||||||
|
* @param shape The shape to check on
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||||
|
if (shape[axis] < 0) {
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||||
|
"negative dimensions are not allowed; axis {0} "
|
||||||
|
"has dimension {1}",
|
||||||
|
axis, shape[axis], NO_PARAM);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Assert that two shapes are the same in the context of writing output to an ndarray.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void assert_output_shape_same(SizeT ndarray_ndims,
|
||||||
|
const SizeT* ndarray_shape,
|
||||||
|
SizeT output_ndims,
|
||||||
|
const SizeT* output_shape) {
|
||||||
|
if (ndarray_ndims != output_ndims) {
|
||||||
|
// There is no corresponding NumPy error message like this.
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}",
|
||||||
|
output_ndims, ndarray_ndims, NO_PARAM);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
|
||||||
|
if (ndarray_shape[axis] != output_shape[axis]) {
|
||||||
|
// There is no corresponding NumPy error message like this.
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||||
|
"Mismatched dimensions on axis {0}, output has "
|
||||||
|
"dimension {1}, but destination ndarray has dimension {2}.",
|
||||||
|
axis, output_shape[axis], ndarray_shape[axis]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return the number of elements of an ndarray given its shape.
|
||||||
|
*
|
||||||
|
* @param ndims Number of dimensions in `shape`
|
||||||
|
* @param shape The shape of the ndarray
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
||||||
|
SizeT size = 1;
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++)
|
||||||
|
size *= shape[axis];
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
|
||||||
|
*
|
||||||
|
* @param ndims Number of elements in `shape` and `indices`
|
||||||
|
* @param shape The shape of the ndarray
|
||||||
|
* @param indices The returned indices indexing the ndarray with shape `shape`.
|
||||||
|
* @param nth The index of the element of interest.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
||||||
|
for (SizeT i = 0; i < ndims; i++) {
|
||||||
|
SizeT axis = ndims - i - 1;
|
||||||
|
SizeT dim = shape[axis];
|
||||||
|
|
||||||
|
indices[axis] = nth % dim;
|
||||||
|
nth /= dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return the number of elements of an `ndarray`
|
||||||
|
*
|
||||||
|
* This function corresponds to `<an_ndarray>.size`
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT size(const NDArray<SizeT>* ndarray) {
|
||||||
|
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return of the number of its content of an `ndarray`.
|
||||||
|
*
|
||||||
|
* This function corresponds to `<an_ndarray>.nbytes`.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT nbytes(const NDArray<SizeT>* ndarray) {
|
||||||
|
return size(ndarray) * ndarray->itemsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
|
||||||
|
*
|
||||||
|
* This function corresponds to `<an_ndarray>.__len__`.
|
||||||
|
*
|
||||||
|
* @param dst_length The length.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT len(const NDArray<SizeT>* ndarray) {
|
||||||
|
if (ndarray->ndims != 0) {
|
||||||
|
return ndarray->shape[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
// numpy prohibits `__len__` on unsized objects
|
||||||
|
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM);
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
|
||||||
|
*
|
||||||
|
* You may want to see ndarray's rules for C-contiguity:
|
||||||
|
* https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
|
||||||
|
// References:
|
||||||
|
// - tinynumpy's implementation:
|
||||||
|
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
|
||||||
|
// - ndarray's flags["C_CONTIGUOUS"]:
|
||||||
|
// https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
|
||||||
|
// - ndarray's rules for C-contiguity:
|
||||||
|
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
||||||
|
|
||||||
|
// From
|
||||||
|
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
|
||||||
|
//
|
||||||
|
// The traditional rule is that for an array to be flagged as C contiguous,
|
||||||
|
// the following must hold:
|
||||||
|
//
|
||||||
|
// strides[-1] == itemsize
|
||||||
|
// strides[i] == shape[i+1] * strides[i + 1]
|
||||||
|
// [...]
|
||||||
|
// According to these rules, a 0- or 1-dimensional array is either both
|
||||||
|
// C- and F-contiguous, or neither; and an array with 2+ dimensions
|
||||||
|
// can be C- or F- contiguous, or neither, but not both. Though there
|
||||||
|
// there are exceptions for arrays with zero or one item, in the first
|
||||||
|
// case the check is relaxed up to and including the first dimension
|
||||||
|
// with shape[i] == 0. In the second case `strides == itemsize` will
|
||||||
|
// can be true for all dimensions and both flags are set.
|
||||||
|
|
||||||
|
if (ndarray->ndims == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (SizeT i = 1; i < ndarray->ndims; i++) {
|
||||||
|
SizeT axis_i = ndarray->ndims - i - 1;
|
||||||
|
if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return the pointer to the element indexed by `indices` along the ndarray's axes.
|
||||||
|
*
|
||||||
|
* This function does no bound check.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void* get_pelement_by_indices(const NDArray<SizeT>* ndarray, const SizeT* indices) {
|
||||||
|
void* element = ndarray->data;
|
||||||
|
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
|
||||||
|
element = static_cast<uint8_t*>(element) + indices[dim_i] * ndarray->strides[dim_i];
|
||||||
|
return element;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view.
|
||||||
|
*
|
||||||
|
* This function does no bound check.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
|
||||||
|
void* element = ndarray->data;
|
||||||
|
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
||||||
|
SizeT axis = ndarray->ndims - i - 1;
|
||||||
|
SizeT dim = ndarray->shape[axis];
|
||||||
|
element = static_cast<uint8_t*>(element) + ndarray->strides[axis] * (nth % dim);
|
||||||
|
nth /= dim;
|
||||||
|
}
|
||||||
|
return element;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous.
|
||||||
|
*
|
||||||
|
* You might want to read https://ajcr.net/stride-guide-part-1/.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
|
||||||
|
SizeT stride_product = 1;
|
||||||
|
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
||||||
|
SizeT axis = ndarray->ndims - i - 1;
|
||||||
|
ndarray->strides[axis] = stride_product * ndarray->itemsize;
|
||||||
|
stride_product *= ndarray->shape[axis];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Set an element in `ndarray`.
|
||||||
|
*
|
||||||
|
* @param pelement Pointer to the element in `ndarray` to be set.
|
||||||
|
* @param pvalue Pointer to the value `pelement` will be set to.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void set_pelement_value(NDArray<SizeT>* ndarray, void* pelement, const void* pvalue) {
|
||||||
|
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
|
||||||
|
*
|
||||||
|
* Both ndarrays will be viewed in their flatten views when copying the elements.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||||
|
// TODO: Make this faster with memcpy when we see a contiguous segment.
|
||||||
|
// TODO: Handle overlapping.
|
||||||
|
|
||||||
|
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < size(src_ndarray); i++) {
|
||||||
|
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
|
||||||
|
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
|
||||||
|
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace basic
|
||||||
|
} // namespace ndarray
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
using namespace ndarray::basic;
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) {
|
||||||
|
assert_shape_no_negative(ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) {
|
||||||
|
assert_shape_no_negative(ndims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
|
||||||
|
const int32_t* ndarray_shape,
|
||||||
|
int32_t output_ndims,
|
||||||
|
const int32_t* output_shape) {
|
||||||
|
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims,
|
||||||
|
const int64_t* ndarray_shape,
|
||||||
|
int64_t output_ndims,
|
||||||
|
const int64_t* output_shape) {
|
||||||
|
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
||||||
|
return size(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
||||||
|
return size(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
||||||
|
return nbytes(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
||||||
|
return nbytes(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
|
||||||
|
return len(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
|
||||||
|
return len(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
|
||||||
|
return is_c_contiguous(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
|
||||||
|
return is_c_contiguous(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
|
||||||
|
return get_nth_pelement(ndarray, nth);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray, int64_t nth) {
|
||||||
|
return get_nth_pelement(ndarray, nth);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* __nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t>* ndarray, int32_t* indices) {
|
||||||
|
return get_pelement_by_indices(ndarray, indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* __nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t>* ndarray, int64_t* indices) {
|
||||||
|
return get_pelement_by_indices(ndarray, indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
||||||
|
set_strides_by_shape(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
||||||
|
set_strides_by_shape(ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
|
||||||
|
copy_data(src_ndarray, dst_ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
|
||||||
|
copy_data(src_ndarray, dst_ndarray);
|
||||||
|
}
|
||||||
|
}
|
51
nac3core/irrt/irrt/ndarray/def.hpp
Normal file
51
nac3core/irrt/irrt/ndarray/def.hpp
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "irrt/int_types.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/**
|
||||||
|
* @brief The NDArray object
|
||||||
|
*
|
||||||
|
* Official numpy implementation:
|
||||||
|
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface
|
||||||
|
*
|
||||||
|
* Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The
|
||||||
|
* difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is
|
||||||
|
* that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject`
|
||||||
|
* does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its
|
||||||
|
* `data`. There are also minor differences in the struct layout.
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
struct NDArray {
|
||||||
|
/**
|
||||||
|
* @brief The number of bytes of a single element in `data`.
|
||||||
|
*/
|
||||||
|
SizeT itemsize;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief The number of dimensions of this shape.
|
||||||
|
*/
|
||||||
|
SizeT ndims;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief The NDArray shape, with length equal to `ndims`.
|
||||||
|
*
|
||||||
|
* Note that it may contain 0.
|
||||||
|
*/
|
||||||
|
SizeT* shape;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Array strides, with length equal to `ndims`
|
||||||
|
*
|
||||||
|
* The stride values are in units of bytes, not number of elements.
|
||||||
|
*
|
||||||
|
* Note that `strides` can have negative values or contain 0.
|
||||||
|
*/
|
||||||
|
SizeT* strides;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief The underlying data this `ndarray` is pointing to.
|
||||||
|
*/
|
||||||
|
void* data;
|
||||||
|
};
|
||||||
|
} // namespace
|
@ -59,6 +59,27 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
|
|||||||
irrt_mod
|
irrt_mod
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the name of a function which contains variants for 32-bit and 64-bit `size_t`.
|
||||||
|
///
|
||||||
|
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
|
||||||
|
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_usize_dependent_function_name<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'_, '_>,
|
||||||
|
name: &str,
|
||||||
|
) -> String {
|
||||||
|
let mut name = name.to_owned();
|
||||||
|
match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||||
|
32 => {}
|
||||||
|
64 => name.push_str("64"),
|
||||||
|
bit_width => {
|
||||||
|
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name
|
||||||
|
}
|
||||||
|
|
||||||
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
|
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
|
||||||
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
|
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
|
||||||
/// NO numeric slice in python.
|
/// NO numeric slice in python.
|
||||||
|
250
nac3core/src/codegen/irrt/ndarray/basic.rs
Normal file
250
nac3core/src/codegen/irrt/ndarray/basic.rs
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
use inkwell::{
|
||||||
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::codegen::{
|
||||||
|
expr::{create_and_call_function, infer_and_call_function},
|
||||||
|
irrt::get_usize_dependent_function_name,
|
||||||
|
types::ProxyType,
|
||||||
|
values::{NDArrayValue, ProxyValue},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndims: IntValue<'ctx>,
|
||||||
|
shape: PointerValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
"__nac3_ndarray_util_assert_shape_no_negative",
|
||||||
|
);
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_usize.into()),
|
||||||
|
&[(llvm_usize.into(), ndims.into()), (llvm_pusize.into(), shape.into())],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray_ndims: IntValue<'ctx>,
|
||||||
|
ndarray_shape: PointerValue<'ctx>,
|
||||||
|
output_ndims: IntValue<'ctx>,
|
||||||
|
output_shape: IntValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
"__nac3_ndarray_util_assert_output_shape_same",
|
||||||
|
);
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_usize.into()),
|
||||||
|
&[
|
||||||
|
(llvm_usize.into(), ndarray_ndims.into()),
|
||||||
|
(llvm_pusize.into(), ndarray_shape.into()),
|
||||||
|
(llvm_usize.into(), output_ndims.into()),
|
||||||
|
(llvm_pusize.into(), output_shape.into()),
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_usize.into()),
|
||||||
|
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
||||||
|
Some("size"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_usize.into()),
|
||||||
|
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
||||||
|
Some("nbytes"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_usize.into()),
|
||||||
|
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
||||||
|
Some("len"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
|
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_i1.into()),
|
||||||
|
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
||||||
|
Some("is_c_contiguous"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
index: IntValue<'ctx>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
|
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||||
|
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_pi8.into()),
|
||||||
|
&[(llvm_ndarray.into(), ndarray.as_base_value().into()), (llvm_usize.into(), index.into())],
|
||||||
|
Some("pelement"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
indices: PointerValue<'ctx>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
|
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||||
|
|
||||||
|
let name =
|
||||||
|
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
Some(llvm_pi8.into()),
|
||||||
|
&[
|
||||||
|
(llvm_ndarray.into(), ndarray.as_base_value().into()),
|
||||||
|
(llvm_pusize.into(), indices.into()),
|
||||||
|
],
|
||||||
|
Some("pelement"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let llvm_ndarray = ndarray.get_type().as_base_type();
|
||||||
|
|
||||||
|
let name =
|
||||||
|
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
|
||||||
|
|
||||||
|
create_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
None,
|
||||||
|
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
src_ndarray: NDArrayValue<'ctx>,
|
||||||
|
dst_ndarray: NDArrayValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
|
||||||
|
|
||||||
|
infer_and_call_function(
|
||||||
|
ctx,
|
||||||
|
&name,
|
||||||
|
None,
|
||||||
|
&[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
@ -15,6 +15,9 @@ use crate::codegen::{
|
|||||||
},
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
pub use basic::*;
|
||||||
|
|
||||||
|
mod basic;
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
||||||
/// calculated total size.
|
/// calculated total size.
|
||||||
@ -77,7 +80,7 @@ where
|
|||||||
/// `NDArray`.
|
/// `NDArray`.
|
||||||
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &G,
|
generator: &G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
index: IntValue<'ctx>,
|
index: IntValue<'ctx>,
|
||||||
ndarray: NDArrayValue<'ctx>,
|
ndarray: NDArrayValue<'ctx>,
|
||||||
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||||
@ -201,8 +204,8 @@ where
|
|||||||
/// `NDArray`.
|
/// `NDArray`.
|
||||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||||
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
|
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
|
||||||
generator: &mut G,
|
generator: &G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
ndarray: NDArrayValue<'ctx>,
|
ndarray: NDArrayValue<'ctx>,
|
||||||
indices: &Index,
|
indices: &Index,
|
||||||
) -> IntValue<'ctx>
|
) -> IntValue<'ctx>
|
Loading…
Reference in New Issue
Block a user