From b6980c3a397b9cae39223e51dac99fdb762a87e0 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 15:00:27 +0800 Subject: [PATCH 1/4] core/ndstrides: add NDArrayObject::make_copy --- nac3core/src/codegen/object/ndarray/mod.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 7fa3365a..3d0b26f4 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -337,6 +337,24 @@ impl<'ctx> NDArrayObject<'ctx> { call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance); } + /// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over. + /// + /// The new ndarray will own its data and will be C-contiguous. + #[must_use] + pub fn make_copy( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Self { + let clone = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims); + + let shape = self.instance.gep(ctx, |f| f.shape).load(generator, ctx); + clone.copy_shape_from_array(generator, ctx, shape); + clone.create_data(generator, ctx); + clone.copy_data_from(generator, ctx, *self); + clone + } + /// Copy data from another ndarray. /// /// This ndarray and `src` is that their `np.size()` should be the same. Their shapes -- 2.44.2 From 9cfa2622cab1775602aab8da64896edd563d5523 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 15:02:42 +0800 Subject: [PATCH 2/4] core/ndstrides: add NDArrayObject::atleast_nd --- nac3core/src/codegen/object/ndarray/mod.rs | 1 + nac3core/src/codegen/object/ndarray/view.rs | 28 +++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 nac3core/src/codegen/object/ndarray/view.rs diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 3d0b26f4..a888ea70 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -25,6 +25,7 @@ pub mod factory; pub mod indexing; pub mod nditer; pub mod shape_util; +pub mod view; /// Fields of [`NDArray`] pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> { diff --git a/nac3core/src/codegen/object/ndarray/view.rs b/nac3core/src/codegen/object/ndarray/view.rs new file mode 100644 index 00000000..8afd5b4c --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/view.rs @@ -0,0 +1,28 @@ +use super::{indexing::RustNDIndex, NDArrayObject}; +use crate::codegen::{CodeGenContext, CodeGenerator}; + +impl<'ctx> NDArrayObject<'ctx> { + /// Make sure the ndarray is at least `ndmin`-dimensional. + /// + /// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended to the shape. + /// If this ndarray's `ndims` is not less than `ndmin`, this function does nothing and return this ndarray. + #[must_use] + pub fn atleast_nd( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndmin: u64, + ) -> Self { + if self.ndims < ndmin { + // Extend the dimensions with np.newaxis. + let mut indices = vec![]; + for _ in self.ndims..ndmin { + indices.push(RustNDIndex::NewAxis); + } + indices.push(RustNDIndex::Ellipsis); + self.index(generator, ctx, &indices) + } else { + *self + } + } +} -- 2.44.2 From b8190ccc874429028df362731c687bbda59e41a2 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 14:51:40 +0800 Subject: [PATCH 3/4] core/irrt: add List Needed for implementing np_array() --- nac3core/irrt/irrt.cpp | 2 +- nac3core/irrt/irrt/list.hpp | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index b586ce57..aa25667d 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -8,4 +8,4 @@ #include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" -#include "irrt/ndarray/indexing.hpp" +#include "irrt/ndarray/indexing.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/list.hpp b/nac3core/irrt/irrt/list.hpp index b389197f..3404fb8f 100644 --- a/nac3core/irrt/irrt/list.hpp +++ b/nac3core/irrt/irrt/list.hpp @@ -2,6 +2,21 @@ #include "irrt/int_types.hpp" #include "irrt/math_util.hpp" +#include "irrt/slice.hpp" + +namespace { +/** + * @brief A list in NAC3. + * + * The `items` field is opaque. You must rely on external contexts to + * know how to interpret it. + */ +template +struct List { + uint8_t* items; + SizeT len; +}; +} // namespace extern "C" { // Handle list assignment and dropping part of the list when -- 2.44.2 From 8f0084ac8a76a6262f16cfbf96c61b6c440627ea Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 15:10:39 +0800 Subject: [PATCH 4/4] core/ndstrides: implement np_array() It also checks for inconsistent dimensions if the input is a list. e.g., rejecting `[[1.0, 2.0], [3.0]]`. However, currently only `np_array(, copy=False)` and `np_array(, copy=True)` are supported. In NumPy, copy could be false, true, or None. Right now, NAC3's `np_array(, copy=False)` behaves like NumPy's `np.array(, copy=None)`. --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/array.hpp | 134 ++++++++++++++ nac3core/src/codegen/irrt/mod.rs | 34 +++- nac3core/src/codegen/model/ptr.rs | 9 + nac3core/src/codegen/numpy.rs | 57 ++---- nac3core/src/codegen/object/list.rs | 11 ++ nac3core/src/codegen/object/ndarray/array.rs | 184 +++++++++++++++++++ nac3core/src/codegen/object/ndarray/mod.rs | 14 +- 8 files changed, 400 insertions(+), 46 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/array.hpp create mode 100644 nac3core/src/codegen/object/ndarray/array.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index aa25667d..42c83355 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -8,4 +8,5 @@ #include "irrt/ndarray/basic.hpp" #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" -#include "irrt/ndarray/indexing.hpp" \ No newline at end of file +#include "irrt/ndarray/indexing.hpp" +#include "irrt/ndarray/array.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/array.hpp b/nac3core/irrt/irrt/ndarray/array.hpp new file mode 100644 index 00000000..d509f7f7 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/array.hpp @@ -0,0 +1,134 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/list.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +namespace ndarray { +namespace array { +/** + * @brief In the context of `np.array()`, deduce the ndarray's shape produced by `` and raise + * an exception if there is anything wrong with `` (e.g., inconsistent dimensions `np.array([[1.0, 2.0], + * [3.0]])`) + * + * If this function finds no issues with ``, the deduced shape is written to `shape`. The caller has the + * responsibility to allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because + * of implementation details. + */ +template +void set_and_validate_list_shape_helper(SizeT axis, List* list, SizeT ndims, SizeT* shape) { + if (shape[axis] == -1) { + // Dimension is unspecified. Set it. + shape[axis] = list->len; + } else { + // Dimension is specified. Check. + if (shape[axis] != list->len) { + // Mismatch, throw an error. + // NOTE: NumPy's error message is more complex and needs more PARAMS to display. + raise_exception(SizeT, EXN_VALUE_ERROR, + "The requested array has an inhomogenous shape " + "after {0} dimension(s).", + axis, shape[axis], list->len); + } + } + + if (axis + 1 == ndims) { + // `list` has type `list[ItemType]` + // Do nothing + } else { + // `list` has type `list[list[...]]` + List** lists = (List**)(list->items); + for (SizeT i = 0; i < list->len; i++) { + set_and_validate_list_shape_helper(axis + 1, lists[i], ndims, shape); + } + } +} + +/** + * @brief See `set_and_validate_list_shape_helper`. + */ +template +void set_and_validate_list_shape(List* list, SizeT ndims, SizeT* shape) { + for (SizeT axis = 0; axis < ndims; axis++) { + shape[axis] = -1; // Sentinel to say this dimension is unspecified. + } + set_and_validate_list_shape_helper(0, list, ndims, shape); +} + +/** + * @brief In the context of `np.array()`, copied the contents stored in `list` to `ndarray`. + * + * `list` is assumed to be "legal". (i.e., no inconsistent dimensions) + * + * # Notes on `ndarray` + * The caller is responsible for allocating space for `ndarray`. + * Here is what this function expects from `ndarray` when called: + * - `ndarray->data` has to be allocated, contiguous, and may contain uninitialized values. + * - `ndarray->itemsize` has to be initialized. + * - `ndarray->ndims` has to be initialized. + * - `ndarray->shape` has to be initialized. + * - `ndarray->strides` is ignored, but note that `ndarray->data` is contiguous. + * When this function call ends: + * - `ndarray->data` is written with contents from ``. + */ +template +void write_list_to_array_helper(SizeT axis, SizeT* index, List* list, NDArray* ndarray) { + debug_assert_eq(SizeT, list->len, ndarray->shape[axis]); + if (IRRT_DEBUG_ASSERT_BOOL) { + if (!ndarray::basic::is_c_contiguous(ndarray)) { + raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1], + NO_PARAM); + } + } + + if (axis + 1 == ndarray->ndims) { + // `list` has type `list[scalar]` + // `ndarray` is contiguous, so we can do this, and this is fast. + uint8_t* dst = ndarray->data + (ndarray->itemsize * (*index)); + __builtin_memcpy(dst, list->items, ndarray->itemsize * list->len); + *index += list->len; + } else { + // `list` has type `list[list[...]]` + List** lists = (List**)(list->items); + + for (SizeT i = 0; i < list->len; i++) { + write_list_to_array_helper(axis + 1, index, lists[i], ndarray); + } + } +} + +/** + * @brief See `write_list_to_array_helper`. + */ +template +void write_list_to_array(List* list, NDArray* ndarray) { + SizeT index = 0; + write_list_to_array_helper((SizeT)0, &index, list, ndarray); +} +} // namespace array +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::array; + +void __nac3_ndarray_array_set_and_validate_list_shape(List* list, int32_t ndims, int32_t* shape) { + set_and_validate_list_shape(list, ndims, shape); +} + +void __nac3_ndarray_array_set_and_validate_list_shape64(List* list, int64_t ndims, int64_t* shape) { + set_and_validate_list_shape(list, ndims, shape); +} + +void __nac3_ndarray_array_write_list_to_array(List* list, NDArray* ndarray) { + write_list_to_array(list, ndarray); +} + +void __nac3_ndarray_array_write_list_to_array64(List* list, NDArray* ndarray) { + write_list_to_array(list, ndarray); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 94996cd3..684c0d24 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -19,7 +19,10 @@ use super::{ llvm_intrinsics, macros::codegen_unreachable, model::{function::FnCall, *}, - object::ndarray::{indexing::NDIndex, nditer::NDIter, NDArray}, + object::{ + list::List, + ndarray::{indexing::NDIndex, nditer::NDIter, NDArray}, + }, stmt::gen_for_callback_incrementing, CodeGenContext, CodeGenerator, }; @@ -1129,3 +1132,32 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( .arg(dst_ndarray) .returning_void(); } + +pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + list: Instance<'ctx, Ptr>>>>, + ndims: Instance<'ctx, Int>, + shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_array_set_and_validate_list_shape", + ); + FnCall::builder(generator, ctx, &name).arg(list).arg(ndims).arg(shape).returning_void(); +} + +pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + list: Instance<'ctx, Ptr>>>>, + ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_array_write_list_to_array", + ); + FnCall::builder(generator, ctx, &name).arg(list).arg(ndarray).returning_void(); +} diff --git a/nac3core/src/codegen/model/ptr.rs b/nac3core/src/codegen/model/ptr.rs index b5258ec0..c635a0bc 100644 --- a/nac3core/src/codegen/model/ptr.rs +++ b/nac3core/src/codegen/model/ptr.rs @@ -181,6 +181,15 @@ impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr> { Ptr(new_item).pointer_cast(generator, ctx, self.value) } + /// Cast this pointer to `uint8_t*` + pub fn cast_to_pi8( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Ptr>> { + Ptr(Int(Byte)).pointer_cast(generator, ctx, self.value) + } + /// Check if the pointer is null with [`inkwell::builder::Builder::build_is_null`]. pub fn is_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int> { let value = ctx.builder.build_is_null(self.value, "").unwrap(); diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 68e5149f..17460250 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -19,6 +19,7 @@ use super::{ }, llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, + model::*, object::{ any::AnyObject, ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject}, @@ -29,13 +30,13 @@ use super::{ use crate::{ symbol_resolver::ValueEnum, toplevel::{ - helper::{extract_ndims, PrimDef}, + helper::extract_ndims, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, }, typecheck::{ magic_methods::Binop, - typedef::{FunSignature, Type, TypeEnum}, + typedef::{FunSignature, Type}, }, }; @@ -1839,26 +1840,6 @@ pub fn gen_ndarray_array<'ctx>( assert!(matches!(args.len(), 1..=3)); let obj_ty = fun.0.args[0].ty; - let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 - } - - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { - let mut ty = *params.iter().next().unwrap().1; - while let TypeEnum::TObj { obj_id, params, .. } = &*context.unifier.get_ty_immutable(ty) - { - if *obj_id != PrimDef::List.id() { - break; - } - - ty = *params.iter().next().unwrap().1; - } - ty - } - - _ => obj_ty, - }; let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?; let copy_arg = if let Some(arg) = @@ -1874,28 +1855,18 @@ pub fn gen_ndarray_array<'ctx>( ) }; - let ndmin_arg = if let Some(arg) = - args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) - { - let ndmin_ty = fun.0.args[2].ty; - arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)? - } else { - context.gen_symbol_val( - generator, - fun.0.args[2].default_value.as_ref().unwrap(), - fun.0.args[2].ty, - ) - }; + // The ndmin argument is ignored. We can simply force the ndarray's number of dimensions to be + // the `ndims` of the function return type. + let (_, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); - call_ndarray_array_impl( - generator, - context, - obj_elem_ty, - obj_arg, - copy_arg.into_int_value(), - ndmin_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let object = AnyObject { value: obj_arg, ty: obj_ty }; + // NAC3 booleans are i8. + let copy = Int(Bool).truncate(generator, context, copy_arg.into_int_value()); + let ndarray = NDArrayObject::make_np_array(generator, context, object, copy) + .atleast_nd(generator, context, ndims); + + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.eye`. diff --git a/nac3core/src/codegen/object/list.rs b/nac3core/src/codegen/object/list.rs index 36a864d5..6e27ae71 100644 --- a/nac3core/src/codegen/object/list.rs +++ b/nac3core/src/codegen/object/list.rs @@ -30,6 +30,17 @@ impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for List { } } +impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr>>> { + /// Cast the items pointer to `uint8_t*`. + pub fn with_pi8_items( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Ptr>>>> { + self.pointer_cast(generator, ctx, Struct(List { item: Int(Byte) })) + } +} + /// A NAC3 Python List object. #[derive(Debug, Clone, Copy)] pub struct ListObject<'ctx> { diff --git a/nac3core/src/codegen/object/ndarray/array.rs b/nac3core/src/codegen/object/ndarray/array.rs new file mode 100644 index 00000000..cca00692 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/array.rs @@ -0,0 +1,184 @@ +use super::NDArrayObject; +use crate::{ + codegen::{ + irrt::{ + call_nac3_ndarray_array_set_and_validate_list_shape, + call_nac3_ndarray_array_write_list_to_array, + }, + model::*, + object::{any::AnyObject, list::ListObject}, + stmt::gen_if_else_expr_callback, + CodeGenContext, CodeGenerator, + }, + toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims}, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// Get the expected `dtype` and `ndims` of the ndarray returned by `np_array(list)`. +fn get_list_object_dtype_and_ndims<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, + list: ListObject<'ctx>, +) -> (Type, u64) { + let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list.item_type); + + let ndims = arraylike_get_ndims(&mut ctx.unifier, list.item_type); + let ndims = ndims + 1; // To count `list` itself. + + (dtype, ndims) +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Implementation of `np_array(, copy=True)` + fn make_np_array_list_copy_true_impl( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + list: ListObject<'ctx>, + ) -> Self { + let (dtype, ndims_int) = get_list_object_dtype_and_ndims(ctx, list); + let list_value = list.instance.with_pi8_items(generator, ctx); + + // Validate `list` has a consistent shape. + // Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`. + // If `list` has a consistent shape, deduce the shape and write it to `shape`. + let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int, false); + let shape = Int(SizeT).array_alloca(generator, ctx, ndims.value); + call_nac3_ndarray_array_set_and_validate_list_shape( + generator, ctx, list_value, ndims, shape, + ); + + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims_int); + ndarray.copy_shape_from_array(generator, ctx, shape); + ndarray.create_data(generator, ctx); + + // Copy all contents from the list. + call_nac3_ndarray_array_write_list_to_array(generator, ctx, list_value, ndarray.instance); + + ndarray + } + + /// Implementation of `np_array(, copy=None)` + fn make_np_array_list_copy_none_impl( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + list: ListObject<'ctx>, + ) -> Self { + // np_array without copying is only possible `list` is not nested. + // + // If `list` is `list[T]`, we can create an ndarray with `data` set + // to the array pointer of `list`. + // + // If `list` is `list[list[T]]` or worse, copy. + + let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list); + if ndims == 1 { + // `list` is not nested + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, 1); + + // Set data + let data = list.instance.get(generator, ctx, |f| f.items).cast_to_pi8(generator, ctx); + ndarray.instance.set(ctx, |f| f.data, data); + + // ndarray->shape[0] = list->len; + let shape = ndarray.instance.get(generator, ctx, |f| f.shape); + let list_len = list.instance.get(generator, ctx, |f| f.len); + shape.set_index_const(ctx, 0, list_len); + + // Set strides, the `data` is contiguous + ndarray.set_strides_contiguous(generator, ctx); + + ndarray + } else { + // `list` is nested, copy + NDArrayObject::make_np_array_list_copy_true_impl(generator, ctx, list) + } + } + + /// Implementation of `np_array(, copy=copy)` + fn make_np_array_list_impl( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + list: ListObject<'ctx>, + copy: Instance<'ctx, Int>, + ) -> Self { + let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list); + + let ndarray = gen_if_else_expr_callback( + generator, + ctx, + |_generator, _ctx| Ok(copy.value), + |generator, ctx| { + let ndarray = + NDArrayObject::make_np_array_list_copy_true_impl(generator, ctx, list); + Ok(Some(ndarray.instance.value)) + }, + |generator, ctx| { + let ndarray = + NDArrayObject::make_np_array_list_copy_none_impl(generator, ctx, list); + Ok(Some(ndarray.instance.value)) + }, + ) + .unwrap() + .unwrap(); + + NDArrayObject::from_value_and_unpacked_types(generator, ctx, ndarray, dtype, ndims) + } + + /// Implementation of `np_array(, copy=copy)`. + pub fn make_np_array_ndarray_impl( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayObject<'ctx>, + copy: Instance<'ctx, Int>, + ) -> Self { + let ndarray_val = gen_if_else_expr_callback( + generator, + ctx, + |_generator, _ctx| Ok(copy.value), + |generator, ctx| { + let ndarray = ndarray.make_copy(generator, ctx); // Force copy + Ok(Some(ndarray.instance.value)) + }, + |_generator, _ctx| { + // No need to copy. Return `ndarray` itself. + Ok(Some(ndarray.instance.value)) + }, + ) + .unwrap() + .unwrap(); + + NDArrayObject::from_value_and_unpacked_types( + generator, + ctx, + ndarray_val, + ndarray.dtype, + ndarray.ndims, + ) + } + + /// Create a new ndarray like `np.array()`. + /// + /// NOTE: The `ndmin` argument is not here. You may want to + /// do [`NDArrayObject::atleast_nd`] to achieve that. + pub fn make_np_array( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + copy: Instance<'ctx, Int>, + ) -> Self { + match &*ctx.unifier.get_ty(object.ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListObject::from_object(generator, ctx, object); + NDArrayObject::make_np_array_list_impl(generator, ctx, list, copy) + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayObject::from_object(generator, ctx, object); + NDArrayObject::make_np_array_ndarray_impl(generator, ctx, ndarray, copy) + } + _ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object.ty)), // Typechecker ensures this + } + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index a888ea70..85643a60 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -21,6 +21,7 @@ use crate::{ typecheck::typedef::Type, }; +pub mod array; pub mod factory; pub mod indexing; pub mod nditer; @@ -73,8 +74,19 @@ impl<'ctx> NDArrayObject<'ctx> { ) -> NDArrayObject<'ctx> { let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty); let ndims = extract_ndims(&ctx.unifier, ndims); + Self::from_value_and_unpacked_types(generator, ctx, object.value, dtype, ndims) + } - let value = Ptr(Struct(NDArray)).check_value(generator, ctx.ctx, object.value).unwrap(); + /// Like [`NDArrayObject::from_object`] but you directly supply the ndarray's + /// `dtype` and `ndims`. + pub fn from_value_and_unpacked_types, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: V, + dtype: Type, + ndims: u64, + ) -> Self { + let value = Ptr(Struct(NDArray)).check_value(generator, ctx.ctx, value).unwrap(); NDArrayObject { dtype, ndims, instance: value } } -- 2.44.2