1
0
forked from M-Labs/nac3

core/ndstrides: add basic ndarray IRRT functions

This commit is contained in:
lyken 2024-07-28 16:08:37 +08:00
parent 3886dffe68
commit 79f66e8517
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
6 changed files with 627 additions and 1 deletions

View File

@ -0,0 +1,288 @@
#pragma once
#include <irrt/exception.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
namespace {
namespace ndarray {
namespace basic {
namespace util {
/**
* @brief Asserts that `shape` does not contain negative dimensions.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape to check on
*/
template <typename SizeT>
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) {
if (shape[axis] < 0) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"negative dimensions are not allowed; axis {0} "
"has dimension {1}",
axis, shape[axis], NO_PARAM);
}
}
}
/**
* @brief Returns the number of elements of an ndarray given its shape.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape of the ndarray
*/
template <typename SizeT>
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
SizeT size = 1;
for (SizeT axis = 0; axis < ndims; axis++) size *= shape[axis];
return size;
}
/**
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
*
* @param ndims Number of elements in `shape` and `indices`
* @param shape The shape of the ndarray
* @param indices The returned indices indexing the ndarray with shape `shape`.
* @param nth The index of the element of interest.
*/
template <typename SizeT>
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices,
SizeT nth) {
for (int32_t i = 0; i < ndims; i++) {
int32_t axis = ndims - i - 1;
int32_t dim = shape[axis];
indices[axis] = nth % dim;
nth /= dim;
}
}
} // namespace util
/**
* @brief Return the number of elements of an `ndarray`
*
* This function corresponds to `<an_ndarray>.size`
*/
template <typename SizeT>
SizeT size(const NDArray<SizeT>* ndarray) {
return util::calc_size_from_shape(ndarray->ndims, ndarray->shape);
}
/**
* @brief Return of the number of its content of an `ndarray`.
*
* This function corresponds to `<an_ndarray>.nbytes`.
*/
template <typename SizeT>
SizeT nbytes(const NDArray<SizeT>* ndarray) {
return size(ndarray) * ndarray->itemsize;
}
/**
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
*
* This function corresponds to `<an_ndarray>.__len__`.
*
* @param dst_length The returned result
*/
template <typename SizeT>
SizeT len(const NDArray<SizeT>* ndarray) {
// numpy prohibits `__len__` on unsized objects
if (ndarray->ndims == 0) {
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object",
NO_PARAM, NO_PARAM, NO_PARAM);
}
return ndarray->shape[0];
}
/**
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
*
* You may want to see: ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
*/
template <typename SizeT>
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
// Other references:
// - tinynumpy's implementation: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
// - ndarray's flags["C_CONTIGUOUS"]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
// - ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
// From https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
//
// The traditional rule is that for an array to be flagged as C contiguous,
// the following must hold:
//
// strides[-1] == itemsize
// strides[i] == shape[i+1] * strides[i + 1]
// [...]
// According to these rules, a 0- or 1-dimensional array is either both
// C- and F-contiguous, or neither; and an array with 2+ dimensions
// can be C- or F- contiguous, or neither, but not both. Though there
// there are exceptions for arrays with zero or one item, in the first
// case the check is relaxed up to and including the first dimension
// with shape[i] == 0. In the second case `strides == itemsize` will
// can be true for all dimensions and both flags are set.
if (ndarray->ndims == 0) {
return true;
}
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
return false;
}
for (SizeT i = 1; i < ndarray->ndims; i++) {
SizeT axis_i = ndarray->ndims - i - 1;
if (ndarray->strides[axis_i] !=
ndarray->shape[axis_i + 1] + ndarray->strides[axis_i + 1]) {
return false;
}
}
return true;
}
/**
* @brief Return the pointer to the element indexed by `indices`.
*/
template <typename SizeT>
uint8_t* get_pelement_by_indices(const NDArray<SizeT>* ndarray,
const SizeT* indices) {
uint8_t* element = ndarray->data;
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
element += indices[dim_i] * ndarray->strides[dim_i];
return element;
}
/**
* @brief Return the pointer to the nth (0-based) element in a flattened view of `ndarray`.
*
* This function does no bound check.
*/
template <typename SizeT>
uint8_t* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
SizeT* indices = (SizeT*)__builtin_alloca(sizeof(SizeT) * ndarray->ndims);
util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, nth);
return get_pelement_by_indices(ndarray, indices);
}
/**
* @brief Update the strides of an ndarray given an ndarray `shape`
* and assuming that the ndarray is fully c-contagious.
*
* You might want to read https://ajcr.net/stride-guide-part-1/.
*/
template <typename SizeT>
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
SizeT stride_product = 1;
for (SizeT i = 0; i < ndarray->ndims; i++) {
int axis = ndarray->ndims - i - 1;
ndarray->strides[axis] = stride_product * ndarray->itemsize;
stride_product *= ndarray->shape[axis];
}
}
/**
* @brief Set an element in `ndarray`.
*
* @param pelement Pointer to the element in `ndarray` to be set.
* @param pvalue Pointer to the value `pelement` will be set to.
*/
template <typename SizeT>
void set_pelement_value(NDArray<SizeT>* ndarray, uint8_t* pelement,
const uint8_t* pvalue) {
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
}
/**
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
*
* Both ndarrays will be viewed in their flatten views when copying the elements.
*/
template <typename SizeT>
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// TODO: Make this faster with memcpy
__builtin_assume(src_ndarray->itemsize == dst_ndarray->itemsize);
for (SizeT i = 0; i < size(src_ndarray); i++) {
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
ndarray::basic::set_pelement_value(dst_ndarray, dst_element,
src_element);
}
}
} // namespace basic
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::basic;
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims,
int32_t* shape) {
util::assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims,
int64_t* shape) {
util::assert_shape_no_negative(ndims, shape);
}
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
return size(ndarray);
}
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
return size(ndarray);
}
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
return nbytes(ndarray);
}
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
return nbytes(ndarray);
}
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) { return len(ndarray); }
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) { return len(ndarray); }
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
return is_c_contiguous(ndarray);
}
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
return is_c_contiguous(ndarray);
}
uint8_t* __nac3_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
return get_nth_pelement(ndarray, nth);
}
uint8_t* __nac3_get_nth_pelement64(const NDArray<int64_t>* ndarray,
int64_t nth) {
return get_nth_pelement(ndarray, nth);
}
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
}

