From 46110d2f6a22cc94c95e4fd188161a60c5db885e Mon Sep 17 00:00:00 2001 From: lyken Date: Sun, 28 Jul 2024 16:08:37 +0800 Subject: [PATCH] core/ndstrides: add basic ndarray IRRT functions --- nac3core/irrt/irrt/ndarray/basic.hpp | 288 ++++++++++++++++++++++ nac3core/irrt/irrt_everything.hpp | 1 + nac3core/irrt/irrt_test.cpp | 2 + nac3core/irrt/test/test_ndarray_basic.hpp | 30 +++ nac3core/src/codegen/irrt/mod.rs | 115 +++++++++ nac3core/src/codegen/structure/ndarray.rs | 192 ++++++++++++++- 6 files changed, 627 insertions(+), 1 deletion(-) create mode 100644 nac3core/irrt/irrt/ndarray/basic.hpp create mode 100644 nac3core/irrt/test/test_ndarray_basic.hpp diff --git a/nac3core/irrt/irrt/ndarray/basic.hpp b/nac3core/irrt/irrt/ndarray/basic.hpp new file mode 100644 index 00000000..502aed79 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/basic.hpp @@ -0,0 +1,288 @@ +#pragma once + +#include +#include +#include + +namespace { +namespace ndarray { +namespace basic { +namespace util { +/** + * @brief Asserts that `shape` does not contain negative dimensions. + * + * @param ndims Number of dimensions in `shape` + * @param shape The shape to check on + */ +template +void assert_shape_no_negative(SizeT ndims, const SizeT* shape) { + for (SizeT axis = 0; axis < ndims; axis++) { + if (shape[axis] < 0) { + raise_exception(SizeT, EXN_VALUE_ERROR, + "negative dimensions are not allowed; axis {0} " + "has dimension {1}", + axis, shape[axis], NO_PARAM); + } + } +} + +/** + * @brief Returns the number of elements of an ndarray given its shape. + * + * @param ndims Number of dimensions in `shape` + * @param shape The shape of the ndarray + */ +template +SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) { + SizeT size = 1; + for (SizeT axis = 0; axis < ndims; axis++) size *= shape[axis]; + return size; +} + +/** + * @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape. + * + * @param ndims Number of elements in `shape` and `indices` + * @param shape The shape of the ndarray + * @param indices The returned indices indexing the ndarray with shape `shape`. + * @param nth The index of the element of interest. + */ +template +void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, + SizeT nth) { + for (int32_t i = 0; i < ndims; i++) { + int32_t axis = ndims - i - 1; + int32_t dim = shape[axis]; + + indices[axis] = nth % dim; + nth /= dim; + } +} +} // namespace util + +/** + * @brief Return the number of elements of an `ndarray` + * + * This function corresponds to `.size` + */ +template +SizeT size(const NDArray* ndarray) { + return util::calc_size_from_shape(ndarray->ndims, ndarray->shape); +} + +/** + * @brief Return of the number of its content of an `ndarray`. + * + * This function corresponds to `.nbytes`. + */ +template +SizeT nbytes(const NDArray* ndarray) { + return size(ndarray) * ndarray->itemsize; +} + +/** + * @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object. + * + * This function corresponds to `.__len__`. + * + * @param dst_length The returned result + */ +template +SizeT len(const NDArray* ndarray) { + // numpy prohibits `__len__` on unsized objects + if (ndarray->ndims == 0) { + raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", + NO_PARAM, NO_PARAM, NO_PARAM); + } + return ndarray->shape[0]; +} + +/** + * @brief Return a boolean indicating if `ndarray` is (C-)contiguous. + * + * You may want to see: ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45 + */ +template +bool is_c_contiguous(const NDArray* ndarray) { + // Other references: + // - tinynumpy's implementation: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102 + // - ndarray's flags["C_CONTIGUOUS"]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags + // - ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45 + + // From https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45: + // + // The traditional rule is that for an array to be flagged as C contiguous, + // the following must hold: + // + // strides[-1] == itemsize + // strides[i] == shape[i+1] * strides[i + 1] + // [...] + // According to these rules, a 0- or 1-dimensional array is either both + // C- and F-contiguous, or neither; and an array with 2+ dimensions + // can be C- or F- contiguous, or neither, but not both. Though there + // there are exceptions for arrays with zero or one item, in the first + // case the check is relaxed up to and including the first dimension + // with shape[i] == 0. In the second case `strides == itemsize` will + // can be true for all dimensions and both flags are set. + + if (ndarray->ndims == 0) { + return true; + } + + if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) { + return false; + } + + for (SizeT i = 1; i < ndarray->ndims; i++) { + SizeT axis_i = ndarray->ndims - i - 1; + if (ndarray->strides[axis_i] != + ndarray->shape[axis_i + 1] + ndarray->strides[axis_i + 1]) { + return false; + } + } + + return true; +} + +/** + * @brief Return the pointer to the element indexed by `indices`. + */ +template +uint8_t* get_pelement_by_indices(const NDArray* ndarray, + const SizeT* indices) { + uint8_t* element = ndarray->data; + for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++) + element += indices[dim_i] * ndarray->strides[dim_i]; + return element; +} + +/** + * @brief Return the pointer to the nth (0-based) element in a flattened view of `ndarray`. + * + * This function does no bound check. + */ +template +uint8_t* get_nth_pelement(const NDArray* ndarray, SizeT nth) { + SizeT* indices = (SizeT*)__builtin_alloca(sizeof(SizeT) * ndarray->ndims); + util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, nth); + return get_pelement_by_indices(ndarray, indices); +} + +/** + * @brief Update the strides of an ndarray given an ndarray `shape` + * and assuming that the ndarray is fully c-contagious. + * + * You might want to read https://ajcr.net/stride-guide-part-1/. + */ +template +void set_strides_by_shape(NDArray* ndarray) { + SizeT stride_product = 1; + for (SizeT i = 0; i < ndarray->ndims; i++) { + int axis = ndarray->ndims - i - 1; + ndarray->strides[axis] = stride_product * ndarray->itemsize; + stride_product *= ndarray->shape[axis]; + } +} + +/** + * @brief Set an element in `ndarray`. + * + * @param pelement Pointer to the element in `ndarray` to be set. + * @param pvalue Pointer to the value `pelement` will be set to. + */ +template +void set_pelement_value(NDArray* ndarray, uint8_t* pelement, + const uint8_t* pvalue) { + __builtin_memcpy(pelement, pvalue, ndarray->itemsize); +} + +/** + * @brief Copy data from one ndarray to another of the exact same size and itemsize. + * + * Both ndarrays will be viewed in their flatten views when copying the elements. + */ +template +void copy_data(const NDArray* src_ndarray, NDArray* dst_ndarray) { + // TODO: Make this faster with memcpy + + __builtin_assume(src_ndarray->itemsize == dst_ndarray->itemsize); + + for (SizeT i = 0; i < size(src_ndarray); i++) { + auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i); + auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i); + ndarray::basic::set_pelement_value(dst_ndarray, dst_element, + src_element); + } +} +} // namespace basic +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::basic; + +void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, + int32_t* shape) { + util::assert_shape_no_negative(ndims, shape); +} + +void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, + int64_t* shape) { + util::assert_shape_no_negative(ndims, shape); +} + +uint32_t __nac3_ndarray_size(NDArray* ndarray) { + return size(ndarray); +} + +uint64_t __nac3_ndarray_size64(NDArray* ndarray) { + return size(ndarray); +} + +uint32_t __nac3_ndarray_nbytes(NDArray* ndarray) { + return nbytes(ndarray); +} + +uint64_t __nac3_ndarray_nbytes64(NDArray* ndarray) { + return nbytes(ndarray); +} + +int32_t __nac3_ndarray_len(NDArray* ndarray) { return len(ndarray); } + +int64_t __nac3_ndarray_len64(NDArray* ndarray) { return len(ndarray); } + +bool __nac3_ndarray_is_c_contiguous(NDArray* ndarray) { + return is_c_contiguous(ndarray); +} + +bool __nac3_ndarray_is_c_contiguous64(NDArray* ndarray) { + return is_c_contiguous(ndarray); +} + +uint8_t* __nac3_get_nth_pelement(const NDArray* ndarray, int32_t nth) { + return get_nth_pelement(ndarray, nth); +} + +uint8_t* __nac3_get_nth_pelement64(const NDArray* ndarray, + int64_t nth) { + return get_nth_pelement(ndarray, nth); +} + +void __nac3_ndarray_set_strides_by_shape(NDArray* ndarray) { + set_strides_by_shape(ndarray); +} + +void __nac3_ndarray_set_strides_by_shape64(NDArray* ndarray) { + set_strides_by_shape(ndarray); +} + +void __nac3_ndarray_copy_data(NDArray* src_ndarray, + NDArray* dst_ndarray) { + copy_data(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_copy_data64(NDArray* src_ndarray, + NDArray* dst_ndarray) { + copy_data(src_ndarray, dst_ndarray); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp index d1db2364..13e0168d 100644 --- a/nac3core/irrt/irrt_everything.hpp +++ b/nac3core/irrt/irrt_everything.hpp @@ -3,5 +3,6 @@ #include #include #include +#include #include #include diff --git a/nac3core/irrt/irrt_test.cpp b/nac3core/irrt/irrt_test.cpp index 7ffa2e66..9183eba0 100644 --- a/nac3core/irrt/irrt_test.cpp +++ b/nac3core/irrt/irrt_test.cpp @@ -11,8 +11,10 @@ #define IRRT_TESTING #include +#include int main() { test::core::run(); + test::ndarray_basic::run(); return 0; } \ No newline at end of file diff --git a/nac3core/irrt/test/test_ndarray_basic.hpp b/nac3core/irrt/test/test_ndarray_basic.hpp new file mode 100644 index 00000000..1bbdab26 --- /dev/null +++ b/nac3core/irrt/test/test_ndarray_basic.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include + +namespace test { +namespace ndarray_basic { +void test_calc_size_from_shape_normal() { + // Test shapes with normal values + BEGIN_TEST(); + + int32_t shape[4] = {2, 3, 5, 7}; + assert_values_match( + 210, ndarray::basic::util::calc_size_from_shape(4, shape)); +} + +void test_calc_size_from_shape_has_zero() { + // Test shapes with 0 in them + BEGIN_TEST(); + + int32_t shape[4] = {2, 0, 5, 7}; + assert_values_match( + 0, ndarray::basic::util::calc_size_from_shape(4, shape)); +} + +void run() { + test_calc_size_from_shape_normal(); + test_calc_size_from_shape_has_zero(); +} +} // namespace ndarray_basic +} // namespace test \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 9150afe8..8ed02f87 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -5,6 +5,7 @@ mod test; pub mod util; use super::model::*; +use super::structure::ndarray::NpArray; use super::{ classes::{ ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, @@ -987,3 +988,117 @@ pub fn setup_irrt_exceptions<'ctx>( global.set_initializer(&exn_id); } } + +pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx>( + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + ndims: Int<'ctx, SizeT>, + shape: Ptr<'ctx, IntModel>, +) { + CallFunction::begin( + tyctx, + ctx, + &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_util_assert_shape_no_negative"), + ) + .arg("ndims", ndims) + .arg("shape", shape) + .returning_void(); +} + +pub fn call_nac3_ndarray_size<'ctx>( + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + pndarray: Ptr<'ctx, StructModel>, +) -> Int<'ctx, SizeT> { + CallFunction::begin( + tyctx, + ctx, + &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_size"), + ) + .arg("ndarray", pndarray) + .returning_auto("size") +} + +pub fn call_nac3_ndarray_nbytes<'ctx>( + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + pndarray: Ptr<'ctx, StructModel>, +) -> Int<'ctx, SizeT> { + CallFunction::begin( + tyctx, + ctx, + &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_nbytes"), + ) + .arg("ndarray", pndarray) + .returning_auto("nbytes") +} + +pub fn call_nac3_ndarray_len<'ctx>( + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + pndarray: Ptr<'ctx, StructModel>, +) -> Int<'ctx, SizeT> { + CallFunction::begin(tyctx, ctx, &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_len")) + .arg("ndarray", pndarray) + .returning_auto("len") +} + +pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray_ptr: Ptr<'ctx, StructModel>, +) -> Int<'ctx, Bool> { + CallFunction::begin( + tyctx, + ctx, + &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_is_c_contiguous"), + ) + .arg("ndarray", ndarray_ptr) + .returning_auto("is_c_contiguous") +} + +pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + pndarray: Ptr<'ctx, StructModel>, + index: Int<'ctx, SizeT>, +) -> Ptr<'ctx, IntModel> { + CallFunction::begin( + tyctx, + ctx, + &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_get_nth_pelement"), + ) + .arg("ndarray", pndarray) + .arg("index", index) + .returning_auto("pelement") +} + +pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + pdnarray: Ptr<'ctx, StructModel>, +) { + CallFunction::begin( + tyctx, + ctx, + &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_set_strides_by_shape"), + ) + .arg("ndarray", pdnarray) + .returning_void(); +} + +pub fn call_nac3_ndarray_copy_data<'ctx>( + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Ptr<'ctx, StructModel>, + dst_ndarray: Ptr<'ctx, StructModel>, +) { + CallFunction::begin( + tyctx, + ctx, + &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_copy_data"), + ) + .arg("src_ndarray", src_ndarray) + .arg("dst_ndarray", dst_ndarray) + .returning_void(); +} diff --git a/nac3core/src/codegen/structure/ndarray.rs b/nac3core/src/codegen/structure/ndarray.rs index a7313c22..a0c44561 100644 --- a/nac3core/src/codegen/structure/ndarray.rs +++ b/nac3core/src/codegen/structure/ndarray.rs @@ -1,4 +1,10 @@ -use crate::codegen::*; +use irrt::{ + call_nac3_ndarray_copy_data, call_nac3_ndarray_get_nth_pelement, + call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_len, call_nac3_ndarray_nbytes, + call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, +}; + +use crate::{codegen::*, symbol_resolver::SymbolValue}; pub struct NpArrayFields<'ctx, F: FieldTraversal<'ctx>> { pub data: F::Out>>, @@ -26,9 +32,193 @@ impl<'ctx> StructKind<'ctx> for NpArray { } } +/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible. +/// The `ndims` must only contain 1 value. +#[must_use] +pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 { + let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty); + let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else { + panic!("ndims_ty should be a TLiteral"); + }; + + assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value"); + + let ndims = values[0].clone(); + u64::try_from(ndims).unwrap() +} + +/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value. +pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type { + unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None) +} + #[derive(Debug, Clone, Copy)] pub struct NDArrayObject<'ctx> { pub dtype: Type, pub ndims: Type, pub value: Ptr<'ctx, StructModel>, } + +impl<'ctx> NDArrayObject<'ctx> { + /// Allocate an ndarray on the stack given its `ndims` and `dtype`. + /// + /// `shape` and `strides` will be automatically allocated on the stack. + /// + /// The returned ndarray's content will be: + /// - `data`: set to `nullptr`. + /// - `itemsize`: set to the `sizeof()` of `dtype`. + /// - `ndims`: set to the value of `ndims`. + /// - `shape`: allocated with an array of length `ndims` with uninitialized values. + /// - `strides`: allocated with an array of length `ndims` with uninitialized values. + pub fn alloca_uninitialized( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: Type, + name: &str, + ) -> Self { + let tyctx = generator.type_context(ctx.ctx); + let sizet_model = IntModel(SizeT); + let ndarray_model = StructModel(NpArray); + let ndarray_data_model = PtrModel(IntModel(Byte)); + + let pndarray = ndarray_model.alloca(tyctx, ctx, name); + + let data = ndarray_data_model.nullptr(tyctx, ctx.ctx); + + let itemsize = ctx.get_llvm_type(generator, dtype).size_of().unwrap(); + let itemsize = sizet_model.s_extend_or_bit_cast(tyctx, ctx, itemsize, "itemsize"); + + let ndims_val = extract_ndims(&ctx.unifier, ndims); + let ndims_val = sizet_model.constant(tyctx, ctx.ctx, ndims_val); + + let shape = sizet_model.array_alloca(tyctx, ctx, ndims_val.value, "shape"); + let strides = sizet_model.array_alloca(tyctx, ctx, ndims_val.value, "strides"); + + pndarray.gep(ctx, |f| f.data).store(ctx, data); + pndarray.gep(ctx, |f| f.itemsize).store(ctx, itemsize); + pndarray.gep(ctx, |f| f.ndims).store(ctx, ndims_val); + pndarray.gep(ctx, |f| f.shape).store(ctx, shape); + pndarray.gep(ctx, |f| f.strides).store(ctx, strides); + + NDArrayObject { dtype, ndims, value: pndarray } + } + + /// Get this ndarray's `ndims` as an LLVM constant. + pub fn get_ndims( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + ) -> Int<'ctx, SizeT> { + let sizet_model = IntModel(SizeT); + + let ndims_val = extract_ndims(&ctx.unifier, self.ndims); + sizet_model.constant(tyctx, ctx.ctx, ndims_val) + } + + /// Return true if this ndarray is unsized. + #[must_use] + pub fn is_unsized(&self, unifier: &Unifier) -> bool { + extract_ndims(unifier, self.ndims) == 0 + } + + /// Initialize an ndarray's `data` by allocating a buffer on the stack. + /// The allocated data buffer is considered to be *owned* by the ndarray. + /// + /// `strides` of the ndarray will also be updated with `set_strides_by_shape`. + /// + /// `shape` and `itemsize` of the ndarray ***must*** be initialized first. + pub fn create_data( + &self, + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + ) { + let byte_model = IntModel(Byte); + + let data = byte_model.array_alloca(tyctx, ctx, self.get_ndims(tyctx, ctx).value, "data"); + self.value.gep(ctx, |f| f.data).store(ctx, data); + + self.update_strides_by_shape(tyctx, ctx); + } + + /// Get the `np.size()` of this ndarray. + pub fn size( + &self, + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Int<'ctx, SizeT> { + call_nac3_ndarray_size(tyctx, ctx, self.value) + } + + /// Get the `ndarray.nbytes` of this ndarray. + pub fn nbytes( + &self, + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Int<'ctx, SizeT> { + call_nac3_ndarray_nbytes(tyctx, ctx, self.value) + } + + /// Get the `len()` of this ndarray. + pub fn len( + &self, + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Int<'ctx, SizeT> { + call_nac3_ndarray_len(tyctx, ctx, self.value) + } + + /// Check if this ndarray is C-contiguous. + /// + /// See NumPy's `flags["C_CONTIGUOUS"]`: + pub fn is_c_contiguous( + &self, + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Int<'ctx, Bool> { + call_nac3_ndarray_is_c_contiguous(tyctx, ctx, self.value) + } + + /// Get the pointer to the n-th (0-based) element. + /// + /// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`. + pub fn get_nth_pelement( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + nth: Int<'ctx, SizeT>, + name: &str, + ) -> PointerValue<'ctx> { + let tyctx = generator.type_context(ctx.ctx); + let elem_ty = ctx.get_llvm_type(generator, self.dtype); + + let p = call_nac3_ndarray_get_nth_pelement(tyctx, ctx, self.value, nth); + ctx.builder + .build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), name) + .unwrap() + } + + /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. + /// + /// Please refer to the IRRT implementation to see its purpose. + pub fn update_strides_by_shape( + &self, + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + ) { + call_nac3_ndarray_set_strides_by_shape(tyctx, ctx, self.value); + } + + /// Copy data from another ndarray. + /// + /// Panics if the `dtype`s of ndarrays are different. + pub fn copy_data_from( + &self, + tyctx: TypeContext<'ctx>, + ctx: &mut CodeGenContext<'ctx, '_>, + src: NDArrayObject<'ctx>, + ) { + assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match"); + call_nac3_ndarray_copy_data(tyctx, ctx, src.value, self.value); + } +}