From 4a26a0e222d5de01017604605d81a8510b66e309 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 11:38:05 +0800 Subject: [PATCH] core/ndstrides: introduce NDArray NDArray with strides. --- nac3core/irrt/irrt.cpp | 4 +- nac3core/irrt/irrt/ndarray/basic.hpp | 380 +++++++++++++++++++++ nac3core/irrt/irrt/ndarray/def.hpp | 47 +++ nac3core/src/codegen/irrt/mod.rs | 118 ++++++- nac3core/src/codegen/mod.rs | 13 +- nac3core/src/codegen/object/mod.rs | 1 + nac3core/src/codegen/object/ndarray/mod.rs | 349 +++++++++++++++++++ 7 files changed, 902 insertions(+), 10 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/basic.hpp create mode 100644 nac3core/irrt/irrt/ndarray/def.hpp create mode 100644 nac3core/src/codegen/object/ndarray/mod.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 31ba09f1..1a4dcf93 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -1,4 +1,6 @@ #include #include #include -#include +#include +#include +#include \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/basic.hpp b/nac3core/irrt/irrt/ndarray/basic.hpp new file mode 100644 index 00000000..20ed1e83 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/basic.hpp @@ -0,0 +1,380 @@ +#pragma once + +#include +#include +#include +#include + +namespace +{ +namespace ndarray +{ +namespace basic +{ +/** + * @brief Asserts that `shape` does not contain negative dimensions. + * + * @param ndims Number of dimensions in `shape` + * @param shape The shape to check on + */ +template void assert_shape_no_negative(SizeT ndims, const SizeT *shape) +{ + for (SizeT axis = 0; axis < ndims; axis++) + { + if (shape[axis] < 0) + { + raise_exception(SizeT, EXN_VALUE_ERROR, + "negative dimensions are not allowed; axis {0} " + "has dimension {1}", + axis, shape[axis], NO_PARAM); + } + } +} + +/** + * @brief Check two shapes are the same in the context of writing outputting to an ndarray. + * + * This function throws error messages for output shape mismatches. + */ +template +void assert_output_shape_same(SizeT ndarray_ndims, const SizeT *ndarray_shape, SizeT output_ndims, + const SizeT *output_shape) +{ + if (ndarray_ndims != output_ndims) + { + // There is no corresponding NumPy error message like this. + raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}", + output_ndims, ndarray_ndims, NO_PARAM); + } + + for (SizeT axis = 0; axis < ndarray_ndims; axis++) + { + if (ndarray_shape[axis] != output_shape[axis]) + { + // There is no corresponding NumPy error message like this. + raise_exception(SizeT, EXN_VALUE_ERROR, + "Mismatched dimensions on axis {0}, output has " + "dimension {1}, but destination ndarray has dimension {2}.", + axis, output_shape[axis], ndarray_shape[axis]); + } + } +} + +/** + * @brief Returns the number of elements of an ndarray given its shape. + * + * @param ndims Number of dimensions in `shape` + * @param shape The shape of the ndarray + */ +template SizeT calc_size_from_shape(SizeT ndims, const SizeT *shape) +{ + SizeT size = 1; + for (SizeT axis = 0; axis < ndims; axis++) + size *= shape[axis]; + return size; +} + +/** + * @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape. + * + * @param ndims Number of elements in `shape` and `indices` + * @param shape The shape of the ndarray + * @param indices The returned indices indexing the ndarray with shape `shape`. + * @param nth The index of the element of interest. + */ +template void set_indices_by_nth(SizeT ndims, const SizeT *shape, SizeT *indices, SizeT nth) +{ + for (SizeT i = 0; i < ndims; i++) + { + SizeT axis = ndims - i - 1; + SizeT dim = shape[axis]; + + indices[axis] = nth % dim; + nth /= dim; + } +} + +/** + * @brief Return the number of elements of an `ndarray` + * + * This function corresponds to `.size` + */ +template SizeT size(const NDArray *ndarray) +{ + return calc_size_from_shape(ndarray->ndims, ndarray->shape); +} + +/** + * @brief Return of the number of its content of an `ndarray`. + * + * This function corresponds to `.nbytes`. + */ +template SizeT nbytes(const NDArray *ndarray) +{ + return size(ndarray) * ndarray->itemsize; +} + +/** + * @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object. + * + * This function corresponds to `.__len__`. + * + * @param dst_length The returned result + */ +template SizeT len(const NDArray *ndarray) +{ + // numpy prohibits `__len__` on unsized objects + if (ndarray->ndims == 0) + { + raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM); + } + else + { + return ndarray->shape[0]; + } +} + +/** + * @brief Return a boolean indicating if `ndarray` is (C-)contiguous. + * + * You may want to see: ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45 + */ +template bool is_c_contiguous(const NDArray *ndarray) +{ + // Other references: + // - tinynumpy's implementation: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102 + // - ndarray's flags["C_CONTIGUOUS"]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags + // - ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45 + + // From https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45: + // + // The traditional rule is that for an array to be flagged as C contiguous, + // the following must hold: + // + // strides[-1] == itemsize + // strides[i] == shape[i+1] * strides[i + 1] + // [...] + // According to these rules, a 0- or 1-dimensional array is either both + // C- and F-contiguous, or neither; and an array with 2+ dimensions + // can be C- or F- contiguous, or neither, but not both. Though there + // there are exceptions for arrays with zero or one item, in the first + // case the check is relaxed up to and including the first dimension + // with shape[i] == 0. In the second case `strides == itemsize` will + // can be true for all dimensions and both flags are set. + + if (ndarray->ndims == 0) + { + return true; + } + + if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) + { + return false; + } + + for (SizeT i = 1; i < ndarray->ndims; i++) + { + SizeT axis_i = ndarray->ndims - i - 1; + if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) + { + return false; + } + } + + return true; +} + +/** + * @brief Return the pointer to the element indexed by `indices`. + */ +template uint8_t *get_pelement_by_indices(const NDArray *ndarray, const SizeT *indices) +{ + uint8_t *element = ndarray->data; + for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++) + element += indices[dim_i] * ndarray->strides[dim_i]; + return element; +} + +/** + * @brief Convenience function. Like `get_pelement_by_indices` but + * reinterprets the element pointer. + */ +template T *get_ptr(const NDArray *ndarray, const SizeT *indices) +{ + return (T *)get_pelement_by_indices(ndarray, indices); +} + +/** + * @brief Return the pointer to the nth (0-based) element in a flattened view of `ndarray`. + * + * This function does no bound check. + */ +template uint8_t *get_nth_pelement(const NDArray *ndarray, SizeT nth) +{ + uint8_t *element = ndarray->data; + for (SizeT i = 0; i < ndarray->ndims; i++) + { + SizeT axis = ndarray->ndims - i - 1; + SizeT dim = ndarray->shape[axis]; + element += ndarray->strides[axis] * (nth % dim); + nth /= dim; + } + return element; +} + +/** + * @brief Update the strides of an ndarray given an ndarray `shape` + * and assuming that the ndarray is fully c-contagious. + * + * You might want to read https://ajcr.net/stride-guide-part-1/. + */ +template void set_strides_by_shape(NDArray *ndarray) +{ + SizeT stride_product = 1; + for (SizeT i = 0; i < ndarray->ndims; i++) + { + SizeT axis = ndarray->ndims - i - 1; + ndarray->strides[axis] = stride_product * ndarray->itemsize; + stride_product *= ndarray->shape[axis]; + } +} + +/** + * @brief Set an element in `ndarray`. + * + * @param pelement Pointer to the element in `ndarray` to be set. + * @param pvalue Pointer to the value `pelement` will be set to. + */ +template void set_pelement_value(NDArray *ndarray, uint8_t *pelement, const uint8_t *pvalue) +{ + __builtin_memcpy(pelement, pvalue, ndarray->itemsize); +} + +/** + * @brief Copy data from one ndarray to another of the exact same size and itemsize. + * + * Both ndarrays will be viewed in their flatten views when copying the elements. + */ +template void copy_data(const NDArray *src_ndarray, NDArray *dst_ndarray) +{ + // TODO: Make this faster with memcpy + + debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize); + + for (SizeT i = 0; i < size(src_ndarray); i++) + { + auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i); + auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i); + ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element); + } +} +} // namespace basic +} // namespace ndarray +} // namespace + +extern "C" +{ + using namespace ndarray::basic; + + void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t *shape) + { + assert_shape_no_negative(ndims, shape); + } + + void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t *shape) + { + assert_shape_no_negative(ndims, shape); + } + + void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims, const int32_t *ndarray_shape, + int32_t output_ndims, const int32_t *output_shape) + { + assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape); + } + + void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims, const int64_t *ndarray_shape, + int64_t output_ndims, const int64_t *output_shape) + { + assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape); + } + + uint32_t __nac3_ndarray_size(NDArray *ndarray) + { + return size(ndarray); + } + + uint64_t __nac3_ndarray_size64(NDArray *ndarray) + { + return size(ndarray); + } + + uint32_t __nac3_ndarray_nbytes(NDArray *ndarray) + { + return nbytes(ndarray); + } + + uint64_t __nac3_ndarray_nbytes64(NDArray *ndarray) + { + return nbytes(ndarray); + } + + int32_t __nac3_ndarray_len(NDArray *ndarray) + { + return len(ndarray); + } + + int64_t __nac3_ndarray_len64(NDArray *ndarray) + { + return len(ndarray); + } + + bool __nac3_ndarray_is_c_contiguous(NDArray *ndarray) + { + return is_c_contiguous(ndarray); + } + + bool __nac3_ndarray_is_c_contiguous64(NDArray *ndarray) + { + return is_c_contiguous(ndarray); + } + + uint8_t *__nac3_ndarray_get_nth_pelement(const NDArray *ndarray, int32_t nth) + { + return get_nth_pelement(ndarray, nth); + } + + uint8_t *__nac3_ndarray_get_nth_pelement64(const NDArray *ndarray, int64_t nth) + { + return get_nth_pelement(ndarray, nth); + } + + uint8_t *__nac3_ndarray_get_pelement_by_indices(const NDArray *ndarray, int32_t *indices) + { + return get_pelement_by_indices(ndarray, indices); + } + + uint8_t *__nac3_ndarray_get_pelement_by_indices64(const NDArray *ndarray, int64_t *indices) + { + return get_pelement_by_indices(ndarray, indices); + } + + void __nac3_ndarray_set_strides_by_shape(NDArray *ndarray) + { + set_strides_by_shape(ndarray); + } + + void __nac3_ndarray_set_strides_by_shape64(NDArray *ndarray) + { + set_strides_by_shape(ndarray); + } + + void __nac3_ndarray_copy_data(NDArray *src_ndarray, NDArray *dst_ndarray) + { + copy_data(src_ndarray, dst_ndarray); + } + + void __nac3_ndarray_copy_data64(NDArray *src_ndarray, NDArray *dst_ndarray) + { + copy_data(src_ndarray, dst_ndarray); + } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/def.hpp b/nac3core/irrt/irrt/ndarray/def.hpp new file mode 100644 index 00000000..270109b1 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/def.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include + +namespace +{ +/** + * @brief The NDArray object + * + * The official numpy implementations: https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst + */ +template struct NDArray +{ + /** + * @brief The underlying data this `ndarray` is pointing to. + * + * Must be set to `nullptr` to indicate that this NDArray's `data` is uninitialized. + */ + uint8_t *data; + + /** + * @brief The number of bytes of a single element in `data`. + */ + SizeT itemsize; + + /** + * @brief The number of dimensions of this shape. + */ + SizeT ndims; + + /** + * @brief The NDArray shape, with length equal to `ndims`. + * + * Note that it may contain 0. + */ + SizeT *shape; + + /** + * @brief Array strides, with length equal to `ndims` + * + * The stride values are in units of bytes, not number of elements. + * + * Note that `strides` can have negative values. + */ + SizeT *strides; +}; +} // namespace \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index b60021c8..2ad905d2 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -5,10 +5,14 @@ use super::{ ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, }, - llvm_intrinsics, CodeGenContext, CodeGenerator, + llvm_intrinsics, + model::*, + object::ndarray::NDArray, + CodeGenContext, CodeGenerator, }; use crate::codegen::classes::TypedArrayLikeAccessor; use crate::codegen::stmt::gen_for_callback_incrementing; +use function::CallFunction; use inkwell::{ attributes::{Attribute, AttributeLoc}, context::Context, @@ -974,3 +978,115 @@ pub fn setup_irrt_exceptions<'ctx>( global.set_initializer(&exn_id); } } + +pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndims: Instance<'ctx, Int>, + shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_util_assert_shape_no_negative", + ); + CallFunction::begin(generator, ctx, &name).arg(ndims).arg(shape).returning_void(); +} + +pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray_ndims: Instance<'ctx, Int>, + ndarray_shape: Instance<'ctx, Ptr>>, + output_ndims: Instance<'ctx, Int>, + output_shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_util_assert_output_shape_same", + ); + CallFunction::begin(generator, ctx, &name) + .arg(ndarray_ndims) + .arg(ndarray_shape) + .arg(output_ndims) + .arg(output_shape) + .returning_void(); +} + +pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_size"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_auto("size") +} + +pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_auto("nbytes") +} + +pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_len"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_auto("len") +} + +pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_auto("is_c_contiguous") +} + +pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, + index: Instance<'ctx, Int>, +) -> Instance<'ctx, Ptr>> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).arg(index).returning_auto("pelement") +} + +pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, + indices: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Ptr>> { + let name = + get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).arg(indices).returning_auto("pelement") +} + +pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) { + let name = + get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_void(); +} + +pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data"); + CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void(); +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index cbf51402..b90cb59b 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,7 +1,7 @@ use crate::{ - codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, + codegen::classes::{ListType, ProxyType, RangeType}, symbol_resolver::{StaticValue, SymbolResolver}, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, + toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, @@ -24,7 +24,9 @@ use inkwell::{ AddressSpace, IntPredicate, OptimizationLevel, }; use itertools::Itertools; +use model::*; use nac3parser::ast::{Location, Stmt, StrRef}; +use object::ndarray::NDArray; use parking_lot::{Condvar, Mutex}; use std::collections::{HashMap, HashSet}; use std::sync::{ @@ -491,12 +493,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); - let element_type = get_llvm_type( - ctx, module, generator, unifier, top_level, type_cache, dtype, - ); - - NDArrayType::new(generator, ctx, element_type).as_base_type().into() + Ptr(Struct(NDArray)).get_type(generator, ctx).as_basic_type_enum() } _ => unreachable!( diff --git a/nac3core/src/codegen/object/mod.rs b/nac3core/src/codegen/object/mod.rs index 93d2f453..f9abd433 100644 --- a/nac3core/src/codegen/object/mod.rs +++ b/nac3core/src/codegen/object/mod.rs @@ -1 +1,2 @@ pub mod any; +pub mod ndarray; diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs new file mode 100644 index 00000000..9bd0f65c --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -0,0 +1,349 @@ +use inkwell::{context::Context, types::BasicType, values::PointerValue, AddressSpace}; + +use crate::{ + codegen::{ + irrt::{ + call_nac3_ndarray_copy_data, call_nac3_ndarray_get_nth_pelement, + call_nac3_ndarray_get_pelement_by_indices, call_nac3_ndarray_is_c_contiguous, + call_nac3_ndarray_len, call_nac3_ndarray_nbytes, + call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, + }, + model::*, + CodeGenContext, CodeGenerator, + }, + toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, + typecheck::typedef::Type, +}; + +use super::any::AnyObject; + +/// Fields of [`NDArray`] +pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> { + pub data: F::Out>>, + pub itemsize: F::Out>, + pub ndims: F::Out>, + pub shape: F::Out>>, + pub strides: F::Out>>, +} + +/// A strided ndarray in NAC3. +/// +/// See IRRT implementation for details about its fields. +#[derive(Debug, Clone, Copy, Default)] +pub struct NDArray; + +impl<'ctx> StructKind<'ctx> for NDArray { + type Fields> = NDArrayFields<'ctx, F>; + + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + data: traversal.add_auto("data"), + itemsize: traversal.add_auto("itemsize"), + ndims: traversal.add_auto("ndims"), + shape: traversal.add_auto("shape"), + strides: traversal.add_auto("strides"), + } + } +} + +/// A NAC3 Python ndarray object. +#[derive(Debug, Clone, Copy)] +pub struct NDArrayObject<'ctx> { + pub dtype: Type, + pub ndims: u64, + pub instance: Instance<'ctx, Ptr>>, +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Attempt to convert an [`AnyObject`] into an [`NDArrayObject`]. + pub fn from_object( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + ) -> NDArrayObject<'ctx> { + let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty); + let ndims = extract_ndims(&ctx.unifier, ndims); + + let value = Ptr(Struct(NDArray)).check_value(generator, ctx.ctx, object.value).unwrap(); + NDArrayObject { dtype, ndims, instance: value } + } + + /// Get this ndarray's `ndims` as an LLVM constant. + pub fn ndims_llvm( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Int> { + Int(SizeT).const_int(generator, ctx, self.ndims) + } + + /// Allocate an ndarray on the stack given its `ndims` and `dtype`. + /// + /// `shape` and `strides` will be automatically allocated on the stack. + //e + /// The returned ndarray's content will be: + /// - `data`: set to `nullptr`. + /// - `itemsize`: set to the `sizeof()` of `dtype`. + /// - `ndims`: set to the value of `ndims`. + /// - `shape`: allocated with an array of length `ndims` with uninitialized values. + /// - `strides`: allocated with an array of length `ndims` with uninitialized values. + pub fn alloca( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + ) -> Self { + let ndarray = Struct(NDArray).alloca(generator, ctx); + + let data = Ptr(Int(Byte)).nullptr(generator, ctx.ctx); + ndarray.set(ctx, |f| f.data, data); + + let itemsize = ctx.get_llvm_type(generator, dtype).size_of().unwrap(); + let itemsize = Int(SizeT).z_extend_or_truncate(generator, ctx, itemsize); + ndarray.set(ctx, |f| f.itemsize, itemsize); + + let ndims_val = Int(SizeT).const_int(generator, ctx.ctx, ndims); + ndarray.set(ctx, |f| f.ndims, ndims_val); + + let shape = Int(SizeT).array_alloca(generator, ctx, ndims_val.value); + ndarray.set(ctx, |f| f.shape, shape); + + let strides = Int(SizeT).array_alloca(generator, ctx, ndims_val.value); + ndarray.set(ctx, |f| f.strides, strides); + + NDArrayObject { dtype, ndims, instance: ndarray } + } + + /// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape. + /// + /// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized. + pub fn alloca_constant_shape( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &[u64], + ) -> Self { + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64); + + // Write shape + let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape); + for (i, dim) in shape.iter().enumerate() { + let dim = Int(SizeT).const_int(generator, ctx.ctx, *dim); + dst_shape.offset_const(ctx, i as u64).store(ctx, dim); + } + + ndarray + } + + /// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape. + /// + /// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized. + pub fn alloca_dynamic_shape( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &[Instance<'ctx, Int>], + ) -> Self { + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64); + + // Write shape + let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape); + for (i, dim) in shape.iter().enumerate() { + dst_shape.offset_const(ctx, i as u64).store(ctx, *dim); + } + + ndarray + } + + /// Initialize an ndarray's `data` by allocating a buffer on the stack. + /// The allocated data buffer is considered to be *owned* by the ndarray. + /// + /// `strides` of the ndarray will also be updated with `set_strides_by_shape`. + /// + /// `shape` and `itemsize` of the ndarray ***must*** be initialized first. + pub fn create_data( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) { + let nbytes = self.nbytes(generator, ctx); + + let data = Int(Byte).array_alloca(generator, ctx, nbytes.value); + self.instance.set(ctx, |f| f.data, data); + + self.set_strides_contiguous(generator, ctx); + } + + /// Copy shape dimensions from an array. + pub fn copy_shape_from_array( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: Instance<'ctx, Ptr>>, + ) { + let num_items = self.ndims_llvm(generator, ctx.ctx).value; + self.instance.get(generator, ctx, |f| f.shape).copy_from(generator, ctx, shape, num_items); + } + + /// Copy shape dimensions from an ndarray. + /// Panics if `ndims` mismatches. + pub fn copy_shape_from_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayObject<'ctx>, + ) { + assert_eq!(self.ndims, src_ndarray.ndims); + let src_shape = src_ndarray.instance.get(generator, ctx, |f| f.shape); + self.copy_shape_from_array(generator, ctx, src_shape); + } + + /// Copy strides dimensions from an array. + pub fn copy_strides_from_array( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + strides: Instance<'ctx, Ptr>>, + ) { + let num_items = self.ndims_llvm(generator, ctx.ctx).value; + self.instance + .get(generator, ctx, |f| f.strides) + .copy_from(generator, ctx, strides, num_items); + } + + /// Copy strides dimensions from an ndarray. + /// Panics if `ndims` mismatches. + pub fn copy_strides_from_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayObject<'ctx>, + ) { + assert_eq!(self.ndims, src_ndarray.ndims); + let src_strides = src_ndarray.instance.get(generator, ctx, |f| f.strides); + self.copy_strides_from_array(generator, ctx, src_strides); + } + + /// Get the `np.size()` of this ndarray. + pub fn size( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_ndarray_size(generator, ctx, self.instance) + } + + /// Get the `ndarray.nbytes` of this ndarray. + pub fn nbytes( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_ndarray_nbytes(generator, ctx, self.instance) + } + + /// Get the `len()` of this ndarray. + pub fn len( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_ndarray_len(generator, ctx, self.instance) + } + + /// Check if this ndarray is C-contiguous. + /// + /// See NumPy's `flags["C_CONTIGUOUS"]`: + pub fn is_c_contiguous( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_ndarray_is_c_contiguous(generator, ctx, self.instance) + } + + /// Get the pointer to the n-th (0-based) element. + /// + /// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`. + pub fn get_nth_pelement( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + nth: Instance<'ctx, Int>, + ) -> PointerValue<'ctx> { + let elem_ty = ctx.get_llvm_type(generator, self.dtype); + + let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, nth); + ctx.builder + .build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "") + .unwrap() + } + + /// Get the n-th (0-based) scalar. + pub fn get_nth_scalar( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + nth: Instance<'ctx, Int>, + ) -> AnyObject<'ctx> { + let ptr = self.get_nth_pelement(generator, ctx, nth); + let value = ctx.builder.build_load(ptr, "").unwrap(); + AnyObject { ty: self.dtype, value } + } + + /// Get the pointer to the element indexed by `indices`. + /// + /// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`. + pub fn get_pelement_by_indices( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + indices: Instance<'ctx, Ptr>>, + ) -> PointerValue<'ctx> { + let elem_ty = ctx.get_llvm_type(generator, self.dtype); + + let p = call_nac3_ndarray_get_pelement_by_indices(generator, ctx, self.instance, indices); + ctx.builder + .build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "") + .unwrap() + } + + /// Get the scalar indexed by `indices`. + pub fn get_scalar_by_indices( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + indices: Instance<'ctx, Ptr>>, + ) -> AnyObject<'ctx> { + let ptr = self.get_pelement_by_indices(generator, ctx, indices); + let value = ctx.builder.build_load(ptr, "").unwrap(); + AnyObject { ty: self.dtype, value } + } + + /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. + /// + /// Update the ndarray's strides to make the ndarray contiguous. + pub fn set_strides_contiguous( + self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) { + call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance); + } + + /// Copy data from another ndarray. + /// + /// This ndarray and `src` is that their `np.size()` should be the same. Their shapes + /// do not matter. The copying order is determined by how their flattened views look. + /// + /// Panics if the `dtype`s of ndarrays are different. + pub fn copy_data_from( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src: NDArrayObject<'ctx>, + ) { + assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match"); + call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance); + } +}