diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9b5af0f11..9c57919ca 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -3,7 +3,6 @@ use inkwell::{ values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; -use itertools::Itertools; use nac3parser::ast::{Operator, StrRef}; @@ -19,17 +18,28 @@ use super::{ llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, - types::{ndarray::NDArrayType, ListType, ProxyType}, + types::{ + ndarray::{ + factory::{ndarray_one_value, ndarray_zero_value}, + NDArrayType, + }, + ListType, ProxyType, + }, values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, - ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + ndarray::{shape::parse_numpy_int_sequence, NDArrayValue}, + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, + TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId}, + toplevel::{ + helper::{extract_ndims, PrimDef}, + numpy::unpack_ndarray_var_tys, + DefinitionId, + }, typecheck::{ magic_methods::Binop, typedef::{FunSignature, Type, TypeEnum}, @@ -174,132 +184,6 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( Ok(ndarray) } -fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i32_type().const_zero().into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "").into() - } else { - codegen_unreachable!(ctx) - } -} - -fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32); - ctx.ctx.i32_type().const_int(1, is_signed).into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64); - ctx.ctx.i64_type().const_int(1, is_signed).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_float(1.0).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_int(1, false).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "1").into() - } else { - codegen_unreachable!(ctx) - } -} - -/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -/// -/// ### Notes on `shape` -/// -/// Just like numpy, the `shape` argument can be: -/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` -/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` -/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` -/// -/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to -/// learn how `shape` gets from being a Python user expression to here. -fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - match shape { - BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() => - { - // 1. A list of ints; e.g., `np.empty([600, 800, 3])` - - let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None); - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &shape_list, - |_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)), - |generator, ctx, shape_list, idx| { - Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value()) - }, - ) - } - BasicValueEnum::StructValue(shape_tuple) => { - // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` - // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. - - // Get the length/size of the tuple, which also happens to be the value of `ndims`. - let ndims = shape_tuple.get_type().count_fields(); - - let shape = (0..ndims) - .map(|dim_i| { - ctx.builder - .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) - .map(BasicValueEnum::into_int_value) - .map(|v| { - ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap() - }) - .unwrap() - }) - .collect_vec(); - - create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) - } - BasicValueEnum::IntValue(shape_int) => { - // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` - let shape_int = - ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap(); - - create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) - } - _ => codegen_unreachable!(ctx), - } -} - /// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as /// its input. fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( @@ -529,107 +413,6 @@ where Ok(res) } -/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let supported_types = [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ctx.primitives.float, - ctx.primitives.bool, - ctx.primitives.str, - ]; - assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); - - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = ndarray_zero_value(generator, ctx, elem_ty); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let supported_types = [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ctx.primitives.float, - ctx.primitives.bool, - ctx.primitives.str, - ]; - assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); - - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = ndarray_one_value(generator, ctx, elem_ty); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.full`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, - fill_value: BasicValueEnum<'ctx>, -) -> Result, String> { - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = if fill_value.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - fill_value.into_pointer_value(), - fill_value.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if fill_value.is_int_value() || fill_value.is_float_value() { - fill_value - } else { - codegen_unreachable!(ctx) - }; - - Ok(value) - })?; - - Ok(ndarray) -} - /// Returns the number of dimensions for a multidimensional list as an [`IntValue`]. fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( generator: &G, @@ -1752,8 +1535,15 @@ pub fn gen_ndarray_empty<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_empty(generator, context, &shape, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.zeros`. @@ -1770,8 +1560,15 @@ pub fn gen_ndarray_zeros<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_zeros(generator, context, dtype, &shape, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.ones`. @@ -1788,8 +1585,15 @@ pub fn gen_ndarray_ones<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_ones(generator, context, dtype, &shape, None); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.full`. @@ -1809,8 +1613,15 @@ pub fn gen_ndarray_full<'ctx>( let fill_value_arg = args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; - call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let llvm_dtype = context.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); + + let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims)) + .construct_numpy_full(generator, context, &shape, fill_value_arg, None); + Ok(ndarray.as_base_value()) } pub fn gen_ndarray_array<'ctx>( diff --git a/nac3core/src/codegen/types/ndarray/factory.rs b/nac3core/src/codegen/types/ndarray/factory.rs new file mode 100644 index 000000000..13aae8cd5 --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/factory.rs @@ -0,0 +1,146 @@ +use inkwell::values::{BasicValueEnum, IntValue}; + +use super::NDArrayType; +use crate::{ + codegen::{ + irrt, types::ProxyType, values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator, + }, + typecheck::typedef::Type, +}; + +/// Get the zero value in `np.zeros()` of a `dtype`. +pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i32_type().const_zero().into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +/// Get the one value in `np.ones()` of a `dtype`. +pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32); + ctx.ctx.i32_type().const_int(1, is_signed).into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64); + ctx.ctx.i64_type().const_int(1, is_signed).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_float(1.0).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_int(1, false).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "1").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +impl<'ctx> NDArrayType<'ctx> { + /// Create an ndarray like + /// [`np.empty`](https://numpy.org/doc/stable/reference/generated/numpy.empty.html). + pub fn construct_numpy_empty( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + name: Option<&'ctx str>, + ) -> >::Value { + let ndarray = self.construct_uninitialized(generator, ctx, name); + + // Validate `shape` + irrt::ndarray::call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, shape); + + ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); + unsafe { ndarray.create_data(generator, ctx) }; + + ndarray + } + + /// Create an ndarray like + /// [`np.full`](https://numpy.org/doc/stable/reference/generated/numpy.full.html). + pub fn construct_numpy_full( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + fill_value: BasicValueEnum<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + let ndarray = self.construct_numpy_empty(generator, ctx, shape, name); + ndarray.fill(generator, ctx, fill_value); + ndarray + } + + /// Create an ndarray like + /// [`np.zero`](https://numpy.org/doc/stable/reference/generated/numpy.zeros.html). + pub fn construct_numpy_zeros( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + + let fill_value = ndarray_zero_value(generator, ctx, dtype); + self.construct_numpy_full(generator, ctx, shape, fill_value, name) + } + + /// Create an ndarray like + /// [`np.ones`](https://numpy.org/doc/stable/reference/generated/numpy.ones.html). + pub fn construct_numpy_ones( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, + name: Option<&'ctx str>, + ) -> >::Value { + assert_eq!( + ctx.get_llvm_type(generator, dtype), + self.dtype, + "Expected LLVM dtype={} but got {}", + self.dtype.print_to_string(), + ctx.get_llvm_type(generator, dtype).print_to_string(), + ); + + let fill_value = ndarray_one_value(generator, ctx, dtype); + self.construct_numpy_full(generator, ctx, shape, fill_value, name) + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 3886ce842..892416180 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -25,6 +25,7 @@ pub use indexing::*; pub use nditer::*; mod contiguous; +pub mod factory; mod indexing; mod nditer; diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 7ce8ed79f..9b71693a1 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -163,8 +163,13 @@ impl<'ctx> NDIterType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> >::Value { + assert!( + ndarray.get_type().ndims().is_some(), + "NDIter requires ndims of NDArray to be known." + ); + let nditer = self.raw_alloca_var(generator, ctx, None); - let ndims = ndarray.load_ndims(ctx); + let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims().unwrap(), false); // The caller has the responsibility to allocate 'indices' for `NDIter`. let indices = diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index e47876c6c..ffde76c9b 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -23,6 +23,7 @@ pub use nditer::*; mod contiguous; mod indexing; mod nditer; +pub mod shape; mod view; /// Proxy type for accessing an `NDArray` value in LLVM. @@ -397,6 +398,23 @@ impl<'ctx> NDArrayValue<'ctx> { irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self); } + /// Fill the ndarray with a scalar. + /// + /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. + pub fn fill( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) { + self.foreach(generator, ctx, |_, ctx, _, nditer| { + let p = nditer.get_pointer(ctx); + ctx.builder.build_store(p, value).unwrap(); + Ok(()) + }) + .unwrap(); + } + /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. #[must_use] pub fn is_unsized(&self) -> Option { diff --git a/nac3core/src/codegen/values/ndarray/shape.rs b/nac3core/src/codegen/values/ndarray/shape.rs new file mode 100644 index 000000000..190a1e4fc --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/shape.rs @@ -0,0 +1,152 @@ +use inkwell::values::{BasicValueEnum, IntValue}; + +use crate::{ + codegen::{ + stmt::gen_for_callback_incrementing, + types::{ListType, TupleType}, + values::{ + ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + TypedArrayLikeMutator, UntypedArrayLikeAccessor, + }, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length. +/// +/// * `sequence` - The `sequence` parameter. +/// * `sequence_ty` - The typechecker type of `sequence` +/// +/// The `sequence` argument type may only be one of the following: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to +/// `np.empty([3])` +/// +/// All `int32` values will be sign-extended to `SizeT`. +pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>), +) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let zero = llvm_usize.const_zero(); + let one = llvm_usize.const_int(1, false); + + // The result `list` to return. + match &*ctx.unifier.get_ty_immutable(input_seq_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` + + let input_seq = ListType::from_unifier_type(generator, ctx, input_seq_ty) + .map_value(input_seq.into_pointer_value(), None); + + let len = input_seq.load_size(ctx, None); + // TODO: Find a way to remove this mid-BB allocation + let result = ctx.builder.build_array_alloca(llvm_usize, len, "").unwrap(); + let result = TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val(result, len, None), + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + // Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result` + gen_for_callback_incrementing( + generator, + ctx, + None, + zero, + (len, false), + |generator, ctx, _, i| { + // Load the i-th int32 in the input sequence + let int = unsafe { + input_seq.data().get_unchecked(ctx, generator, &i, None).into_int_value() + }; + + // Cast to SizeT + let int = + ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap(); + + // Store + unsafe { result.set_typed_unchecked(ctx, generator, &i, int) }; + + Ok(()) + }, + one, + ) + .unwrap(); + + result + } + + TypeEnum::TTuple { .. } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + let input_seq = TupleType::from_unifier_type(generator, ctx, input_seq_ty) + .map_value(input_seq.into_struct_value(), None); + + let len = input_seq.get_type().num_elements(); + + let result = generator + .gen_array_var_alloc( + ctx, + llvm_usize.into(), + llvm_usize.const_int(u64::from(len), false), + None, + ) + .unwrap(); + let result = TypedArrayLikeAdapter::from( + result, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + + for i in 0..input_seq.get_type().num_elements() { + // Get the i-th element off of the tuple and load it into `result`. + let int = input_seq.load_element(ctx, i).into_int_value(); + let int = ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap(); + + unsafe { + result.set_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(u64::from(i), false), + int, + ); + } + } + + result + } + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + + let input_int = input_seq.into_int_value(); + + let len = one; + let result = generator.gen_array_var_alloc(ctx, llvm_usize.into(), len, None).unwrap(); + let result = TypedArrayLikeAdapter::from( + result, + |_, _, val| val.into_int_value(), + |_, _, val| val.into(), + ); + let int = + ctx.builder.build_int_s_extend_or_bit_cast(input_int, llvm_usize, "").unwrap(); + + // Storing into result[0] + unsafe { + result.set_typed_unchecked(ctx, generator, &zero, int); + } + + result + } + + _ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(input_seq_ty)), + } +}