View File

@ -3,5 +3,6 @@
#include <irrt/core.hpp>
#include <irrt/exception.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/util.hpp>

View File

@ -11,8 +11,10 @@
#define IRRT_TESTING
#include <test/test_core.hpp>
#include <test/test_ndarray_basic.hpp>
int main() {
test::core::run();
test::ndarray_basic::run();
return 0;
}

View File

@ -0,0 +1,30 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace ndarray_basic {
void test_calc_size_from_shape_normal() {
// Test shapes with normal values
BEGIN_TEST();
int32_t shape[4] = {2, 3, 5, 7};
assert_values_match(
210, ndarray::basic::util::calc_size_from_shape<int32_t>(4, shape));
}
void test_calc_size_from_shape_has_zero() {
// Test shapes with 0 in them
BEGIN_TEST();
int32_t shape[4] = {2, 0, 5, 7};
assert_values_match(
0, ndarray::basic::util::calc_size_from_shape<int32_t>(4, shape));
}
void run() {
test_calc_size_from_shape_normal();
test_calc_size_from_shape_has_zero();
}
} // namespace ndarray_basic
} // namespace test

View File

@ -5,6 +5,7 @@ mod test;
pub mod util;
use super::model::*;
use super::structure::ndarray::NpArray;
use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
@ -987,3 +988,117 @@ pub fn setup_irrt_exceptions<'ctx>(
global.set_initializer(&exn_id);
}
}
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: Int<'ctx, SizeT>,
shape: Ptr<'ctx, IntModel<SizeT>>,
) {
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_util_assert_shape_no_negative"),
)
.arg("ndims", ndims)
.arg("shape", shape)
.returning_void();
}
pub fn call_nac3_ndarray_size<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NpArray>>,
) -> Int<'ctx, SizeT> {
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_size"),
)
.arg("ndarray", pndarray)
.returning_auto("size")
}
pub fn call_nac3_ndarray_nbytes<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NpArray>>,
) -> Int<'ctx, SizeT> {
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_nbytes"),
)
.arg("ndarray", pndarray)
.returning_auto("nbytes")
}
pub fn call_nac3_ndarray_len<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NpArray>>,
) -> Int<'ctx, SizeT> {
CallFunction::begin(tyctx, ctx, &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_len"))
.arg("ndarray", pndarray)
.returning_auto("len")
}
pub fn call_nac3_ndarray_is_c_contiguous<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ptr: Ptr<'ctx, StructModel<NpArray>>,
) -> Int<'ctx, Bool> {
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_is_c_contiguous"),
)
.arg("ndarray", ndarray_ptr)
.returning_auto("is_c_contiguous")
}
pub fn call_nac3_ndarray_get_nth_pelement<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
pndarray: Ptr<'ctx, StructModel<NpArray>>,
index: Int<'ctx, SizeT>,
) -> Ptr<'ctx, IntModel<Byte>> {
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_get_nth_pelement"),
)
.arg("ndarray", pndarray)
.arg("index", index)
.returning_auto("pelement")
}
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
pdnarray: Ptr<'ctx, StructModel<NpArray>>,
) {
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_set_strides_by_shape"),
)
.arg("ndarray", pdnarray)
.returning_void();
}
pub fn call_nac3_ndarray_copy_data<'ctx>(
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
dst_ndarray: Ptr<'ctx, StructModel<NpArray>>,
) {
CallFunction::begin(
tyctx,
ctx,
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_copy_data"),
)
.arg("src_ndarray", src_ndarray)
.arg("dst_ndarray", dst_ndarray)
.returning_void();
}

