From 1eb462a5c2ca5f1289664a1a0f4c331873bf7dd2 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 14:51:40 +0800 Subject: [PATCH] [core] codegen/ndarray: Reimplement np_array() Based on 8f0084ac: core/ndstrides: implement np_array() It also checks for inconsistent dimensions if the input is a list. e.g., rejecting `[[1.0, 2.0], [3.0]]`. However, currently only `np_array(, copy=False)` and `np_array (, copy=True)` are supported. In NumPy, copy could be false, true, or None. Right now, NAC3's `np_array(, copy=False)` behaves like NumPy's `np.array(, copy=None)`. --- nac3core/irrt/irrt.cpp | 1 + nac3core/irrt/irrt/list.hpp | 15 + nac3core/irrt/irrt/ndarray/array.hpp | 134 ++++++ nac3core/src/codegen/irrt/ndarray/array.rs | 63 +++ nac3core/src/codegen/irrt/ndarray/mod.rs | 2 + nac3core/src/codegen/numpy.rs | 460 +------------------- nac3core/src/codegen/types/ndarray/array.rs | 244 +++++++++++ nac3core/src/codegen/types/ndarray/mod.rs | 1 + nac3core/src/codegen/values/list.rs | 19 +- nac3core/src/codegen/values/ndarray/mod.rs | 2 +- 10 files changed, 496 insertions(+), 445 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/array.hpp create mode 100644 nac3core/src/codegen/irrt/ndarray/array.rs create mode 100644 nac3core/src/codegen/types/ndarray/array.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 8447fc5a..06178172 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -8,3 +8,4 @@ #include "irrt/ndarray/def.hpp" #include "irrt/ndarray/iter.hpp" #include "irrt/ndarray/indexing.hpp" +#include "irrt/ndarray/array.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/list.hpp b/nac3core/irrt/irrt/list.hpp index 28543945..1edfe498 100644 --- a/nac3core/irrt/irrt/list.hpp +++ b/nac3core/irrt/irrt/list.hpp @@ -2,6 +2,21 @@ #include "irrt/int_types.hpp" #include "irrt/math_util.hpp" +#include "irrt/slice.hpp" + +namespace { +/** + * @brief A list in NAC3. + * + * The `items` field is opaque. You must rely on external contexts to + * know how to interpret it. + */ +template +struct List { + uint8_t* items; + SizeT len; +}; +} // namespace extern "C" { // Handle list assignment and dropping part of the list when diff --git a/nac3core/irrt/irrt/ndarray/array.hpp b/nac3core/irrt/irrt/ndarray/array.hpp new file mode 100644 index 00000000..02c55ab7 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/array.hpp @@ -0,0 +1,134 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/list.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/def.hpp" + +namespace { +namespace ndarray { +namespace array { +/** + * @brief In the context of `np.array()`, deduce the ndarray's shape produced by `` and raise + * an exception if there is anything wrong with `` (e.g., inconsistent dimensions `np.array([[1.0, 2.0], + * [3.0]])`) + * + * If this function finds no issues with ``, the deduced shape is written to `shape`. The caller has the + * responsibility to allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because + * of implementation details. + */ +template +void set_and_validate_list_shape_helper(SizeT axis, List* list, SizeT ndims, SizeT* shape) { + if (shape[axis] == -1) { + // Dimension is unspecified. Set it. + shape[axis] = list->len; + } else { + // Dimension is specified. Check. + if (shape[axis] != list->len) { + // Mismatch, throw an error. + // NOTE: NumPy's error message is more complex and needs more PARAMS to display. + raise_exception(SizeT, EXN_VALUE_ERROR, + "The requested array has an inhomogenous shape " + "after {0} dimension(s).", + axis, shape[axis], list->len); + } + } + + if (axis + 1 == ndims) { + // `list` has type `list[ItemType]` + // Do nothing + } else { + // `list` has type `list[list[...]]` + List** lists = (List**)(list->items); + for (SizeT i = 0; i < list->len; i++) { + set_and_validate_list_shape_helper(axis + 1, lists[i], ndims, shape); + } + } +} + +/** + * @brief See `set_and_validate_list_shape_helper`. + */ +template +void set_and_validate_list_shape(List* list, SizeT ndims, SizeT* shape) { + for (SizeT axis = 0; axis < ndims; axis++) { + shape[axis] = -1; // Sentinel to say this dimension is unspecified. + } + set_and_validate_list_shape_helper(0, list, ndims, shape); +} + +/** + * @brief In the context of `np.array()`, copied the contents stored in `list` to `ndarray`. + * + * `list` is assumed to be "legal". (i.e., no inconsistent dimensions) + * + * # Notes on `ndarray` + * The caller is responsible for allocating space for `ndarray`. + * Here is what this function expects from `ndarray` when called: + * - `ndarray->data` has to be allocated, contiguous, and may contain uninitialized values. + * - `ndarray->itemsize` has to be initialized. + * - `ndarray->ndims` has to be initialized. + * - `ndarray->shape` has to be initialized. + * - `ndarray->strides` is ignored, but note that `ndarray->data` is contiguous. + * When this function call ends: + * - `ndarray->data` is written with contents from ``. + */ +template +void write_list_to_array_helper(SizeT axis, SizeT* index, List* list, NDArray* ndarray) { + debug_assert_eq(SizeT, list->len, ndarray->shape[axis]); + if (IRRT_DEBUG_ASSERT_BOOL) { + if (!ndarray::basic::is_c_contiguous(ndarray)) { + raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1], + NO_PARAM); + } + } + + if (axis + 1 == ndarray->ndims) { + // `list` has type `list[scalar]` + // `ndarray` is contiguous, so we can do this, and this is fast. + uint8_t* dst = static_cast(ndarray->data) + (ndarray->itemsize * (*index)); + __builtin_memcpy(dst, list->items, ndarray->itemsize * list->len); + *index += list->len; + } else { + // `list` has type `list[list[...]]` + List** lists = (List**)(list->items); + + for (SizeT i = 0; i < list->len; i++) { + write_list_to_array_helper(axis + 1, index, lists[i], ndarray); + } + } +} + +/** + * @brief See `write_list_to_array_helper`. + */ +template +void write_list_to_array(List* list, NDArray* ndarray) { + SizeT index = 0; + write_list_to_array_helper((SizeT)0, &index, list, ndarray); +} +} // namespace array +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::array; + +void __nac3_ndarray_array_set_and_validate_list_shape(List* list, int32_t ndims, int32_t* shape) { + set_and_validate_list_shape(list, ndims, shape); +} + +void __nac3_ndarray_array_set_and_validate_list_shape64(List* list, int64_t ndims, int64_t* shape) { + set_and_validate_list_shape(list, ndims, shape); +} + +void __nac3_ndarray_array_write_list_to_array(List* list, NDArray* ndarray) { + write_list_to_array(list, ndarray); +} + +void __nac3_ndarray_array_write_list_to_array64(List* list, NDArray* ndarray) { + write_list_to_array(list, ndarray); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/array.rs b/nac3core/src/codegen/irrt/ndarray/array.rs new file mode 100644 index 00000000..46466a7c --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/array.rs @@ -0,0 +1,63 @@ +use inkwell::{types::BasicTypeEnum, values::IntValue}; + +use crate::codegen::{ + expr::infer_and_call_function, + irrt::get_usize_dependent_function_name, + values::{ndarray::NDArrayValue, ListValue, ProxyValue, TypedArrayLikeAccessor}, + CodeGenContext, CodeGenerator, +}; + +pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + list: ListValue<'ctx>, + ndims: IntValue<'ctx>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, +) { + let llvm_usize = generator.get_size_type(ctx.ctx); + assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); + assert_eq!(ndims.get_type(), llvm_usize); + assert_eq!( + BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(), + llvm_usize.into() + ); + + let name = get_usize_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_array_set_and_validate_list_shape", + ); + + infer_and_call_function( + ctx, + &name, + None, + &[list.as_base_value().into(), ndims.into(), shape.base_ptr(ctx, generator).into()], + None, + None, + ); +} + +pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + list: ListValue<'ctx>, + ndarray: NDArrayValue<'ctx>, +) { + assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); + + let name = get_usize_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_array_write_list_to_array", + ); + + infer_and_call_function( + ctx, + &name, + None, + &[list.as_base_value().into(), ndarray.as_base_value().into()], + None, + None, + ); +} diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 56017c94..307ec6bb 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -16,10 +16,12 @@ use crate::codegen::{ }, CodeGenContext, CodeGenerator, }; +pub use array::*; pub use basic::*; pub use indexing::*; pub use iter::*; +mod array; mod basic; mod indexing; mod iter; diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9c57919c..09d848dc 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,7 +1,7 @@ use inkwell::{ - types::{BasicType, BasicTypeEnum, PointerType}, + types::BasicType, values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, - AddressSpace, IntPredicate, OptimizationLevel, + IntPredicate, OptimizationLevel, }; use nac3parser::ast::{Operator, StrRef}; @@ -18,12 +18,9 @@ use super::{ llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::{ - ndarray::{ - factory::{ndarray_one_value, ndarray_zero_value}, - NDArrayType, - }, - ListType, ProxyType, + types::ndarray::{ + factory::{ndarray_one_value, ndarray_zero_value}, + NDArrayType, }, values::{ ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, @@ -35,14 +32,10 @@ use super::{ }; use crate::{ symbol_resolver::ValueEnum, - toplevel::{ - helper::{extract_ndims, PrimDef}, - numpy::unpack_ndarray_var_tys, - DefinitionId, - }, + toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, typecheck::{ magic_methods::Binop, - typedef::{FunSignature, Type, TypeEnum}, + typedef::{FunSignature, Type}, }, }; @@ -413,394 +406,6 @@ where Ok(res) } -/// Returns the number of dimensions for a multidimensional list as an [`IntValue`]. -fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ty: PointerType<'ctx>, -) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let list_ty = ListType::from_type(ty, llvm_usize); - let list_elem_ty = list_ty.element_type().unwrap(); - - let ndims = llvm_usize.const_int(1, false); - match list_elem_ty { - BasicTypeEnum::PointerType(ptr_ty) - if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) - } - - BasicTypeEnum::PointerType(ptr_ty) - if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - todo!("Getting ndims for list[ndarray] not supported") - } - - _ => ndims, - } -} - -/// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`]. -fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - src_lst: ListValue<'ctx>, - dim: u64, -) -> Result<(), String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let list_elem_ty = src_lst.get_type().element_type().unwrap(); - - match list_elem_ty { - BasicTypeEnum::PointerType(ptr_ty) - if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - // The stride of elements in this dimension, i.e. the number of elements between arr[i] - // and arr[i + 1] in this dimension - let stride = call_ndarray_calc_size( - generator, - ctx, - &dst_arr.shape(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - - gen_for_range_callback( - generator, - ctx, - None, - true, - |_, _| Ok(llvm_usize.const_zero()), - (|_, ctx| Ok(src_lst.load_size(ctx, None)), false), - |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _, i| { - let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); - let offset = ctx - .builder - .build_int_mul( - offset, - ctx.builder - .build_int_truncate_or_bit_cast( - dst_arr.get_type().element_type().size_of().unwrap(), - offset.get_type(), - "", - ) - .unwrap(), - "", - ) - .unwrap(); - - let dst_ptr = - unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() }; - - let nested_lst_elem = ListValue::from_pointer_value( - unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) } - .into_pointer_value(), - llvm_usize, - None, - ); - - ndarray_from_ndlist_impl( - generator, - ctx, - (dst_arr, dst_ptr), - nested_lst_elem, - dim + 1, - )?; - - Ok(()) - }, - )?; - } - - BasicTypeEnum::PointerType(ptr_ty) - if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => - { - todo!("Not implemented for list[ndarray]") - } - - _ => { - let lst_len = src_lst.load_size(ctx, None); - let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); - let sizeof_elem = - ctx.builder.build_int_z_extend_or_bit_cast(sizeof_elem, llvm_usize, "").unwrap(); - - let cpy_len = ctx - .builder - .build_int_mul( - ctx.builder.build_int_z_extend_or_bit_cast(lst_len, llvm_usize, "").unwrap(), - sizeof_elem, - "", - ) - .unwrap(); - - call_memcpy_generic( - ctx, - dst_slice_ptr, - src_lst.data().base_ptr(ctx, generator), - cpy_len, - llvm_i1.const_zero(), - ); - } - } - - Ok(()) -} - -/// LLVM-typed implementation for `ndarray.array`. -fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - object: BasicValueEnum<'ctx>, - copy: IntValue<'ctx>, - ndmin: IntValue<'ctx>, -) -> Result, String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndmin = ctx.builder.build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "").unwrap(); - - // TODO(Derppening): Add assertions for sizes of different dimensions - - // object is not a pointer - 0-dim NDArray - if !object.is_pointer_value() { - let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[])?; - - unsafe { - ndarray.data().set_unchecked(ctx, generator, &llvm_usize.const_zero(), object); - } - - return Ok(ndarray); - } - - let object = object.into_pointer_value(); - - // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims - if NDArrayValue::is_representable(object, llvm_usize).is_ok() { - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, None, llvm_usize, None); - - let ndarray = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - let copy_nez = ctx - .builder - .build_int_compare(IntPredicate::NE, copy, llvm_i1.const_zero(), "") - .unwrap(); - let ndmin_gt_ndims = ctx - .builder - .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") - .unwrap(); - - Ok(ctx.builder.build_and(copy_nez, ndmin_gt_ndims, "").unwrap()) - }, - |generator, ctx| { - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &object, - |_, ctx, object| { - let ndims = object.load_ndims(ctx); - let ndmin_gt_ndims = ctx - .builder - .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") - .unwrap(); - - Ok(ctx - .builder - .build_select(ndmin_gt_ndims, ndmin, ndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - |generator, ctx, object, idx| { - let ndims = object.load_ndims(ctx); - let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); - // The number of dimensions to prepend 1's to - let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); - - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::UGE, idx, offset, "") - .unwrap()) - }, - |_, _| Ok(Some(llvm_usize.const_int(1, false))), - |_, ctx| Ok(Some(ctx.builder.build_int_sub(idx, offset, "").unwrap())), - )? - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - )?; - - ndarray_sliced_copyto_impl( - generator, - ctx, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - (object, object.data().base_ptr(ctx, generator)), - 0, - &[], - )?; - - Ok(Some(ndarray.as_base_value())) - }, - |_, _| Ok(Some(object.as_base_value())), - )?; - - return Ok(NDArrayValue::from_pointer_value( - ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), - llvm_elem_ty, - None, - llvm_usize, - None, - )); - } - - // Remaining case: TList - assert!(ListValue::is_representable(object, llvm_usize).is_ok()); - let object = ListValue::from_pointer_value(object, llvm_usize, None); - - // The number of dimensions to prepend 1's to - let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); - let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); - - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &object, - |generator, ctx, object| { - let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin_gt_ndims = - ctx.builder.build_int_compare(IntPredicate::UGT, ndmin, ndims, "").unwrap(); - - Ok(ctx - .builder - .build_select(ndmin_gt_ndims, ndmin, ndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - |generator, ctx, object, idx| { - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx.builder.build_int_compare(IntPredicate::ULT, idx, offset, "").unwrap()) - }, - |_, _| Ok(Some(llvm_usize.const_int(1, false))), - |generator, ctx| { - let make_llvm_list = |elem_ty: BasicTypeEnum<'ctx>| { - ctx.ctx.struct_type( - &[elem_ty.ptr_type(AddressSpace::default()).into(), llvm_usize.into()], - false, - ) - }; - - let llvm_i8 = ctx.ctx.i8_type(); - let llvm_list_i8 = make_llvm_list(llvm_i8.into()); - let llvm_plist_i8 = llvm_list_i8.ptr_type(AddressSpace::default()); - - // Cast list to { i8*, usize } since we only care about the size - let lst = generator - .gen_var_alloc( - ctx, - ListType::new(generator, ctx.ctx, llvm_i8.into()).as_base_type().into(), - None, - ) - .unwrap(); - ctx.builder - .build_store( - lst, - ctx.builder - .build_bit_cast(object.as_base_value(), llvm_plist_i8, "") - .unwrap(), - ) - .unwrap(); - - let stop = ctx.builder.build_int_sub(idx, offset, "").unwrap(); - gen_for_range_callback( - generator, - ctx, - None, - true, - |_, _| Ok(llvm_usize.const_zero()), - (|_, _| Ok(stop), false), - |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _, _| { - let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) - .ptr_type(AddressSpace::default()); - - let this_dim = ctx - .builder - .build_load(lst, "") - .map(BasicValueEnum::into_pointer_value) - .map(|v| ctx.builder.build_bit_cast(v, plist_plist_i8, "").unwrap()) - .map(BasicValueEnum::into_pointer_value) - .unwrap(); - let this_dim = - ListValue::from_pointer_value(this_dim, llvm_usize, None); - - // TODO: Assert this_dim.sz != 0 - - let next_dim = unsafe { - this_dim.data().get_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - } - .into_pointer_value(); - ctx.builder - .build_store( - lst, - ctx.builder - .build_bit_cast(next_dim, llvm_plist_i8, "") - .unwrap(), - ) - .unwrap(); - - Ok(()) - }, - )?; - - let lst = ListValue::from_pointer_value( - ctx.builder - .build_load(lst, "") - .map(BasicValueEnum::into_pointer_value) - .unwrap(), - llvm_usize, - None, - ); - - Ok(Some(lst.load_size(ctx, None))) - }, - )? - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - )?; - - ndarray_from_ndlist_impl( - generator, - ctx, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - object, - 0, - )?; - - Ok(ndarray) -} - /// LLVM-typed implementation for generating the implementation for `ndarray.eye`. /// /// * `elem_ty` - The element type of the `NDArray`. @@ -1635,26 +1240,6 @@ pub fn gen_ndarray_array<'ctx>( assert!(matches!(args.len(), 1..=3)); let obj_ty = fun.0.args[0].ty; - let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 - } - - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { - let mut ty = *params.iter().next().unwrap().1; - while let TypeEnum::TObj { obj_id, params, .. } = &*context.unifier.get_ty_immutable(ty) - { - if *obj_id != PrimDef::List.id() { - break; - } - - ty = *params.iter().next().unwrap().1; - } - ty - } - - _ => obj_ty, - }; let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?; let copy_arg = if let Some(arg) = @@ -1670,28 +1255,17 @@ pub fn gen_ndarray_array<'ctx>( ) }; - let ndmin_arg = if let Some(arg) = - args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) - { - let ndmin_ty = fun.0.args[2].ty; - arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)? - } else { - context.gen_symbol_val( - generator, - fun.0.args[2].default_value.as_ref().unwrap(), - fun.0.args[2].ty, - ) - }; + // The ndmin argument is ignored. We can simply force the ndarray's number of dimensions to be + // the `ndims` of the function return type. + let (_, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); - call_ndarray_array_impl( - generator, - context, - obj_elem_ty, - obj_arg, - copy_arg.into_int_value(), - ndmin_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let copy = generator.bool_to_i1(context, copy_arg.into_int_value()); + let ndarray = NDArrayType::from_unifier_type(generator, context, fun.0.ret) + .construct_numpy_array(generator, context, (obj_ty, obj_arg), copy, None) + .atleast_nd(generator, context, ndims); + + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.eye`. diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs new file mode 100644 index 00000000..b19a2595 --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -0,0 +1,244 @@ +use inkwell::{ + types::BasicTypeEnum, + values::{BasicValueEnum, IntValue}, + AddressSpace, +}; + +use crate::{ + codegen::{ + irrt, + stmt::gen_if_else_expr_callback, + types::{ndarray::NDArrayType, ListType, ProxyType}, + values::{ + ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, + TypedArrayLikeAdapter, TypedArrayLikeMutator, + }, + CodeGenContext, CodeGenerator, + }, + toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims}, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// Get the expected `dtype` and `ndims` of the ndarray returned by `np_array(list)`. +fn get_list_object_dtype_and_ndims<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + list_ty: Type, +) -> (BasicTypeEnum<'ctx>, u64) { + let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list_ty); + let ndims = arraylike_get_ndims(&mut ctx.unifier, list_ty); + + (ctx.get_llvm_type(generator, dtype), ndims) +} + +impl<'ctx> NDArrayType<'ctx> { + /// Implementation of `np_array(, copy=True)` + fn construct_numpy_array_from_list_copy_true_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (list_ty, list): (Type, ListValue<'ctx>), + name: Option<&'ctx str>, + ) -> >::Value { + let (dtype, ndims_int) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); + assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims_int)); + assert_eq!(dtype, self.dtype); + + let list_value = list.as_i8_list(generator, ctx); + + // Validate `list` has a consistent shape. + // Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`. + // If `list` has a consistent shape, deduce the shape and write it to `shape`. + let ndims = self.llvm_usize.const_int(ndims_int, false); + let shape = ctx.builder.build_array_alloca(self.llvm_usize, ndims, "").unwrap(); + let shape = ArraySliceValue::from_ptr_val(shape, ndims, None); + let shape = TypedArrayLikeAdapter::from( + shape, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + irrt::ndarray::call_nac3_ndarray_array_set_and_validate_list_shape( + generator, ctx, list_value, ndims, &shape, + ); + + let ndarray = Self::new(generator, ctx.ctx, dtype, Some(ndims_int)) + .construct_uninitialized(generator, ctx, name); + ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); + unsafe { ndarray.create_data(generator, ctx) }; + + // Copy all contents from the list. + irrt::ndarray::call_nac3_ndarray_array_write_list_to_array( + generator, ctx, list_value, ndarray, + ); + + ndarray + } + + /// Implementation of `np_array(, copy=None)` + fn construct_numpy_array_from_list_copy_none_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (list_ty, list): (Type, ListValue<'ctx>), + name: Option<&'ctx str>, + ) -> >::Value { + // np_array without copying is only possible `list` is not nested. + // + // If `list` is `list[T]`, we can create an ndarray with `data` set + // to the array pointer of `list`. + // + // If `list` is `list[list[T]]` or worse, copy. + + let (dtype, ndims) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); + if ndims == 1 { + // `list` is not nested + assert_eq!(ndims, 1); + assert!(self.ndims.is_none_or(|self_ndims| self_ndims >= ndims)); + assert_eq!(dtype, self.dtype); + + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + + let ndarray = Self::new(generator, ctx.ctx, dtype, Some(1)) + .construct_uninitialized(generator, ctx, name); + + // Set data + let data = ctx + .builder + .build_pointer_cast(list.data().base_ptr(ctx, generator), llvm_pi8, "") + .unwrap(); + ndarray.store_data(ctx, data); + + // ndarray->shape[0] = list->len; + let shape = ndarray.shape(); + let list_len = list.load_size(ctx, None); + unsafe { + shape.set_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), list_len); + } + + // Set strides, the `data` is contiguous + ndarray.set_strides_contiguous(generator, ctx); + + ndarray + } else { + // `list` is nested, copy + self.construct_numpy_array_from_list_copy_true_impl( + generator, + ctx, + (list_ty, list), + name, + ) + } + } + + /// Implementation of `np_array(, copy=copy)` + fn construct_numpy_array_list_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (list_ty, list): (Type, ListValue<'ctx>), + copy: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!(copy.get_type(), ctx.ctx.bool_type()); + + let (dtype, ndims) = get_list_object_dtype_and_ndims(generator, ctx, list_ty); + + let ndarray = gen_if_else_expr_callback( + generator, + ctx, + |_generator, _ctx| Ok(copy), + |generator, ctx| { + let ndarray = self.construct_numpy_array_from_list_copy_true_impl( + generator, + ctx, + (list_ty, list), + name, + ); + Ok(Some(ndarray.as_base_value())) + }, + |generator, ctx| { + let ndarray = self.construct_numpy_array_from_list_copy_none_impl( + generator, + ctx, + (list_ty, list), + name, + ); + Ok(Some(ndarray.as_base_value())) + }, + ) + .unwrap() + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + + NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims)).map_value(ndarray, None) + } + + /// Implementation of `np_array(, copy=copy)`. + pub fn construct_numpy_array_ndarray_impl( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + copy: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!(ndarray.get_type().dtype, self.dtype); + assert!(ndarray.get_type().ndims.is_none_or(|ndarray_ndims| self + .ndims + .is_none_or(|self_ndims| self_ndims >= ndarray_ndims))); + assert_eq!(copy.get_type(), ctx.ctx.bool_type()); + + let ndarray_val = gen_if_else_expr_callback( + generator, + ctx, + |_generator, _ctx| Ok(copy), + |generator, ctx| { + let ndarray = ndarray.make_copy(generator, ctx); // Force copy + Ok(Some(ndarray.as_base_value())) + }, + |_generator, _ctx| { + // No need to copy. Return `ndarray` itself. + Ok(Some(ndarray.as_base_value())) + }, + ) + .unwrap() + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + + ndarray.get_type().map_value(ndarray_val, name) + } + + /// Create a new ndarray like `np.array()`. + /// + /// Note that the returned [`NDArrayValue`] may have fewer dimensions than is specified by this + /// instance. Use [`NDArrayValue::atleast_nd`] on the returned value if an `ndarray` instance + /// with the exact number of dimensions is needed. + pub fn construct_numpy_array( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (object_ty, object): (Type, BasicValueEnum<'ctx>), + copy: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + match &*ctx.unifier.get_ty_immutable(object_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListType::from_unifier_type(generator, ctx, object_ty) + .map_value(object.into_pointer_value(), None); + self.construct_numpy_array_list_impl(generator, ctx, (object_ty, list), copy, name) + } + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) + .map_value(object.into_pointer_value(), None); + self.construct_numpy_array_ndarray_impl(generator, ctx, ndarray, copy, name) + } + + _ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object_ty)), // Typechecker ensures this + } + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 1ff6e2d1..dd41df67 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -24,6 +24,7 @@ pub use contiguous::*; pub use indexing::*; pub use nditer::*; +mod array; mod contiguous; pub mod factory; mod indexing; diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index bd115a2d..c497f8f8 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -8,7 +8,7 @@ use super::{ ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::{ - types::{structure::StructField, ListType}, + types::{structure::StructField, ListType, ProxyType}, {CodeGenContext, CodeGenerator}, }; @@ -116,6 +116,23 @@ impl<'ctx> ListValue<'ctx> { ) -> IntValue<'ctx> { self.len_field(ctx).get(ctx, self.value, name) } + + /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. + #[must_use] + pub fn as_i8_list( + &self, + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ) -> ListValue<'ctx> { + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_list_i8 = ::Type::new(generator, ctx.ctx, llvm_i8.into()); + + Self::from_pointer_value( + ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_base_type(), "").unwrap(), + self.llvm_usize, + self.name, + ) + } } impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index ffde76c9..d4a460a5 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -173,7 +173,7 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Stores the array of data elements `data` into this instance. - fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { let data = ctx .builder .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")