diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 404681cc..58d18f8a 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -4,3 +4,6 @@ #include "irrt/math.hpp" #include "irrt/ndarray.hpp" #include "irrt/slice.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/def.hpp" +#include "irrt/ndarray/iter.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/exception.hpp b/nac3core/irrt/irrt/exception.hpp index 5b1ec590..4a4b093b 100644 --- a/nac3core/irrt/irrt/exception.hpp +++ b/nac3core/irrt/irrt/exception.hpp @@ -55,11 +55,14 @@ void _raise_exception_helper(ExceptionId id, int64_t param2) { Exception e = { .id = id, - .filename = {.base = reinterpret_cast(filename), .len = __builtin_strlen(filename)}, + .filename = {.base = reinterpret_cast(const_cast(filename)), + .len = static_cast(__builtin_strlen(filename))}, .line = line, .column = 0, - .function = {.base = reinterpret_cast(function), .len = __builtin_strlen(function)}, - .msg = {.base = reinterpret_cast(msg), .len = __builtin_strlen(msg)}, + .function = {.base = reinterpret_cast(const_cast(function)), + .len = static_cast(__builtin_strlen(function))}, + .msg = {.base = reinterpret_cast(const_cast(msg)), + .len = static_cast(__builtin_strlen(msg))}, }; e.params[0] = param0; e.params[1] = param1; diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp index 5dda1739..cacdd2a2 100644 --- a/nac3core/irrt/irrt/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray.hpp @@ -2,6 +2,8 @@ #include "irrt/int_types.hpp" +// TODO: To be deleted since NDArray with strides is done. + namespace { template SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) { diff --git a/nac3core/irrt/irrt/ndarray/basic.hpp b/nac3core/irrt/irrt/ndarray/basic.hpp new file mode 100644 index 00000000..2582acf8 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/basic.hpp @@ -0,0 +1,341 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +namespace ndarray { +namespace basic { +/** + * @brief Assert that `shape` does not contain negative dimensions. + * + * @param ndims Number of dimensions in `shape` + * @param shape The shape to check on + */ +template +void assert_shape_no_negative(SizeT ndims, const SizeT* shape) { + for (SizeT axis = 0; axis < ndims; axis++) { + if (shape[axis] < 0) { + raise_exception(SizeT, EXN_VALUE_ERROR, + "negative dimensions are not allowed; axis {0} " + "has dimension {1}", + axis, shape[axis], NO_PARAM); + } + } +} + +/** + * @brief Assert that two shapes are the same in the context of writing output to an ndarray. + */ +template +void assert_output_shape_same(SizeT ndarray_ndims, + const SizeT* ndarray_shape, + SizeT output_ndims, + const SizeT* output_shape) { + if (ndarray_ndims != output_ndims) { + // There is no corresponding NumPy error message like this. + raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}", + output_ndims, ndarray_ndims, NO_PARAM); + } + + for (SizeT axis = 0; axis < ndarray_ndims; axis++) { + if (ndarray_shape[axis] != output_shape[axis]) { + // There is no corresponding NumPy error message like this. + raise_exception(SizeT, EXN_VALUE_ERROR, + "Mismatched dimensions on axis {0}, output has " + "dimension {1}, but destination ndarray has dimension {2}.", + axis, output_shape[axis], ndarray_shape[axis]); + } + } +} + +/** + * @brief Return the number of elements of an ndarray given its shape. + * + * @param ndims Number of dimensions in `shape` + * @param shape The shape of the ndarray + */ +template +SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) { + SizeT size = 1; + for (SizeT axis = 0; axis < ndims; axis++) + size *= shape[axis]; + return size; +} + +/** + * @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape. + * + * @param ndims Number of elements in `shape` and `indices` + * @param shape The shape of the ndarray + * @param indices The returned indices indexing the ndarray with shape `shape`. + * @param nth The index of the element of interest. + */ +template +void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) { + for (SizeT i = 0; i < ndims; i++) { + SizeT axis = ndims - i - 1; + SizeT dim = shape[axis]; + + indices[axis] = nth % dim; + nth /= dim; + } +} + +/** + * @brief Return the number of elements of an `ndarray` + * + * This function corresponds to `.size` + */ +template +SizeT size(const NDArray* ndarray) { + return calc_size_from_shape(ndarray->ndims, ndarray->shape); +} + +/** + * @brief Return of the number of its content of an `ndarray`. + * + * This function corresponds to `.nbytes`. + */ +template +SizeT nbytes(const NDArray* ndarray) { + return size(ndarray) * ndarray->itemsize; +} + +/** + * @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object. + * + * This function corresponds to `.__len__`. + * + * @param dst_length The length. + */ +template +SizeT len(const NDArray* ndarray) { + // numpy prohibits `__len__` on unsized objects + if (ndarray->ndims == 0) { + raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM); + } else { + return ndarray->shape[0]; + } +} + +/** + * @brief Return a boolean indicating if `ndarray` is (C-)contiguous. + * + * You may want to see ndarray's rules for C-contiguity: + * https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45 + */ +template +bool is_c_contiguous(const NDArray* ndarray) { + // References: + // - tinynumpy's implementation: + // https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102 + // - ndarray's flags["C_CONTIGUOUS"]: + // https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags + // - ndarray's rules for C-contiguity: + // https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45 + + // From + // https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45: + // + // The traditional rule is that for an array to be flagged as C contiguous, + // the following must hold: + // + // strides[-1] == itemsize + // strides[i] == shape[i+1] * strides[i + 1] + // [...] + // According to these rules, a 0- or 1-dimensional array is either both + // C- and F-contiguous, or neither; and an array with 2+ dimensions + // can be C- or F- contiguous, or neither, but not both. Though there + // there are exceptions for arrays with zero or one item, in the first + // case the check is relaxed up to and including the first dimension + // with shape[i] == 0. In the second case `strides == itemsize` will + // can be true for all dimensions and both flags are set. + + if (ndarray->ndims == 0) { + return true; + } + + if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) { + return false; + } + + for (SizeT i = 1; i < ndarray->ndims; i++) { + SizeT axis_i = ndarray->ndims - i - 1; + if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) { + return false; + } + } + + return true; +} + +/** + * @brief Return the pointer to the element indexed by `indices` along the ndarray's axes. + * + * This function does no bound check. + */ +template +uint8_t* get_pelement_by_indices(const NDArray* ndarray, const SizeT* indices) { + uint8_t* element = ndarray->data; + for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++) + element += indices[dim_i] * ndarray->strides[dim_i]; + return element; +} + +/** + * @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view. + * + * This function does no bound check. + */ +template +uint8_t* get_nth_pelement(const NDArray* ndarray, SizeT nth) { + uint8_t* element = ndarray->data; + for (SizeT i = 0; i < ndarray->ndims; i++) { + SizeT axis = ndarray->ndims - i - 1; + SizeT dim = ndarray->shape[axis]; + element += ndarray->strides[axis] * (nth % dim); + nth /= dim; + } + return element; +} + +/** + * @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous. + * + * You might want to read https://ajcr.net/stride-guide-part-1/. + */ +template +void set_strides_by_shape(NDArray* ndarray) { + SizeT stride_product = 1; + for (SizeT i = 0; i < ndarray->ndims; i++) { + SizeT axis = ndarray->ndims - i - 1; + ndarray->strides[axis] = stride_product * ndarray->itemsize; + stride_product *= ndarray->shape[axis]; + } +} + +/** + * @brief Set an element in `ndarray`. + * + * @param pelement Pointer to the element in `ndarray` to be set. + * @param pvalue Pointer to the value `pelement` will be set to. + */ +template +void set_pelement_value(NDArray* ndarray, uint8_t* pelement, const uint8_t* pvalue) { + __builtin_memcpy(pelement, pvalue, ndarray->itemsize); +} + +/** + * @brief Copy data from one ndarray to another of the exact same size and itemsize. + * + * Both ndarrays will be viewed in their flatten views when copying the elements. + */ +template +void copy_data(const NDArray* src_ndarray, NDArray* dst_ndarray) { + // TODO: Make this faster with memcpy when we see a contiguous segment. + // TODO: Handle overlapping. + + debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize); + + for (SizeT i = 0; i < size(src_ndarray); i++) { + auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i); + auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i); + ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element); + } +} +} // namespace basic +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::basic; + +void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) { + assert_shape_no_negative(ndims, shape); +} + +void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) { + assert_shape_no_negative(ndims, shape); +} + +void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims, + const int32_t* ndarray_shape, + int32_t output_ndims, + const int32_t* output_shape) { + assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape); +} + +void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims, + const int64_t* ndarray_shape, + int64_t output_ndims, + const int64_t* output_shape) { + assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape); +} + +uint32_t __nac3_ndarray_size(NDArray* ndarray) { + return size(ndarray); +} + +uint64_t __nac3_ndarray_size64(NDArray* ndarray) { + return size(ndarray); +} + +uint32_t __nac3_ndarray_nbytes(NDArray* ndarray) { + return nbytes(ndarray); +} + +uint64_t __nac3_ndarray_nbytes64(NDArray* ndarray) { + return nbytes(ndarray); +} + +int32_t __nac3_ndarray_len(NDArray* ndarray) { + return len(ndarray); +} + +int64_t __nac3_ndarray_len64(NDArray* ndarray) { + return len(ndarray); +} + +bool __nac3_ndarray_is_c_contiguous(NDArray* ndarray) { + return is_c_contiguous(ndarray); +} + +bool __nac3_ndarray_is_c_contiguous64(NDArray* ndarray) { + return is_c_contiguous(ndarray); +} + +uint8_t* __nac3_ndarray_get_nth_pelement(const NDArray* ndarray, int32_t nth) { + return get_nth_pelement(ndarray, nth); +} + +uint8_t* __nac3_ndarray_get_nth_pelement64(const NDArray* ndarray, int64_t nth) { + return get_nth_pelement(ndarray, nth); +} + +uint8_t* __nac3_ndarray_get_pelement_by_indices(const NDArray* ndarray, int32_t* indices) { + return get_pelement_by_indices(ndarray, indices); +} + +uint8_t* __nac3_ndarray_get_pelement_by_indices64(const NDArray* ndarray, int64_t* indices) { + return get_pelement_by_indices(ndarray, indices); +} + +void __nac3_ndarray_set_strides_by_shape(NDArray* ndarray) { + set_strides_by_shape(ndarray); +} + +void __nac3_ndarray_set_strides_by_shape64(NDArray* ndarray) { + set_strides_by_shape(ndarray); +} + +void __nac3_ndarray_copy_data(NDArray* src_ndarray, NDArray* dst_ndarray) { + copy_data(src_ndarray, dst_ndarray); +} + +void __nac3_ndarray_copy_data64(NDArray* src_ndarray, NDArray* dst_ndarray) { + copy_data(src_ndarray, dst_ndarray); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/def.hpp b/nac3core/irrt/irrt/ndarray/def.hpp new file mode 100644 index 00000000..8ec1cf63 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/def.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include "irrt/int_types.hpp" + +namespace { +/** + * @brief The NDArray object + * + * Official numpy implementation: + * https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst + */ +template +struct NDArray { + /** + * @brief The underlying data this `ndarray` is pointing to. + */ + uint8_t* data; + + /** + * @brief The number of bytes of a single element in `data`. + */ + SizeT itemsize; + + /** + * @brief The number of dimensions of this shape. + */ + SizeT ndims; + + /** + * @brief The NDArray shape, with length equal to `ndims`. + * + * Note that it may contain 0. + */ + SizeT* shape; + + /** + * @brief Array strides, with length equal to `ndims` + * + * The stride values are in units of bytes, not number of elements. + * + * Note that `strides` can have negative values or contain 0. + */ + SizeT* strides; +}; +} // namespace \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/iter.hpp b/nac3core/irrt/irrt/ndarray/iter.hpp new file mode 100644 index 00000000..d419b7e8 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/iter.hpp @@ -0,0 +1,146 @@ +#pragma once + +#include "irrt/int_types.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +/** + * @brief Helper struct to enumerate through an ndarray *efficiently*. + * + * Example usage (in pseudo-code): + * ``` + * // Suppose my_ndarray has been initialized, with shape [2, 3] and dtype `double` + * NDIter nditer; + * nditer.initialize(my_ndarray); + * while (nditer.has_element()) { + * // This body is run 6 (= my_ndarray.size) times. + * + * // [0, 0] -> [0, 1] -> [0, 2] -> [1, 0] -> [1, 1] -> [1, 2] -> end + * print(nditer.indices); + * + * // 0 -> 1 -> 2 -> 3 -> 4 -> 5 + * print(nditer.nth); + * + * // <1st element> -> <2nd element> -> ... -> <6th element> -> end + * print(*((double *) nditer.element)) + * + * nditer.next(); // Go to next element. + * } + * ``` + * + * Interesting cases: + * - If `my_ndarray.ndims` == 0, there is one iteration. + * - If `my_ndarray.shape` contains zeroes, there are no iterations. + */ +template +struct NDIter { + // Information about the ndarray being iterated over. + SizeT ndims; + SizeT* shape; + SizeT* strides; + + /** + * @brief The current indices. + * + * Must be allocated by the caller. + */ + SizeT* indices; + + /** + * @brief The nth (0-based) index of the current indices. + * + * Initially this is 0. + */ + SizeT nth; + + /** + * @brief Pointer to the current element. + * + * Initially this points to first element of the ndarray. + */ + uint8_t* element; + + /** + * @brief Cache for the product of shape. + * + * Could be 0 if `shape` has 0s in it. + */ + SizeT size; + + void initialize(SizeT ndims, SizeT* shape, SizeT* strides, uint8_t* element, SizeT* indices) { + this->ndims = ndims; + this->shape = shape; + this->strides = strides; + + this->indices = indices; + this->element = element; + + // Compute size + this->size = 1; + for (SizeT i = 0; i < ndims; i++) { + this->size *= shape[i]; + } + + // `indices` starts on all 0s. + for (SizeT axis = 0; axis < ndims; axis++) + indices[axis] = 0; + nth = 0; + } + + void initialize_by_ndarray(NDArray* ndarray, SizeT* indices) { + // NOTE: ndarray->data is pointing to the first element, and `NDIter`'s `element` should also point to the first + // element as well. + this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices); + } + + // Is the current iteration valid? + // If true, then `element`, `indices` and `nth` contain details about the current element. + bool has_element() { return nth < size; } + + // Go to the next element. + void next() { + for (SizeT i = 0; i < ndims; i++) { + SizeT axis = ndims - i - 1; + indices[axis]++; + if (indices[axis] >= shape[axis]) { + indices[axis] = 0; + + // TODO: There is something called backstrides to speedup iteration. + // See https://ajcr.net/stride-guide-part-1/, and + // https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides. + element -= strides[axis] * (shape[axis] - 1); + } else { + element += strides[axis]; + break; + } + } + nth++; + } +}; +} // namespace + +extern "C" { +void __nac3_nditer_initialize(NDIter* iter, NDArray* ndarray, int32_t* indices) { + iter->initialize_by_ndarray(ndarray, indices); +} + +void __nac3_nditer_initialize64(NDIter* iter, NDArray* ndarray, int64_t* indices) { + iter->initialize_by_ndarray(ndarray, indices); +} + +bool __nac3_nditer_has_element(NDIter* iter) { + return iter->has_element(); +} + +bool __nac3_nditer_has_element64(NDIter* iter) { + return iter->has_element(); +} + +void __nac3_nditer_next(NDIter* iter) { + iter->next(); +} + +void __nac3_nditer_next64(NDIter* iter) { + iter->next(); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 4260ebef..d2e02bc1 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -7,16 +7,17 @@ use itertools::Itertools; use super::{ classes::{ - ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + NDArrayValue, ProxyValue, RangeValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, expr::destructure_range, extern_fns, irrt, irrt::calculate_len_for_slice_range, llvm_intrinsics, macros::codegen_unreachable, + model::*, numpy, numpy::ndarray_elementwise_unaryop_impl, + object::{any::AnyObject, list::ListObject, ndarray::NDArrayObject, tuple::TupleObject}, stmt::gen_for_callback_incrementing, CodeGenContext, CodeGenerator, }; @@ -42,58 +43,33 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let range_ty = ctx.primitives.range; let (arg_ty, arg) = n; - - Ok(if ctx.unifier.unioned(arg_ty, range_ty) { + Ok(if ctx.unifier.unioned(arg_ty, ctx.primitives.range) { let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); calculate_len_for_slice_range(generator, ctx, start, end, step) } else { - match &*ctx.unifier.get_ty_immutable(arg_ty) { - TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false), - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { - let zero = llvm_i32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, llvm_i32.const_int(1, false)], - None, - ) - .into_int_value(); - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + let arg = AnyObject { ty: arg_ty, value: arg }; + let len: Instance<'ctx, Int> = match &*ctx.unifier.get_ty(arg_ty) { + TypeEnum::TTuple { .. } => { + let tuple = TupleObject::from_object(ctx, arg); + tuple.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); - - let ndims = arg.dim_sizes().size(ctx, generator); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "") - .unwrap(), - "0:TypeError", - "len() of unsized object", - [None, None, None], - ctx.current_loc, - ); - - let len = unsafe { - arg.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - }; - - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayObject::from_object(generator, ctx, arg); + ndarray.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32) } - _ => codegen_unreachable!(ctx), - } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListObject::from_object(generator, ctx, arg); + list.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32) + } + _ => unsupported_type(ctx, "len", &[arg_ty]), + }; + len.value }) } diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 15af3cdb..7f2cd687 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -18,6 +18,8 @@ use super::{ }, llvm_intrinsics, macros::codegen_unreachable, + model::{function::FnCall, *}, + object::ndarray::{nditer::NDIter, NDArray}, stmt::gen_for_callback_incrementing, CodeGenContext, CodeGenerator, }; @@ -950,3 +952,163 @@ pub fn call_ndarray_calc_broadcast_index< Box::new(|_, v| v.into()), ) } + +// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}". +// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64". +#[must_use] +pub fn get_sizet_dependent_function_name( + generator: &mut G, + ctx: &CodeGenContext<'_, '_>, + name: &str, +) -> String { + let mut name = name.to_owned(); + match generator.get_size_type(ctx.ctx).get_bit_width() { + 32 => {} + 64 => name.push_str("64"), + bit_width => { + panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits") + } + } + name +} + +pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndims: Instance<'ctx, Int>, + shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_util_assert_shape_no_negative", + ); + FnCall::builder(generator, ctx, &name).arg(ndims).arg(shape).returning_void(); +} + +pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray_ndims: Instance<'ctx, Int>, + ndarray_shape: Instance<'ctx, Ptr>>, + output_ndims: Instance<'ctx, Int>, + output_shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_util_assert_output_shape_same", + ); + FnCall::builder(generator, ctx, &name) + .arg(ndarray_ndims) + .arg(ndarray_shape) + .arg(output_ndims) + .arg(output_shape) + .returning_void(); +} + +pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_size"); + FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("size") +} + +pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes"); + FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("nbytes") +} + +pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_len"); + FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("len") +} + +pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous"); + FnCall::builder(generator, ctx, &name).arg(ndarray).returning_auto("is_c_contiguous") +} + +pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, + index: Instance<'ctx, Int>, +) -> Instance<'ctx, Ptr>> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); + FnCall::builder(generator, ctx, &name).arg(ndarray).arg(index).returning_auto("pelement") +} + +pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, + indices: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Ptr>> { + let name = + get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices"); + FnCall::builder(generator, ctx, &name).arg(ndarray).arg(indices).returning_auto("pelement") +} + +pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) { + let name = + get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape"); + FnCall::builder(generator, ctx, &name).arg(ndarray).returning_void(); +} + +pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data"); + FnCall::builder(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void(); +} + +pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + iter: Instance<'ctx, Ptr>>, + ndarray: Instance<'ctx, Ptr>>, + indices: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize"); + FnCall::builder(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void(); +} + +pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + iter: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_element"); + FnCall::builder(generator, ctx, &name).arg(iter).returning_auto("has_element") +} + +pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + iter: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next"); + FnCall::builder(generator, ctx, &name).arg(iter).returning_void(); +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 7d83cb86..cc0c5d0a 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -30,15 +30,17 @@ use nac3parser::ast::{Location, Stmt, StrRef}; use crate::{ symbol_resolver::{StaticValue, SymbolResolver}, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, + toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, }, }; -use classes::{ListType, NDArrayType, ProxyType, RangeType}; +use classes::{ListType, ProxyType, RangeType}; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; +use model::*; +use object::ndarray::NDArray; pub mod builtin_fns; pub mod classes; @@ -50,6 +52,7 @@ pub mod irrt; pub mod llvm_intrinsics; pub mod model; pub mod numpy; +pub mod object; pub mod stmt; #[cfg(test)] @@ -510,12 +513,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); - let element_type = get_llvm_type( - ctx, module, generator, unifier, top_level, type_cache, dtype, - ); - - NDArrayType::new(generator, ctx, element_type).as_base_type().into() + Ptr(Struct(NDArray)).llvm_type(generator, ctx).as_basic_type_enum() } _ => unreachable!( diff --git a/nac3core/src/codegen/model/mod.rs b/nac3core/src/codegen/model/mod.rs index 22bb3332..4256e84e 100644 --- a/nac3core/src/codegen/model/mod.rs +++ b/nac3core/src/codegen/model/mod.rs @@ -6,6 +6,7 @@ pub mod function; mod int; mod ptr; mod structure; +pub mod util; pub use any::*; pub use array::*; diff --git a/nac3core/src/codegen/model/util.rs b/nac3core/src/codegen/model/util.rs new file mode 100644 index 00000000..cbff14c1 --- /dev/null +++ b/nac3core/src/codegen/model/util.rs @@ -0,0 +1,41 @@ +use super::*; +use crate::codegen::{ + stmt::{gen_for_callback_incrementing, BreakContinueHooks}, + CodeGenContext, CodeGenerator, +}; + +/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions. +/// +/// The value for `stop` is exclusive. +pub fn gen_for_model<'ctx, 'a, G, F, N>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + start: Instance<'ctx, Int>, + stop: Instance<'ctx, Int>, + step: Instance<'ctx, Int>, + body: F, +) -> Result<(), String> +where + G: CodeGenerator + ?Sized, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + Instance<'ctx, Int>, + ) -> Result<(), String>, + N: IntKind<'ctx> + Default, +{ + let int_model = Int(N::default()); + gen_for_callback_incrementing( + generator, + ctx, + None, + start.value, + (stop.value, false), + |g, ctx, hooks, i| { + let i = unsafe { int_model.believe_value(i) }; + body(g, ctx, hooks, i) + }, + step.value, + ) +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index ffe0d83d..68e5149f 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -19,13 +19,17 @@ use super::{ }, llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, + object::{ + any::AnyObject, + ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject}, + }, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, toplevel::{ - helper::PrimDef, + helper::{extract_ndims, PrimDef}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, }, @@ -1742,8 +1746,13 @@ pub fn gen_ndarray_empty<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = AnyObject { value: shape_arg, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, context, shape); + let ndarray = NDArrayObject::make_np_empty(generator, context, dtype, ndims, shape); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.zeros`. @@ -1760,8 +1769,13 @@ pub fn gen_ndarray_zeros<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = AnyObject { value: shape_arg, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, context, shape); + let ndarray = NDArrayObject::make_np_zeros(generator, context, dtype, ndims, shape); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.ones`. @@ -1778,8 +1792,13 @@ pub fn gen_ndarray_ones<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = AnyObject { value: shape_arg, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, context, shape); + let ndarray = NDArrayObject::make_np_ones(generator, context, dtype, ndims, shape); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.full`. @@ -1799,8 +1818,14 @@ pub fn gen_ndarray_full<'ctx>( let fill_value_arg = args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; - call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = AnyObject { value: shape_arg, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, context, shape); + let ndarray = + NDArrayObject::make_np_full(generator, context, dtype, ndims, shape, fill_value_arg); + Ok(ndarray.instance.value) } pub fn gen_ndarray_array<'ctx>( diff --git a/nac3core/src/codegen/object/any.rs b/nac3core/src/codegen/object/any.rs new file mode 100644 index 00000000..c7a983e0 --- /dev/null +++ b/nac3core/src/codegen/object/any.rs @@ -0,0 +1,12 @@ +use inkwell::values::BasicValueEnum; + +use crate::typecheck::typedef::Type; + +/// A NAC3 LLVM Python object of any type. +#[derive(Debug, Clone, Copy)] +pub struct AnyObject<'ctx> { + /// Typechecker type of the object. + pub ty: Type, + /// LLVM value of the object. + pub value: BasicValueEnum<'ctx>, +} diff --git a/nac3core/src/codegen/object/list.rs b/nac3core/src/codegen/object/list.rs new file mode 100644 index 00000000..36a864d5 --- /dev/null +++ b/nac3core/src/codegen/object/list.rs @@ -0,0 +1,75 @@ +use super::any::AnyObject; +use crate::{ + codegen::{model::*, CodeGenContext, CodeGenerator}, + typecheck::typedef::{iter_type_vars, Type, TypeEnum}, +}; + +/// Fields of [`List`] +pub struct ListFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> { + /// Array pointer to content + pub items: F::Output>, + /// Number of items in the array + pub len: F::Output>, +} + +/// A list in NAC3. +#[derive(Debug, Clone, Copy, Default)] +pub struct List { + /// Model of the list items + pub item: Item, +} + +impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for List { + type Fields> = ListFields<'ctx, F, Item>; + + fn iter_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + items: traversal.add("items", Ptr(self.item)), + len: traversal.add_auto("len"), + } + } +} + +/// A NAC3 Python List object. +#[derive(Debug, Clone, Copy)] +pub struct ListObject<'ctx> { + /// Typechecker type of the list items + pub item_type: Type, + pub instance: Instance<'ctx, Ptr>>>>, +} + +impl<'ctx> ListObject<'ctx> { + /// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`]. + pub fn from_object( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + ) -> Self { + // Check typechecker type and extract `item_type` + let item_type = match &*ctx.unifier.get_ty(object.ty) { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + iter_type_vars(params).next().unwrap().ty // Extract `item_type` + } + _ => { + panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(object.ty)) + } + }; + + let plist = Ptr(Struct(List { item: Any(ctx.get_llvm_type(generator, item_type)) })); + + // Create object + let value = plist.check_value(generator, ctx.ctx, object.value).unwrap(); + ListObject { item_type, instance: value } + } + + /// Get the `len()` of this list. + pub fn len( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + self.instance.get(generator, ctx, |f| f.len) + } +} diff --git a/nac3core/src/codegen/object/mod.rs b/nac3core/src/codegen/object/mod.rs new file mode 100644 index 00000000..9466d70b --- /dev/null +++ b/nac3core/src/codegen/object/mod.rs @@ -0,0 +1,4 @@ +pub mod any; +pub mod list; +pub mod ndarray; +pub mod tuple; diff --git a/nac3core/src/codegen/object/ndarray/factory.rs b/nac3core/src/codegen/object/ndarray/factory.rs new file mode 100644 index 00000000..712c0cff --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/factory.rs @@ -0,0 +1,125 @@ +use inkwell::values::BasicValueEnum; + +use super::NDArrayObject; +use crate::{ + codegen::{ + irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, CodeGenContext, + CodeGenerator, + }, + typecheck::typedef::Type, +}; + +/// Get the zero value in `np.zeros()` of a `dtype`. +fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i32_type().const_zero().into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +/// Get the one value in `np.ones()` of a `dtype`. +fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32); + ctx.ctx.i32_type().const_int(1, is_signed).into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64); + ctx.ctx.i64_type().const_int(1, is_signed).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_float(1.0).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_int(1, false).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "1").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Create an ndarray like `np.empty`. + pub fn make_np_empty( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + shape: Instance<'ctx, Ptr>>, + ) -> Self { + // Validate `shape` + let ndims_llvm = Int(SizeT).const_int(generator, ctx.ctx, ndims, false); + call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape); + + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims); + ndarray.copy_shape_from_array(generator, ctx, shape); + ndarray.create_data(generator, ctx); + + ndarray + } + + /// Create an ndarray like `np.full`. + pub fn make_np_full( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + shape: Instance<'ctx, Ptr>>, + fill_value: BasicValueEnum<'ctx>, + ) -> Self { + let ndarray = NDArrayObject::make_np_empty(generator, ctx, dtype, ndims, shape); + ndarray.fill(generator, ctx, fill_value); + ndarray + } + + /// Create an ndarray like `np.zero`. + pub fn make_np_zeros( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + shape: Instance<'ctx, Ptr>>, + ) -> Self { + let fill_value = ndarray_zero_value(generator, ctx, dtype); + NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value) + } + + /// Create an ndarray like `np.ones`. + pub fn make_np_ones( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + shape: Instance<'ctx, Ptr>>, + ) -> Self { + let fill_value = ndarray_one_value(generator, ctx, dtype); + NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value) + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs new file mode 100644 index 00000000..8b55c2e1 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -0,0 +1,371 @@ +use inkwell::{ + context::Context, + types::BasicType, + values::{BasicValueEnum, PointerValue}, + AddressSpace, +}; + +use super::any::AnyObject; +use crate::{ + codegen::{ + irrt::{ + call_nac3_ndarray_copy_data, call_nac3_ndarray_get_nth_pelement, + call_nac3_ndarray_get_pelement_by_indices, call_nac3_ndarray_is_c_contiguous, + call_nac3_ndarray_len, call_nac3_ndarray_nbytes, + call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, + }, + model::*, + CodeGenContext, CodeGenerator, + }, + toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, + typecheck::typedef::Type, +}; + +pub mod factory; +pub mod nditer; +pub mod shape_util; + +/// Fields of [`NDArray`] +pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> { + pub data: F::Output>>, + pub itemsize: F::Output>, + pub ndims: F::Output>, + pub shape: F::Output>>, + pub strides: F::Output>>, +} + +/// A strided ndarray in NAC3. +/// +/// See IRRT implementation for details about its fields. +#[derive(Debug, Clone, Copy, Default)] +pub struct NDArray; + +impl<'ctx> StructKind<'ctx> for NDArray { + type Fields> = NDArrayFields<'ctx, F>; + + fn iter_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + data: traversal.add_auto("data"), + itemsize: traversal.add_auto("itemsize"), + ndims: traversal.add_auto("ndims"), + shape: traversal.add_auto("shape"), + strides: traversal.add_auto("strides"), + } + } +} + +/// A NAC3 Python ndarray object. +#[derive(Debug, Clone, Copy)] +pub struct NDArrayObject<'ctx> { + pub dtype: Type, + pub ndims: u64, + pub instance: Instance<'ctx, Ptr>>, +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Attempt to convert an [`AnyObject`] into an [`NDArrayObject`]. + pub fn from_object( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + ) -> NDArrayObject<'ctx> { + let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty); + let ndims = extract_ndims(&ctx.unifier, ndims); + + let value = Ptr(Struct(NDArray)).check_value(generator, ctx.ctx, object.value).unwrap(); + NDArrayObject { dtype, ndims, instance: value } + } + + /// Get this ndarray's `ndims` as an LLVM constant. + pub fn ndims_llvm( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Int> { + Int(SizeT).const_int(generator, ctx, self.ndims, false) + } + + /// Allocate an ndarray on the stack given its `ndims` and `dtype`. + /// + /// `shape` and `strides` will be automatically allocated onto the stack. + /// + /// The returned ndarray's content will be: + /// - `data`: uninitialized. + /// - `itemsize`: set to the `sizeof()` of `dtype`. + /// - `ndims`: set to the value of `ndims`. + /// - `shape`: allocated with an array of length `ndims` with uninitialized values. + /// - `strides`: allocated with an array of length `ndims` with uninitialized values. + pub fn alloca( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + ) -> Self { + let ndarray = Struct(NDArray).alloca(generator, ctx); + + let itemsize = ctx.get_llvm_type(generator, dtype).size_of().unwrap(); + let itemsize = Int(SizeT).z_extend_or_truncate(generator, ctx, itemsize); + ndarray.set(ctx, |f| f.itemsize, itemsize); + + let ndims_val = Int(SizeT).const_int(generator, ctx.ctx, ndims, false); + ndarray.set(ctx, |f| f.ndims, ndims_val); + + let shape = Int(SizeT).array_alloca(generator, ctx, ndims_val.value); + ndarray.set(ctx, |f| f.shape, shape); + + let strides = Int(SizeT).array_alloca(generator, ctx, ndims_val.value); + ndarray.set(ctx, |f| f.strides, strides); + + NDArrayObject { dtype, ndims, instance: ndarray } + } + + /// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape. + /// + /// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized. + pub fn alloca_constant_shape( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &[u64], + ) -> Self { + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64); + + // Write shape + let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape); + for (i, dim) in shape.iter().enumerate() { + let dim = Int(SizeT).const_int(generator, ctx.ctx, *dim, false); + dst_shape.offset_const(ctx, i64::try_from(i).unwrap()).store(ctx, dim); + } + + ndarray + } + + /// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape. + /// + /// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized. + pub fn alloca_dynamic_shape( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &[Instance<'ctx, Int>], + ) -> Self { + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64); + + // Write shape + let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape); + for (i, dim) in shape.iter().enumerate() { + dst_shape.offset_const(ctx, i64::try_from(i).unwrap()).store(ctx, *dim); + } + + ndarray + } + + /// Initialize an ndarray's `data` by allocating a buffer on the stack. + /// The allocated data buffer is considered to be *owned* by the ndarray. + /// + /// `strides` of the ndarray will also be updated with `set_strides_by_shape`. + /// + /// `shape` and `itemsize` of the ndarray ***must*** be initialized first. + pub fn create_data( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) { + let nbytes = self.nbytes(generator, ctx); + + let data = Int(Byte).array_alloca(generator, ctx, nbytes.value); + self.instance.set(ctx, |f| f.data, data); + + self.set_strides_contiguous(generator, ctx); + } + + /// Copy shape dimensions from an array. + pub fn copy_shape_from_array( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: Instance<'ctx, Ptr>>, + ) { + let num_items = self.ndims_llvm(generator, ctx.ctx).value; + self.instance.get(generator, ctx, |f| f.shape).copy_from(generator, ctx, shape, num_items); + } + + /// Copy shape dimensions from an ndarray. + /// Panics if `ndims` mismatches. + pub fn copy_shape_from_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayObject<'ctx>, + ) { + assert_eq!(self.ndims, src_ndarray.ndims); + let src_shape = src_ndarray.instance.get(generator, ctx, |f| f.shape); + self.copy_shape_from_array(generator, ctx, src_shape); + } + + /// Copy strides dimensions from an array. + pub fn copy_strides_from_array( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + strides: Instance<'ctx, Ptr>>, + ) { + let num_items = self.ndims_llvm(generator, ctx.ctx).value; + self.instance + .get(generator, ctx, |f| f.strides) + .copy_from(generator, ctx, strides, num_items); + } + + /// Copy strides dimensions from an ndarray. + /// Panics if `ndims` mismatches. + pub fn copy_strides_from_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayObject<'ctx>, + ) { + assert_eq!(self.ndims, src_ndarray.ndims); + let src_strides = src_ndarray.instance.get(generator, ctx, |f| f.strides); + self.copy_strides_from_array(generator, ctx, src_strides); + } + + /// Get the `np.size()` of this ndarray. + pub fn size( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_ndarray_size(generator, ctx, self.instance) + } + + /// Get the `ndarray.nbytes` of this ndarray. + pub fn nbytes( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_ndarray_nbytes(generator, ctx, self.instance) + } + + /// Get the `len()` of this ndarray. + pub fn len( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_ndarray_len(generator, ctx, self.instance) + } + + /// Check if this ndarray is C-contiguous. + /// + /// See NumPy's `flags["C_CONTIGUOUS"]`: + pub fn is_c_contiguous( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_ndarray_is_c_contiguous(generator, ctx, self.instance) + } + + /// Get the pointer to the n-th (0-based) element. + /// + /// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`. + pub fn get_nth_pelement( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + nth: Instance<'ctx, Int>, + ) -> PointerValue<'ctx> { + let elem_ty = ctx.get_llvm_type(generator, self.dtype); + + let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, nth); + ctx.builder + .build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "") + .unwrap() + } + + /// Get the n-th (0-based) scalar. + pub fn get_nth_scalar( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + nth: Instance<'ctx, Int>, + ) -> AnyObject<'ctx> { + let ptr = self.get_nth_pelement(generator, ctx, nth); + let value = ctx.builder.build_load(ptr, "").unwrap(); + AnyObject { ty: self.dtype, value } + } + + /// Get the pointer to the element indexed by `indices`. + /// + /// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`. + pub fn get_pelement_by_indices( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + indices: Instance<'ctx, Ptr>>, + ) -> PointerValue<'ctx> { + let elem_ty = ctx.get_llvm_type(generator, self.dtype); + + let p = call_nac3_ndarray_get_pelement_by_indices(generator, ctx, self.instance, indices); + ctx.builder + .build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "") + .unwrap() + } + + /// Get the scalar indexed by `indices`. + pub fn get_scalar_by_indices( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + indices: Instance<'ctx, Ptr>>, + ) -> AnyObject<'ctx> { + let ptr = self.get_pelement_by_indices(generator, ctx, indices); + let value = ctx.builder.build_load(ptr, "").unwrap(); + AnyObject { ty: self.dtype, value } + } + + /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. + /// + /// Update the ndarray's strides to make the ndarray contiguous. + pub fn set_strides_contiguous( + self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) { + call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance); + } + + /// Copy data from another ndarray. + /// + /// This ndarray and `src` is that their `np.size()` should be the same. Their shapes + /// do not matter. The copying order is determined by how their flattened views look. + /// + /// Panics if the `dtype`s of ndarrays are different. + pub fn copy_data_from( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src: NDArrayObject<'ctx>, + ) { + assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match"); + call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance); + } + + /// Fill the ndarray with a scalar. + /// + /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. + pub fn fill( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) { + self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| { + let p = nditer.get_pointer(generator, ctx); + ctx.builder.build_store(p, value).unwrap(); + Ok(()) + }) + .unwrap(); + } +} diff --git a/nac3core/src/codegen/object/ndarray/nditer.rs b/nac3core/src/codegen/object/ndarray/nditer.rs new file mode 100644 index 00000000..b0d29b7a --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/nditer.rs @@ -0,0 +1,178 @@ +use inkwell::{types::BasicType, values::PointerValue, AddressSpace}; + +use super::NDArrayObject; +use crate::codegen::{ + irrt::{call_nac3_nditer_has_element, call_nac3_nditer_initialize, call_nac3_nditer_next}, + model::*, + object::any::AnyObject, + stmt::{gen_for_callback, BreakContinueHooks}, + CodeGenContext, CodeGenerator, +}; + +/// Fields of [`NDIter`] +pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> { + pub ndims: F::Output>, + pub shape: F::Output>>, + pub strides: F::Output>>, + + pub indices: F::Output>>, + pub nth: F::Output>, + pub element: F::Output>>, + + pub size: F::Output>, +} + +/// An IRRT helper structure used to iterate through an ndarray. +#[derive(Debug, Clone, Copy, Default)] +pub struct NDIter; + +impl<'ctx> StructKind<'ctx> for NDIter { + type Fields> = NDIterFields<'ctx, F>; + + fn iter_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + ndims: traversal.add_auto("ndims"), + shape: traversal.add_auto("shape"), + strides: traversal.add_auto("strides"), + + indices: traversal.add_auto("indices"), + nth: traversal.add_auto("nth"), + element: traversal.add_auto("element"), + + size: traversal.add_auto("size"), + } + } +} + +/// A helper structure with a convenient interface to interact with [`NDIter`]. +#[derive(Debug, Clone)] +pub struct NDIterHandle<'ctx> { + instance: Instance<'ctx, Ptr>>, + /// The ndarray this [`NDIter`] to iterating over. + ndarray: NDArrayObject<'ctx>, + /// The current indices of [`NDIter`]. + indices: Instance<'ctx, Ptr>>, +} + +impl<'ctx> NDIterHandle<'ctx> { + /// Allocate an [`NDIter`] that iterates through an ndarray. + pub fn new( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayObject<'ctx>, + ) -> Self { + let nditer = Struct(NDIter).alloca(generator, ctx); + let ndims = ndarray.ndims_llvm(generator, ctx.ctx); + + // The caller has the responsibility to allocate 'indices' for `NDIter`. + let indices = Int(SizeT).array_alloca(generator, ctx, ndims.value); + call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices); + + NDIterHandle { ndarray, instance: nditer, indices } + } + + /// Is the current iteration valid? + /// + /// If true, then `element`, `indices` and `nth` contain details about the current element. + /// + /// If `ndarray` is unsized, this returns true only for the first iteration. + /// If `ndarray` is 0-sized, this always returns false. + #[must_use] + pub fn has_element( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_nditer_has_element(generator, ctx, self.instance) + } + + /// Go to the next element. If `has_element()` is false, then this has undefined behavior. + /// + /// If `ndarray` is unsized, this can only be called once. + /// If `ndarray` is 0-sized, this can never be called. + pub fn next( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) { + call_nac3_nditer_next(generator, ctx, self.instance); + } + + /// Get pointer to the current element. + #[must_use] + pub fn get_pointer( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> PointerValue<'ctx> { + let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype); + + let p = self.instance.get(generator, ctx, |f| f.element); + ctx.builder + .build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element") + .unwrap() + } + + /// Get the value of the current element. + #[must_use] + pub fn get_scalar( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> AnyObject<'ctx> { + let p = self.get_pointer(generator, ctx); + let value = ctx.builder.build_load(p, "value").unwrap(); + AnyObject { ty: self.ndarray.dtype, value } + } + + /// Get the index of the current element if this ndarray were a flat ndarray. + #[must_use] + pub fn get_index( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + self.instance.get(generator, ctx, |f| f.nth) + } + + /// Get the indices of the current element. + #[must_use] + pub fn get_indices(&self) -> Instance<'ctx, Ptr>> { + self.indices + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Iterate through every element in the ndarray. + /// + /// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterHandle`] to + /// get properties of the current iteration (e.g., the current element, indices, etc.) + pub fn foreach<'a, G, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + body: F, + ) -> Result<(), String> + where + G: CodeGenerator + ?Sized, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + NDIterHandle<'ctx>, + ) -> Result<(), String>, + { + gen_for_callback( + generator, + ctx, + Some("ndarray_foreach"), + |generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)), + |generator, ctx, nditer| Ok(nditer.has_element(generator, ctx).value), + |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), + |generator, ctx, nditer| { + nditer.next(generator, ctx); + Ok(()) + }, + ) + } +} diff --git a/nac3core/src/codegen/object/ndarray/shape_util.rs b/nac3core/src/codegen/object/ndarray/shape_util.rs new file mode 100644 index 00000000..5936b9e3 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/shape_util.rs @@ -0,0 +1,104 @@ +use crate::{ + codegen::{ + model::*, + object::{any::AnyObject, list::ListObject, tuple::TupleObject}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::TypeEnum, +}; +use util::gen_for_model; + +/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length. +/// +/// * `sequence` - The `sequence` parameter. +/// * `sequence_ty` - The typechecker type of `sequence` +/// +/// The `sequence` argument type may only be one of the following: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` +/// +/// All `int32` values will be sign-extended to `SizeT`. +pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + input_sequence: AnyObject<'ctx>, +) -> (Instance<'ctx, Int>, Instance<'ctx, Ptr>>) { + let zero = Int(SizeT).const_0(generator, ctx.ctx); + let one = Int(SizeT).const_1(generator, ctx.ctx); + + // The result `list` to return. + match &*ctx.unifier.get_ty(input_sequence.ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` + + // Check `input_sequence` + let input_sequence = ListObject::from_object(generator, ctx, input_sequence); + + let len = input_sequence.instance.get(generator, ctx, |f| f.len); + let result = Int(SizeT).array_alloca(generator, ctx, len.value); + + // Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result` + gen_for_model(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| { + // Load the i-th int32 in the input sequence + let int = input_sequence + .instance + .get(generator, ctx, |f| f.items) + .get_index(generator, ctx, i.value) + .value + .into_int_value(); + + // Cast to SizeT + let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, int); + + // Store + result.set_index(ctx, i.value, int); + + Ok(()) + }) + .unwrap(); + + (len, result) + } + TypeEnum::TTuple { .. } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + let input_sequence = TupleObject::from_object(ctx, input_sequence); + + let len = input_sequence.len(generator, ctx); + + let result = Int(SizeT).array_alloca(generator, ctx, len.value); + + for i in 0..input_sequence.num_elements() { + // Get the i-th element off of the tuple and load it into `result`. + let int = input_sequence.index(ctx, i).value.into_int_value(); + let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, int); + + result.set_index_const(ctx, i64::try_from(i).unwrap(), int); + } + + (len, result) + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + let input_int = input_sequence.value.into_int_value(); + + let len = Int(SizeT).const_1(generator, ctx.ctx); + let result = Int(SizeT).array_alloca(generator, ctx, len.value); + let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, input_int); + + // Storing into result[0] + result.store(ctx, int); + + (len, result) + } + _ => panic!( + "encountered unknown sequence type: {}", + ctx.unifier.stringify(input_sequence.ty) + ), + } +} diff --git a/nac3core/src/codegen/object/tuple.rs b/nac3core/src/codegen/object/tuple.rs new file mode 100644 index 00000000..376f89a5 --- /dev/null +++ b/nac3core/src/codegen/object/tuple.rs @@ -0,0 +1,98 @@ +use inkwell::values::StructValue; +use itertools::Itertools; + +use super::any::AnyObject; +use crate::{ + codegen::{model::*, CodeGenContext, CodeGenerator}, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// A NAC3 tuple object. +/// +/// NOTE: This struct has no copy trait. +#[derive(Debug, Clone)] +pub struct TupleObject<'ctx> { + /// The type of the tuple. + pub tys: Vec, + /// The underlying LLVM struct value of this tuple. + pub value: StructValue<'ctx>, +} + +impl<'ctx> TupleObject<'ctx> { + pub fn from_object(ctx: &mut CodeGenContext<'ctx, '_>, object: AnyObject<'ctx>) -> Self { + // TODO: Keep `is_vararg_ctx` from TTuple? + + // Sanity check on object type. + let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty(object.ty) else { + panic!( + "Expected type to be a TypeEnum::TTuple, got {}", + ctx.unifier.stringify(object.ty) + ); + }; + + // Check number of fields + let value = object.value.into_struct_value(); + let value_num_fields = value.get_type().count_fields() as usize; + assert!( + value_num_fields == tys.len(), + "Tuple type has {} item(s), but the LLVM struct value has {} field(s)", + tys.len(), + value_num_fields + ); + + TupleObject { tys: tys.clone(), value } + } + + /// Convenience function. Create a [`TupleObject`] from an iterator of objects. + pub fn from_objects( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + objects: I, + ) -> Self + where + I: IntoIterator>, + { + let (values, tys): (Vec<_>, Vec<_>) = + objects.into_iter().map(|object| (object.value, object.ty)).unzip(); + + let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec(); + let llvm_tuple_ty = ctx.ctx.struct_type(&llvm_tys, false); + + let pllvm_tuple = ctx.builder.build_alloca(llvm_tuple_ty, "tuple").unwrap(); + for (i, val) in values.into_iter().enumerate() { + let pval = ctx.builder.build_struct_gep(pllvm_tuple, i as u32, "value").unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + } + + let value = ctx.builder.build_load(pllvm_tuple, "").unwrap().into_struct_value(); + TupleObject { tys, value } + } + + #[must_use] + pub fn num_elements(&self) -> usize { + self.tys.len() + } + + /// Get the `len()` of this tuple. + #[must_use] + pub fn len( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + Int(SizeT).const_int(generator, ctx.ctx, self.num_elements() as u64, false) + } + + /// Get the `i`-th (0-based) object in this tuple. + pub fn index(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize) -> AnyObject<'ctx> { + assert!( + i < self.num_elements(), + "Tuple object with length {} have index {i}", + self.num_elements() + ); + + let value = ctx.builder.build_extract_value(self.value, i as u32, "tuple[{i}]").unwrap(); + let ty = self.tys[i]; + AnyObject { ty, value } + } +} diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 96613107..71ee35b0 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1134,3 +1134,23 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { _ => 0, } } + +/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible. +/// The `ndims` must only contain 1 value. +#[must_use] +pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 { + let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty); + let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else { + panic!("ndims_ty should be a TLiteral"); + }; + + assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value"); + + let ndims = values[0].clone(); + u64::try_from(ndims).unwrap() +} + +/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value. +pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type { + unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None) +}