View File

@ -1,4 +1,10 @@
use crate::codegen::*;
use irrt::{
call_nac3_ndarray_copy_data, call_nac3_ndarray_get_nth_pelement,
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
};
use crate::{codegen::*, symbol_resolver::SymbolValue};
pub struct NpArrayFields<'ctx, F: FieldTraversal<'ctx>> {
pub data: F::Out<PtrModel<IntModel<Byte>>>,
@ -26,9 +32,193 @@ impl<'ctx> StructKind<'ctx> for NpArray {
}
}
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
/// The `ndims` must only contain 1 value.
#[must_use]
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
panic!("ndims_ty should be a TLiteral");
};
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
let ndims = values[0].clone();
u64::try_from(ndims).unwrap()
}
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
}
#[derive(Debug, Clone, Copy)]
pub struct NDArrayObject<'ctx> {
pub dtype: Type,
pub ndims: Type,
pub value: Ptr<'ctx, StructModel<NpArray>>,
}
impl<'ctx> NDArrayObject<'ctx> {
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
///
/// `shape` and `strides` will be automatically allocated on the stack.
///
/// The returned ndarray's content will be:
/// - `data`: set to `nullptr`.
/// - `itemsize`: set to the `sizeof()` of `dtype`.
/// - `ndims`: set to the value of `ndims`.
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
pub fn alloca_uninitialized<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: Type,
name: &str,
) -> Self {
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let ndarray_model = StructModel(NpArray);
let ndarray_data_model = PtrModel(IntModel(Byte));
let pndarray = ndarray_model.alloca(tyctx, ctx, name);
let data = ndarray_data_model.nullptr(tyctx, ctx.ctx);
let itemsize = ctx.get_llvm_type(generator, dtype).size_of().unwrap();
let itemsize = sizet_model.s_extend_or_bit_cast(tyctx, ctx, itemsize, "itemsize");
let ndims_val = extract_ndims(&ctx.unifier, ndims);
let ndims_val = sizet_model.constant(tyctx, ctx.ctx, ndims_val);
let shape = sizet_model.array_alloca(tyctx, ctx, ndims_val.value, "shape");
let strides = sizet_model.array_alloca(tyctx, ctx, ndims_val.value, "strides");
pndarray.gep(ctx, |f| f.data).store(ctx, data);
pndarray.gep(ctx, |f| f.itemsize).store(ctx, itemsize);
pndarray.gep(ctx, |f| f.ndims).store(ctx, ndims_val);
pndarray.gep(ctx, |f| f.shape).store(ctx, shape);
pndarray.gep(ctx, |f| f.strides).store(ctx, strides);
NDArrayObject { dtype, ndims, value: pndarray }
}
/// Get this ndarray's `ndims` as an LLVM constant.
pub fn get_ndims(
&self,
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
let sizet_model = IntModel(SizeT);
let ndims_val = extract_ndims(&ctx.unifier, self.ndims);
sizet_model.constant(tyctx, ctx.ctx, ndims_val)
}
/// Return true if this ndarray is unsized.
#[must_use]
pub fn is_unsized(&self, unifier: &Unifier) -> bool {
extract_ndims(unifier, self.ndims) == 0
}
/// Initialize an ndarray's `data` by allocating a buffer on the stack.
/// The allocated data buffer is considered to be *owned* by the ndarray.
///
/// `strides` of the ndarray will also be updated with `set_strides_by_shape`.
///
/// `shape` and `itemsize` of the ndarray ***must*** be initialized first.
pub fn create_data<G: CodeGenerator + ?Sized>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
let byte_model = IntModel(Byte);
let data = byte_model.array_alloca(tyctx, ctx, self.get_ndims(tyctx, ctx).value, "data");
self.value.gep(ctx, |f| f.data).store(ctx, data);
self.update_strides_by_shape(tyctx, ctx);
}
/// Get the `np.size()` of this ndarray.
pub fn size(
&self,
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_size(tyctx, ctx, self.value)
}
/// Get the `ndarray.nbytes` of this ndarray.
pub fn nbytes(
&self,
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_nbytes(tyctx, ctx, self.value)
}
/// Get the `len()` of this ndarray.
pub fn len(
&self,
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_len(tyctx, ctx, self.value)
}
/// Check if this ndarray is C-contiguous.
///
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
&self,
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> {
call_nac3_ndarray_is_c_contiguous(tyctx, ctx, self.value)
}
/// Get the pointer to the n-th (0-based) element.
///
/// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`.
pub fn get_nth_pelement<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Int<'ctx, SizeT>,
name: &str,
) -> PointerValue<'ctx> {
let tyctx = generator.type_context(ctx.ctx);
let elem_ty = ctx.get_llvm_type(generator, self.dtype);
let p = call_nac3_ndarray_get_nth_pelement(tyctx, ctx, self.value, nth);
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), name)
.unwrap()
}
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
///
/// Please refer to the IRRT implementation to see its purpose.
pub fn update_strides_by_shape(
&self,
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
call_nac3_ndarray_set_strides_by_shape(tyctx, ctx, self.value);
}
/// Copy data from another ndarray.
///
/// Panics if the `dtype`s of ndarrays are different.
pub fn copy_data_from(
&self,
tyctx: TypeContext<'ctx>,
ctx: &mut CodeGenContext<'ctx, '_>,
src: NDArrayObject<'ctx>,
) {
assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match");
call_nac3_ndarray_copy_data(tyctx, ctx, src.value, self.value);
}
}