From 624e943cd6138a8c56836e202531abec1107e0cc Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 25 Nov 2024 15:46:56 +0800 Subject: [PATCH] [core] codegen/irrt: Add IRRT functions for strided-ndarray --- nac3core/irrt/irrt.cpp | 2 + nac3core/irrt/irrt/ndarray.hpp | 2 + nac3core/irrt/irrt/ndarray/basic.hpp | 342 ++++++++++++++++++ nac3core/irrt/irrt/ndarray/def.hpp | 51 +++ nac3core/src/codegen/irrt/mod.rs | 21 ++ nac3core/src/codegen/irrt/ndarray/basic.rs | 250 +++++++++++++ .../irrt/{ndarray.rs => ndarray/mod.rs} | 9 +- 7 files changed, 674 insertions(+), 3 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/basic.hpp create mode 100644 nac3core/irrt/irrt/ndarray/def.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/basic.rs rename nac3core/src/codegen/irrt/{ndarray.rs => ndarray/mod.rs} (99%) diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 7966322a..088b84fb 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -3,3 +3,5 @@ #include "irrt/math.hpp" #include "irrt/ndarray.hpp" #include "irrt/slice.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/def.hpp" diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp index b239152a..9abfc56a 100644 --- a/nac3core/irrt/irrt/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray.hpp @@ -2,6 +2,8 @@ #include "irrt/int_types.hpp" +// TODO: To be deleted since NDArray with strides is done. + namespace { template SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) { diff --git a/nac3core/irrt/irrt/ndarray/basic.hpp b/nac3core/irrt/irrt/ndarray/basic.hpp new file mode 100644 index 00000000..05ee30fc --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/basic.hpp @@ -0,0 +1,342 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +namespace ndarray { +namespace basic { +/** + * @brief Assert that `shape` does not contain negative dimensions. + * + * @param ndims Number of dimensions in `shape` + * @param shape The shape to check on + */ +template +void assert_shape_no_negative(SizeT ndims, const SizeT* shape) { + for (SizeT axis = 0; axis < ndims; axis++) { + if (shape[axis] < 0) { + raise_exception(SizeT, EXN_VALUE_ERROR, + "negative dimensions are not allowed; axis {0} " + "has dimension {1}", + axis, shape[axis], NO_PARAM); + } + } +} + +/** + * @brief Assert that two shapes are the same in the context of writing output to an ndarray. + */ +template +void assert_output_shape_same(SizeT ndarray_ndims, + const SizeT* ndarray_shape, + SizeT output_ndims, + const SizeT* output_shape) { + if (ndarray_ndims != output_ndims) { + // There is no corresponding NumPy error message like this. + raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}", + output_ndims, ndarray_ndims, NO_PARAM); + } + + for (SizeT axis = 0; axis < ndarray_ndims; axis++) { + if (ndarray_shape[axis] != output_shape[axis]) { + // There is no corresponding NumPy error message like this. + raise_exception(SizeT, EXN_VALUE_ERROR, + "Mismatched dimensions on axis {0}, output has " + "dimension {1}, but destination ndarray has dimension {2}.", + axis, output_shape[axis], ndarray_shape[axis]); + } + } +} + +/** + * @brief Return the number of elements of an ndarray given its shape. + * + * @param ndims Number of dimensions in `shape` + * @param shape The shape of the ndarray + */ +template +SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) { + SizeT size = 1; + for (SizeT axis = 0; axis < ndims; axis++) + size *= shape[axis]; + return size; +} + +/** + * @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape. + * + * @param ndims Number of elements in `shape` and `indices` + * @param shape The shape of the ndarray + * @param indices The returned indices indexing the ndarray with shape `shape`. + * @param nth The index of the element of interest. + */ +template +void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) { + for (SizeT i = 0; i < ndims; i++) { + SizeT axis = ndims - i - 1; + SizeT dim = shape[axis]; + + indices[axis] = nth % dim; + nth /= dim; + } +} + +/** + * @brief Return the number of elements of an `ndarray` + * + * This function corresponds to `.size` + */ +template +SizeT size(const NDArray* ndarray) { + return calc_size_from_shape(ndarray->ndims, ndarray->shape); +} + +/** + * @brief Return of the number of its content of an `ndarray`. + * + * This function corresponds to `.nbytes`. + */ +template +SizeT nbytes(const NDArray* ndarray) { + return size(ndarray) * ndarray->itemsize; +} + +/** + * @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object. + * + * This function corresponds to `.__len__`. + * + * @param dst_length The length. + */ +template +SizeT len(const NDArray* ndarray) { + if (ndarray->ndims != 0) { + return ndarray->shape[0]; + } + + // numpy prohibits `__len__` on unsized objects + raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM); + __builtin_unreachable(); +} + +/** + * @brief Return a boolean indicating if `ndarray` is (C-)contiguous. + * + * You may want to see ndarray's rules for C-contiguity: + * https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45 + */ +template +bool is_c_contiguous(const NDArray* ndarray) { + // References: + // - tinynumpy's implementation: + // https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102 + // - ndarray's flags["C_CONTIGUOUS"]: + // https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags + // - ndarray's rules for C-contiguity: + // https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45 + + // From + // https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45: + // + // The traditional rule is that for an array to be flagged as C contiguous, + // the following must hold: + // + // strides[-1] == itemsize + // strides[i] == shape[i+1] * strides[i + 1] + // [...] + // According to these rules, a 0- or 1-dimensional array is either both + // C- and F-contiguous, or neither; and an array with 2+ dimensions + // can be C- or F- contiguous, or neither, but not both. Though there + // there are exceptions for arrays with zero or one item, in the first + // case the check is relaxed up to and including the first dimension + // with shape[i] == 0. In the second case `strides == itemsize` will + // can be true for all dimensions and both flags are set. + + if (ndarray->ndims == 0) { + return true; + } + + if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) { + return false; + } + + for (SizeT i = 1; i < ndarray->ndims; i++) { + SizeT axis_i = ndarray->ndims - i - 1; + if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) { + return false; + } + } + + return true; +} + +/** + * @brief Return the pointer to the element indexed by `indices` along the ndarray's axes. + * + * This function does no bound check. + */ +template +void* get_pelement_by_indices(const NDArray* ndarray, const SizeT* indices) { + void* element = ndarray->data; + for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++) + element = static_cast(element) + indices[dim_i] * ndarray->strides[dim_i]; + return element; +} + +/** + * @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view. + * + * This function does no bound check. + */ +template +void* get_nth_pelement(const NDArray* ndarray, SizeT nth) { + void* element = ndarray->data; + for (SizeT i = 0; i < ndarray->ndims; i++) { + SizeT axis = ndarray->ndims - i - 1; + SizeT dim = ndarray->shape[axis]; + element = static_cast(element) + ndarray->strides[axis] * (nth % dim); + nth /= dim; + } + return element; +} + +/** + * @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous. + * + * You might want to read https://ajcr.net/stride-guide-part-1/. + */ +template +void set_strides_by_shape(NDArray* ndarray) { + SizeT stride_product = 1; + for (SizeT i = 0; i < ndarray->ndims; i++) { + SizeT axis = ndarray->ndims - i - 1; + ndarray->strides[axis] = stride_product * ndarray->itemsize; + stride_product *= ndarray->shape[axis]; + } +} + +/** + * @brief Set an element in `ndarray`. + * + * @param pelement Pointer to the element in `ndarray` to be set. + * @param pvalue Pointer to the value `pelement` will be set to. + */ +template +void set_pelement_value(NDArray* ndarray, void* pelement, const void* pvalue) { + __builtin_memcpy(pelement, pvalue, ndarray->itemsize); +} + +/** + * @brief Copy data from one ndarray to another of the exact same size and itemsize. + * + * Both ndarrays will be viewed in their flatten views when copying the elements. + */ +template +void copy_data(const NDArray* src_ndarray, NDArray* dst_ndarray) { + // TODO: Make this faster with memcpy when we see a contiguous segment. + // TODO: Handle overlapping. + + debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize); + + for (SizeT i = 0; i < size(src_ndarray); i++) { + auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i); + auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i); + ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element); + } +} +} // namespace basic +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::basic; + +void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) { + assert_shape_no_negative(ndims, shape); +} + +void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) { + assert_shape_no_negative(ndims, shape); +} + +void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims, + const int32_t* ndarray_shape, + int32_t output_ndims, + const int32_t* output_shape) { + assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape); +} + +void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims, + const int64_t* ndarray_shape, + int64_t output_ndims, + const int64_t* output_shape) { + assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape); +} + +uint32_t __nac3_ndarray_size(NDArray* ndarray) { + return size(ndarray); +} + +uint64_t __nac3_ndarray_size64(NDArray* ndarray) { + return size(ndarray); +} + +uint32_t __nac3_ndarray_nbytes(NDArray* ndarray) { + return nbytes(ndarray); +} + +uint64_t __nac3_ndarray_nbytes64(NDArray* ndarray) { + return nbytes(ndarray); +} + +int32_t __nac3_ndarray_len(NDArray* ndarray) { + return len(ndarray); +} + +int64_t __nac3_ndarray_len64(NDArray* ndarray) { + return len(ndarray); +} + +bool __nac3_ndarray_is_c_contiguous(NDArray* ndarray) { + return is_c_contiguous(ndarray); +} + +bool __nac3_ndarray_is_c_contiguous64(NDArray* ndarray) { + return is_c_contiguous(ndarray); +} + +void* __nac3_ndarray_get_nth_pelement(const NDArray* ndarray, int32_t nth) { + return get_nth_pelement(ndarray, nth); +} + +void* __nac3_ndarray_get_nth_pelement64(const NDArray* ndarray, int64_t nth) { + return get_nth_pelement(ndarray, nth); +} + +void* __nac3_ndarray_get_pelement_by_indices(const NDArray* ndarray, int32_t* indices) { + return get_pelement_by_indices(ndarray, indices); +} + +void* __nac3_ndarray_get_pelement_by_indices64(const NDArray* ndarray, int64_t* indices) { + return get_pelement_by_indices(ndarray, indices); +} + +void __nac3_ndarray_set_strides_by_shape(NDArray* ndarray) { + set_strides_by_shape(ndarray); +} + +void __nac3_ndarray_set_strides_by_shape64(NDArray* ndarray) { + set_strides_by_shape(ndarray); +} + +void __nac3_ndarray_copy_data(NDArray* src_ndarray, NDArray* dst_ndarray) { + copy_data(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_copy_data64(NDArray* src_ndarray, NDArray* dst_ndarray) { + copy_data(src_ndarray, dst_ndarray); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/def.hpp b/nac3core/irrt/irrt/ndarray/def.hpp new file mode 100644 index 00000000..7cf5c755 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/def.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include "irrt/int_types.hpp" + +namespace { +/** + * @brief The NDArray object + * + * Official numpy implementation: + * https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface + * + * Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The + * difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is + * that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject` + * does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its + * `data`. There are also minor differences in the struct layout. + */ +template +struct NDArray { + /** + * @brief The number of bytes of a single element in `data`. + */ + SizeT itemsize; + + /** + * @brief The number of dimensions of this shape. + */ + SizeT ndims; + + /** + * @brief The NDArray shape, with length equal to `ndims`. + * + * Note that it may contain 0. + */ + SizeT* shape; + + /** + * @brief Array strides, with length equal to `ndims` + * + * The stride values are in units of bytes, not number of elements. + * + * Note that `strides` can have negative values or contain 0. + */ + SizeT* strides; + + /** + * @brief The underlying data this `ndarray` is pointing to. + */ + void* data; +}; +} // namespace \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 970dbf27..ed922db5 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -59,6 +59,27 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) irrt_mod } +/// Returns the name of a function which contains variants for 32-bit and 64-bit `size_t`. +/// +/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`. +/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`. +#[must_use] +pub fn get_usize_dependent_function_name( + generator: &G, + ctx: &CodeGenContext<'_, '_>, + name: &str, +) -> String { + let mut name = name.to_owned(); + match generator.get_size_type(ctx.ctx).get_bit_width() { + 32 => {} + 64 => name.push_str("64"), + bit_width => { + panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits") + } + } + name +} + /// NOTE: the output value of the end index of this function should be compared ***inclusively***, /// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to /// NO numeric slice in python. diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs new file mode 100644 index 00000000..ae194b45 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -0,0 +1,250 @@ +use inkwell::{ + values::{BasicValueEnum, IntValue, PointerValue}, + AddressSpace, +}; + +use crate::codegen::{ + expr::{create_and_call_function, infer_and_call_function}, + irrt::get_usize_dependent_function_name, + types::ProxyType, + values::{NDArrayValue, ProxyValue}, + CodeGenContext, CodeGenerator, +}; + +pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ndims: IntValue<'ctx>, + shape: PointerValue<'ctx>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let name = get_usize_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_util_assert_shape_no_negative", + ); + + create_and_call_function( + ctx, + &name, + Some(llvm_usize.into()), + &[(llvm_usize.into(), ndims.into()), (llvm_pusize.into(), shape.into())], + None, + None, + ); +} + +pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ndarray_ndims: IntValue<'ctx>, + ndarray_shape: PointerValue<'ctx>, + output_ndims: IntValue<'ctx>, + output_shape: IntValue<'ctx>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let name = get_usize_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_util_assert_output_shape_same", + ); + + create_and_call_function( + ctx, + &name, + Some(llvm_usize.into()), + &[ + (llvm_usize.into(), ndarray_ndims.into()), + (llvm_pusize.into(), ndarray_shape.into()), + (llvm_usize.into(), output_ndims.into()), + (llvm_pusize.into(), output_shape.into()), + ], + None, + None, + ); +} + +pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, +) -> IntValue<'ctx> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_ndarray = ndarray.get_type().as_base_type(); + + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size"); + + create_and_call_function( + ctx, + &name, + Some(llvm_usize.into()), + &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + Some("size"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() +} + +pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, +) -> IntValue<'ctx> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_ndarray = ndarray.get_type().as_base_type(); + + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes"); + + create_and_call_function( + ctx, + &name, + Some(llvm_usize.into()), + &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + Some("nbytes"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() +} + +pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, +) -> IntValue<'ctx> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_ndarray = ndarray.get_type().as_base_type(); + + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len"); + + create_and_call_function( + ctx, + &name, + Some(llvm_usize.into()), + &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + Some("len"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() +} + +pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, +) -> IntValue<'ctx> { + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_ndarray = ndarray.get_type().as_base_type(); + + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous"); + + create_and_call_function( + ctx, + &name, + Some(llvm_i1.into()), + &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + Some("is_c_contiguous"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() +} + +pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + index: IntValue<'ctx>, +) -> PointerValue<'ctx> { + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_ndarray = ndarray.get_type().as_base_type(); + + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); + + create_and_call_function( + ctx, + &name, + Some(llvm_pi8.into()), + &[(llvm_ndarray.into(), ndarray.as_base_value().into()), (llvm_usize.into(), index.into())], + Some("pelement"), + None, + ) + .map(BasicValueEnum::into_pointer_value) + .unwrap() +} + +pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + indices: PointerValue<'ctx>, +) -> PointerValue<'ctx> { + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + let llvm_ndarray = ndarray.get_type().as_base_type(); + + let name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices"); + + create_and_call_function( + ctx, + &name, + Some(llvm_pi8.into()), + &[ + (llvm_ndarray.into(), ndarray.as_base_value().into()), + (llvm_pusize.into(), indices.into()), + ], + Some("pelement"), + None, + ) + .map(BasicValueEnum::into_pointer_value) + .unwrap() +} + +pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, +) { + let llvm_ndarray = ndarray.get_type().as_base_type(); + + let name = + get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape"); + + create_and_call_function( + ctx, + &name, + None, + &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + None, + None, + ); +} + +pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayValue<'ctx>, + dst_ndarray: NDArrayValue<'ctx>, +) { + let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data"); + + infer_and_call_function( + ctx, + &name, + None, + &[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs similarity index 99% rename from nac3core/src/codegen/irrt/ndarray.rs rename to nac3core/src/codegen/irrt/ndarray/mod.rs index 541c1175..b9d02d12 100644 --- a/nac3core/src/codegen/irrt/ndarray.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -15,6 +15,9 @@ use crate::codegen::{ }, CodeGenContext, CodeGenerator, }; +pub use basic::*; + +mod basic; /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the /// calculated total size. @@ -77,7 +80,7 @@ where /// `NDArray`. pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( generator: &G, - ctx: &mut CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'ctx, '_>, index: IntValue<'ctx>, ndarray: NDArrayValue<'ctx>, ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { @@ -201,8 +204,8 @@ where /// `NDArray`. /// * `indices` - The multidimensional index to compute the flattened index for. pub fn call_ndarray_flatten_index<'ctx, G, Index>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, indices: &Index, ) -> IntValue<'ctx>