From 937b36dcfd5289d7f769e3e054b3e429370c241e Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 8 Aug 2024 14:58:26 +0800 Subject: [PATCH] core/ndstrides: checkpoint 2 --- nac3core/src/codegen/builtin_fns.rs | 975 +++++++++++------- nac3core/src/codegen/irrt/mod.rs | 4 +- nac3core/src/codegen/irrt/util.rs | 4 +- nac3core/src/codegen/model/structure.rs | 6 +- nac3core/src/codegen/model/util.rs | 35 +- nac3core/src/codegen/numpy_new.rs | 14 +- nac3core/src/codegen/stmt.rs | 3 +- nac3core/src/codegen/structure/list.rs | 34 +- .../codegen/structure/ndarray/broadcast.rs | 94 +- .../codegen/structure/ndarray/functions.rs | 477 +++++++++ .../src/codegen/structure/ndarray/indexing.rs | 21 +- .../src/codegen/structure/ndarray/mapping.rs | 250 ++--- nac3core/src/codegen/structure/ndarray/mod.rs | 78 +- .../src/codegen/structure/ndarray/scalar.rs | 41 +- .../codegen/structure/ndarray/shape_util.rs | 116 +-- nac3core/src/toplevel/builtins.rs | 10 +- 16 files changed, 1459 insertions(+), 703 deletions(-) create mode 100644 nac3core/src/codegen/structure/ndarray/functions.rs diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index e5e74ade..89d5d710 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,22 +1,17 @@ -use inkwell::types::{BasicTypeEnum, IntType}; -use inkwell::values::{BasicValue, BasicValueEnum, FloatValue, IntValue, PointerValue}; +use inkwell::types::BasicTypeEnum; +use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use itertools::Itertools; use crate::codegen::classes::{ NDArrayValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; -use crate::codegen::llvm_intrinsics::call_float_rint; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::stmt::gen_for_callback_incrementing; -use crate::codegen::structure::ndarray::mapping::starmap_scalars_array_like; -use crate::codegen::structure::ndarray::NDArrayObject; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::toplevel::helper::PrimDef; use crate::toplevel::numpy::unpack_ndarray_var_tys; -use crate::typecheck::typedef::{Type, TypeEnum}; - -use super::structure::ndarray::scalar::{split_scalar_or_ndarray, ScalarOrNDArray}; +use crate::typecheck::typedef::Type; /// Shorthand for [`unreachable!()`] when a type of argument is not supported. /// @@ -28,74 +23,66 @@ fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) - ) } -fn handle_cast_to_int_conversion<'ctx, G, F>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (n_ty, n): (Type, BasicValueEnum<'ctx>), - func_name: &str, - int_type: Type, - handle_float: F, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - F: FnOnce(&mut G, &mut CodeGenContext<'ctx, '_>, FloatValue<'ctx>) -> IntValue<'ctx>, -{ - let llvm_int_type = ctx.get_llvm_type(generator, int_type).into_int_type(); - split_scalar_or_ndarray(generator, ctx, n, n_ty) - .map(generator, ctx, |generator, ctx, _i, scalar| { - let int = match () { - () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.bool) => { - // For booleans, simply extend its number of bits - let n = scalar.value.into_int_value(); - ctx.builder.build_int_z_extend(n, llvm_int_type, "").unwrap() - } - () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) => { - // For floats, cast it - let n = scalar.value.into_float_value(); - handle_float(generator, ctx, n) - } - () if ctx.unifier.unioned_any( - scalar.dtype, - [ - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ], - ) => - { - // For int32, int64, uint32, uint64, sign-extend or truncate - let n = scalar.value.into_int_value(); - if n.get_type().get_bit_width() <= llvm_int_type.get_bit_width() { - ctx.builder.build_int_s_extend(n, llvm_int_type, "").unwrap() - } else { - ctx.builder.build_int_truncate(n, llvm_int_type, "").unwrap() - } - } - () => unsupported_type(ctx, func_name, &[scalar.dtype]), - }; - Ok(ScalarObject { dtype: int_type, value: int.as_basic_value_enum() }) - }) - .map(ScalarOrNDArray::to_basic_value_enum) -} - /// Invokes the `int32` builtin function. pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - handle_cast_to_int_conversion( - generator, - ctx, - n, - "int32", - ctx.primitives.int32, - |_generator, ctx, n| { - let n = ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap(); - ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap() - }, - ) + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (n_ty, n) = n; + Ok(match n { + BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); + + ctx.builder.build_int_z_extend(n, llvm_i32, "zext").map(Into::into).unwrap() + } + + BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 32 => { + debug_assert!([ctx.primitives.int32, ctx.primitives.uint32,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); + + n.into() + } + + BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { + debug_assert!([ctx.primitives.int64, ctx.primitives.uint64,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); + + ctx.builder.build_int_truncate(n, llvm_i32, "trunc").map(Into::into).unwrap() + } + + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + + let to_int64 = + ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap(); + ctx.builder.build_int_truncate(to_int64, llvm_i32, "conv").map(Into::into).unwrap() + } + + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.int32, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), + )?; + + ndarray.as_base_value().into() + } + + _ => unsupported_type(ctx, "int32", &[n_ty]), + }) } /// Invokes the `int64` builtin function. @@ -104,16 +91,60 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - handle_cast_to_int_conversion( - generator, - ctx, - n, - "int64", - ctx.primitives.int64, - |_generator, ctx, n| { - ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap() - }, - ) + let llvm_i64 = ctx.ctx.i64_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (n_ty, n) = n; + + Ok(match n { + BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => { + debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); + + if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { + ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() + } else { + ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() + } + } + + BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { + debug_assert!([ctx.primitives.int64, ctx.primitives.uint64,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); + + n.into() + } + + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + + ctx.builder + .build_float_to_signed_int(n, ctx.ctx.i64_type(), "fptosi") + .map(Into::into) + .unwrap() + } + + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.int64, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), + )?; + + ndarray.as_base_value().into() + } + + _ => unsupported_type(ctx, "int64", &[n_ty]), + }) } /// Invokes the `uint32` builtin function. @@ -122,34 +153,76 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - handle_cast_to_int_conversion( - generator, - ctx, - n, - "uint32", - ctx.primitives.uint32, - |_generator, ctx, n| { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (n_ty, n) = n; + + Ok(match n { + BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); + + ctx.builder.build_int_z_extend(n, llvm_i32, "zext").map(Into::into).unwrap() + } + + BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 32 => { + debug_assert!([ctx.primitives.int32, ctx.primitives.uint32,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); + + n.into() + } + + BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { + debug_assert!( + ctx.unifier.unioned(n_ty, ctx.primitives.int64) + || ctx.unifier.unioned(n_ty, ctx.primitives.uint64) + ); + + ctx.builder.build_int_truncate(n, llvm_i32, "trunc").map(Into::into).unwrap() + } + + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + let n_gez = ctx .builder .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") .unwrap(); - let to_int32 = - ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap(); + let to_int32 = ctx.builder.build_float_to_signed_int(n, llvm_i32, "").unwrap(); let to_uint64 = ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); ctx.builder .build_select( n_gez, - ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(), + ctx.builder.build_int_truncate(to_uint64, llvm_i32, "").unwrap(), to_int32, "conv", ) .unwrap() - .into_int_value() - }, - ) + } + + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.uint32, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), + )?; + + ndarray.as_base_value().into() + } + + _ => unsupported_type(ctx, "uint32", &[n_ty]), + }) } /// Invokes the `uint64` builtin function. @@ -158,356 +231,518 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - handle_cast_to_int_conversion( - generator, - ctx, - n, - "uint64", - ctx.primitives.uint64, - |_generator, ctx, n| { + let llvm_i64 = ctx.ctx.i64_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (n_ty, n) = n; + + Ok(match n { + BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => { + debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); + + if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { + ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() + } else { + ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() + } + } + + BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { + debug_assert!([ctx.primitives.int64, ctx.primitives.uint64,] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); + + n.into() + } + + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + let val_gez = ctx .builder .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") .unwrap(); - let to_int64 = - ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap(); - let to_uint64 = - ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); + let to_int64 = ctx.builder.build_float_to_signed_int(n, llvm_i64, "").unwrap(); + let to_uint64 = ctx.builder.build_float_to_unsigned_int(n, llvm_i64, "").unwrap(); - ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap().into_int_value() - }, - ) + ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap() + } + + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.uint64, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), + )?; + + ndarray.as_base_value().into() + } + + _ => unsupported_type(ctx, "uint64", &[n_ty]), + }) } /// Invokes the `float` builtin function. pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - (n_ty, n): (Type, BasicValueEnum<'ctx>), + n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); - split_scalar_or_ndarray(generator, ctx, n, n_ty) - .map(generator, ctx, ctx.primitives.float, |_generator, ctx, _i, scalar| { - Ok(match () { - () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) => { - // Handle floats - scalar.value.into_float_value() - } - () if ctx.unifier.unioned_any( - scalar.dtype, - [ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.int64], - ) => - { - // Handle signed ints and booleans (treating booleans as signed ints for convenience) - let n = scalar.value.into_int_value(); - ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap() - } - () if ctx - .unifier - .unioned_any(scalar.dtype, [ctx.primitives.uint32, ctx.primitives.uint64]) => - { - // Handle unsigned ints - let n = scalar.value.into_int_value(); - ctx.builder - .build_unsigned_int_to_float(n, llvm_f64, "uitofp") - .map(Into::into) - .unwrap() - } - () => unsupported_type(ctx, "float", &[n_ty]), - } - .as_basic_value_enum()) - }) - .map(ScalarOrNDArray::to_basic_value_enum) -} + let llvm_usize = generator.get_size_type(ctx.ctx); -/// Invokes the `bool` builtin function. -pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - (n_ty, n): (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "bool"; + let (n_ty, n) = n; - split_scalar_or_ndarray(generator, ctx, n, n_ty) - .map(generator, ctx, ctx.primitives.float, |_generator, ctx, _i, scalar| { - Ok(match () { - () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.bool) => { - // Handle booleans - scalar.value.into_int_value() - } - () if ctx.unifier.unioned_any( - scalar.dtype, - [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ], - ) => - { - // Handle signed/unsigned ints - let n = scalar.value.into_int_value(); - ctx.builder - .build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME) - .unwrap() - } - () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) => { - // Handle floats - let f64_type = ctx.ctx.f64_type(); - let n = scalar.value.into_float_value(); - ctx.builder - .build_float_compare(FloatPredicate::UNE, n, f64_type.const_zero(), FN_NAME) - .unwrap() - } - () => unsupported_type(ctx, FN_NAME, &[n_ty]), + Ok(match n { + BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => { + debug_assert!([ + ctx.primitives.bool, + ctx.primitives.int32, + ctx.primitives.uint32, + ctx.primitives.int64, + ctx.primitives.uint64, + ] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); + + if [ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.int64] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty)) + { + ctx.builder + .build_signed_int_to_float(n, llvm_f64, "sitofp") + .map(Into::into) + .unwrap() + } else { + ctx.builder + .build_unsigned_int_to_float(n, llvm_f64, "uitofp") + .map(Into::into) + .unwrap() } - .as_basic_value_enum()) - }) - .map(ScalarOrNDArray::to_basic_value_enum) + } + + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + + n.into() + } + + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.float, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), + )?; + + ndarray.as_base_value().into() + } + + _ => unsupported_type(ctx, "float", &[n_ty]), + }) } /// Invokes the `round` builtin function. pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - (n_ty, n): (Type, BasicValueEnum<'ctx>), - ret_int_ty: Type, + n: (Type, BasicValueEnum<'ctx>), + ret_elem_ty: Type, ) -> Result, String> { const FN_NAME: &str = "round"; - let ret_int_ty_llvm = ctx.get_llvm_abi_type(generator, ret_int_ty).into_int_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); - split_scalar_or_ndarray(generator, ctx, n, n_ty) - .map(generator, ctx, ret_int_ty, |_generator, ctx, _i, scalar| { - Ok(match () { - () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) => { - let n = scalar.value.into_float_value(); - let n = llvm_intrinsics::call_float_round(ctx, n, None); - ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, FN_NAME).unwrap() - } - () => unsupported_type(ctx, FN_NAME, &[n_ty]), - } - .as_basic_value_enum()) - }) - .map(ScalarOrNDArray::to_basic_value_enum) + let (n_ty, n) = n; + let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty).into_int_type(); + + Ok(match n { + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + + let val = llvm_intrinsics::call_float_round(ctx, n, None); + ctx.builder + .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) + .map(Into::into) + .unwrap() + } + + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ret_elem_ty, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), + )?; + + ndarray.as_base_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[n_ty]), + }) } /// Invokes the `np_round` builtin function. pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - (n_ty, n): (Type, BasicValueEnum<'ctx>), + n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_round"; - split_scalar_or_ndarray(generator, ctx, n, n_ty) - .map(generator, ctx, ctx.primitives.float, |_generator, ctx, _i, scalar| { - Ok(match () { - () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) => { - let n = scalar.value.into_float_value(); - call_float_rint(ctx, n, None).as_basic_value_enum() - } - () => unsupported_type(ctx, FN_NAME, &[n_ty]), - }) - }) - .map(ScalarOrNDArray::to_basic_value_enum) + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (n_ty, n) = n; + + Ok(match n { + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + + llvm_intrinsics::call_float_rint(ctx, n, None).into() + } + + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.float, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), + )?; + + ndarray.as_base_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[n_ty]), + }) } -#[derive(Debug, Clone, Copy)] -pub enum CeilOrFloor { - Ceil, - Floor, -} - -/// Invokes the `ceil`/`floor` builtin function. -pub fn call_ceil_or_floor<'ctx, G: CodeGenerator + ?Sized>( +/// Invokes the `bool` builtin function. +pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - (n_ty, n): (Type, BasicValueEnum<'ctx>), - kind: CeilOrFloor, - ret_elem_ty: Type, // Can be float or int/bool + n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - // TODO: Bad Rust tuple type inference leads to this - let fn_name = match kind { - CeilOrFloor::Ceil => "ceil", - CeilOrFloor::Floor => "floor", - }; - let function = match kind { - CeilOrFloor::Ceil => llvm_intrinsics::call_float_ceil, - CeilOrFloor::Floor => llvm_intrinsics::call_float_floor, - }; + const FN_NAME: &str = "bool"; - split_scalar_or_ndarray(generator, ctx, n, n_ty) - .map(generator, ctx, ctx.primitives.float, |generator, ctx, _i, scalar| { - Ok(match () { - () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) => { - let n = scalar.value.into_float_value(); - let n = function(ctx, n, None); + let llvm_usize = generator.get_size_type(ctx.ctx); - if ctx.unifier.unioned(ret_elem_ty, ctx.primitives.float) { - // If `ret_elem_ty` is a float, return the floored float directly - n.as_basic_value_enum() - } else { - // Otherwise ret_elem_ty must be IntType, cast the floored float to int - let ret_elem_ty_llvm = - ctx.get_llvm_abi_type(generator, ret_elem_ty).into_int_type(); - ctx.builder - .build_float_to_signed_int(n, ret_elem_ty_llvm, fn_name) - .unwrap() - .as_basic_value_enum() - } - } - () => unsupported_type(ctx, fn_name, &[n_ty]), - }) - }) - .map(ScalarOrNDArray::to_basic_value_enum) + let (n_ty, n) = n; + + Ok(match n { + BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); + + n.into() + } + + BasicValueEnum::IntValue(n) => { + debug_assert!([ + ctx.primitives.int32, + ctx.primitives.uint32, + ctx.primitives.int64, + ctx.primitives.uint64, + ] + .iter() + .any(|ty| ctx.unifier.unioned(n_ty, *ty))); + + ctx.builder + .build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME) + .map(Into::into) + .unwrap() + } + + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + + ctx.builder + .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), FN_NAME) + .map(Into::into) + .unwrap() + } + + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.bool, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| { + let elem = call_bool(generator, ctx, (elem_ty, val))?; + + Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) + }, + )?; + + ndarray.as_base_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[n_ty]), + }) } -#[derive(Debug, Clone, Copy)] -pub enum MinOrMax { - Min, - Max, -} - -/// Invokes the `min`/`max` builtin function. -pub fn call_min_or_max<'ctx>( +/// Invokes the `floor` builtin function. +pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - (m_ty, m): (Type, BasicValueEnum<'ctx>), - (n_ty, n): (Type, BasicValueEnum<'ctx>), - kind: MinOrMax, + n: (Type, BasicValueEnum<'ctx>), + ret_elem_ty: Type, +) -> Result, String> { + const FN_NAME: &str = "floor"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (n_ty, n) = n; + let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); + + Ok(match n { + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + + let val = llvm_intrinsics::call_float_floor(ctx, n, None); + if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { + ctx.builder + .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) + .map(Into::into) + .unwrap() + } else { + val.into() + } + } + + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ret_elem_ty, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), + )?; + + ndarray.as_base_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[n_ty]), + }) +} + +/// Invokes the `ceil` builtin function. +pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + n: (Type, BasicValueEnum<'ctx>), + ret_elem_ty: Type, +) -> Result, String> { + const FN_NAME: &str = "ceil"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (n_ty, n) = n; + let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); + + Ok(match n { + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + + let val = llvm_intrinsics::call_float_ceil(ctx, n, None); + if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { + ctx.builder + .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) + .map(Into::into) + .unwrap() + } else { + val.into() + } + } + + BasicValueEnum::PointerValue(n) + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ret_elem_ty, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), + )?; + + ndarray.as_base_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[n_ty]), + }) +} + +/// Invokes the `min` builtin function. +pub fn call_min<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, + m: (Type, BasicValueEnum<'ctx>), + n: (Type, BasicValueEnum<'ctx>), ) -> BasicValueEnum<'ctx> { - // TODO: Bad Rust type inference leads to this - let fn_name = match kind { - MinOrMax::Min => "min", - MinOrMax::Max => "max", - }; - let signedfn = match kind { - MinOrMax::Min => llvm_intrinsics::call_int_smin, - MinOrMax::Max => llvm_intrinsics::call_int_smax, - }; - let unsignedfn = match kind { - MinOrMax::Min => llvm_intrinsics::call_int_umin, - MinOrMax::Max => llvm_intrinsics::call_int_umax, - }; - let floatfn = match kind { - MinOrMax::Min => llvm_intrinsics::call_float_minnum, - MinOrMax::Max => llvm_intrinsics::call_float_maxnum, - }; + const FN_NAME: &str = "min"; + + let (m_ty, m) = m; + let (n_ty, n) = n; let common_ty = if ctx.unifier.unioned(m_ty, n_ty) { m_ty } else { - unsupported_type(ctx, fn_name, &[m_ty, n_ty]) + unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]) }; - match () { - () if ctx.unifier.unioned_any( - common_ty, - [ctx.primitives.bool, ctx.primitives.uint32, ctx.primitives.uint64], - ) => - { - // Handle unsigned ints and booleans (treating booleans as unsigned ints) - unsignedfn(ctx, m.into_int_value(), n.into_int_value(), Some(fn_name)).into() + match (m, n) { + (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { + debug_assert!([ + ctx.primitives.bool, + ctx.primitives.int32, + ctx.primitives.uint32, + ctx.primitives.int64, + ctx.primitives.uint64, + ] + .iter() + .any(|ty| ctx.unifier.unioned(common_ty, *ty))); + + if [ctx.primitives.int32, ctx.primitives.int64] + .iter() + .any(|ty| ctx.unifier.unioned(common_ty, *ty)) + { + llvm_intrinsics::call_int_smin(ctx, m, n, Some(FN_NAME)).into() + } else { + llvm_intrinsics::call_int_umin(ctx, m, n, Some(FN_NAME)).into() + } } - () if ctx.unifier.unioned_any(common_ty, [ctx.primitives.int32, ctx.primitives.int64]) => { - // Handle signed ints - signedfn(ctx, m.into_int_value(), n.into_int_value(), Some(fn_name)).into() + + (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { + debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float)); + + llvm_intrinsics::call_float_minnum(ctx, m, n, Some(FN_NAME)).into() } - () if ctx.unifier.unioned(common_ty, ctx.primitives.float) => { - // Handle floats - floatfn(ctx, m.into_float_value(), n.into_float_value(), Some(fn_name)).into() - } - () => unsupported_type(ctx, fn_name, &[m_ty, n_ty]), + + _ => unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]), } } -/// Invokes the `np_minimum`/`np_maximum` builtin function. -pub fn call_numpy_minimum_or_maximum<'ctx, G: CodeGenerator + ?Sized>( +/// Invokes the `np_minimum` builtin function. +pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - (x1_ty, x1): (Type, BasicValueEnum<'ctx>), - (x2_ty, x2): (Type, BasicValueEnum<'ctx>), - kind: MinOrMax, + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - let fn_name = match kind { - MinOrMax::Min => "np_minimum", - MinOrMax::Max => "np_maximum", - }; + const FN_NAME: &str = "np_minimum"; - // starmap_scalars_array_like(generator, ctx, inputs, ret_dtype, mapping) + let (x1_ty, x1) = x1; + let (x2_ty, x2) = x2; - todo!() + let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None }; - // let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None }; + Ok(match (x1, x2) { + (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { + debug_assert!([ + ctx.primitives.bool, + ctx.primitives.int32, + ctx.primitives.uint32, + ctx.primitives.int64, + ctx.primitives.uint64, + ctx.primitives.float, + ] + .iter() + .any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); - // Ok(match (x1, x2) { - // (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { - // debug_assert!([ - // ctx.primitives.bool, - // ctx.primitives.int32, - // ctx.primitives.uint32, - // ctx.primitives.int64, - // ctx.primitives.uint64, - // ctx.primitives.float, - // ] - // .iter() - // .any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); + call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) + } - // call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) - // } + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); - // (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - // debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); + call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) + } - // call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) - // } + (x1, x2) + if [&x1_ty, &x2_ty].into_iter().any(|ty| { + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) + }) => + { + let is_ndarray1 = + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let is_ndarray2 = + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - // (x1, x2) - // if [&x1_ty, &x2_ty].into_iter().any(|ty| { - // ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - // }) => - // { - // let is_ndarray1 = - // x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - // let is_ndarray2 = - // x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let dtype = if is_ndarray1 && is_ndarray2 { + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - // let dtype = if is_ndarray1 && is_ndarray2 { - // let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - // let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - // debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + ndarray_dtype1 + } else if is_ndarray1 { + unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + } else if is_ndarray2 { + unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + } else { + unreachable!() + }; - // ndarray_dtype1 - // } else if is_ndarray1 { - // unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - // } else if is_ndarray2 { - // unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - // } else { - // unreachable!() - // }; + let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; + let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - // let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - // let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; + numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + dtype, + None, + (x1, !is_ndarray1), + (x2, !is_ndarray2), + |generator, ctx, (lhs, rhs)| { + call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + }, + )? + .as_base_value() + .into() + } - // numpy::ndarray_elementwise_binop_impl( - // generator, - // ctx, - // dtype, - // None, - // (x1, !is_ndarray1), - // (x2, !is_ndarray2), - // |generator, ctx, (lhs, rhs)| { - // call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - // }, - // )? - // .as_base_value() - // .into() - // } - - // _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - // }) + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + }) } /// Invokes the `max` builtin function. @@ -671,8 +906,6 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]), }; ctx.builder.build_store(res_idx, updated_idx).unwrap(); - ctx.builder.build_store(accumulator_addr, result).unwrap(); - Ok(()) }, llvm_int64.const_int(1, false), diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index b2ffaad4..980c2767 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -953,9 +953,9 @@ pub fn call_ndarray_calc_broadcast_index< ) } -pub fn call_nac3_throw_dummy_error<'ctx, G: CodeGenerator + ?Sized>( +pub fn call_nac3_throw_dummy_error( generator: &mut G, - ctx: &CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'_, '_>, ) { let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_throw_dummy_error"); CallFunction::begin(generator, ctx, &name).returning_void(); diff --git a/nac3core/src/codegen/irrt/util.rs b/nac3core/src/codegen/irrt/util.rs index 2a75a4c6..124b5ef1 100644 --- a/nac3core/src/codegen/irrt/util.rs +++ b/nac3core/src/codegen/irrt/util.rs @@ -3,9 +3,9 @@ use crate::codegen::{CodeGenContext, CodeGenerator}; // When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}". // When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64". #[must_use] -pub fn get_sizet_dependent_function_name<'ctx, G: CodeGenerator + ?Sized>( +pub fn get_sizet_dependent_function_name( generator: &mut G, - ctx: &CodeGenContext<'ctx, '_>, + ctx: &CodeGenContext<'_, '_>, name: &str, ) -> String { let mut name = name.to_owned(); diff --git a/nac3core/src/codegen/model/structure.rs b/nac3core/src/codegen/model/structure.rs index 7baa0aed..8e70204e 100644 --- a/nac3core/src/codegen/model/structure.rs +++ b/nac3core/src/codegen/model/structure.rs @@ -48,9 +48,7 @@ struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> { field_types: Vec>, } -impl<'ctx, 'a, 'b, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> - for TypeFieldTraversal<'ctx, 'a, G> -{ +impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> { type Out = (); fn add>(&mut self, _name: &'static str, model: M) -> Self::Out { @@ -204,6 +202,6 @@ impl<'ctx, S: StructKind<'ctx>> Ptr<'ctx, StructModel> { M: Model<'ctx>, GetField: FnOnce(S::Fields) -> GepField, { - self.gep(ctx, get_field).store(ctx, value) + self.gep(ctx, get_field).store(ctx, value); } } diff --git a/nac3core/src/codegen/model/util.rs b/nac3core/src/codegen/model/util.rs index f59f6950..2b09a259 100644 --- a/nac3core/src/codegen/model/util.rs +++ b/nac3core/src/codegen/model/util.rs @@ -27,14 +27,14 @@ pub fn call_memcpy_model<'ctx, Item: Model<'ctx> + Default, G: CodeGenerator + ? /// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions. /// The [`IntKind`] is automatically inferred. -pub fn gen_for_model_auto<'ctx, 'a, G, F, I, R>( +pub fn gen_for_model_auto<'ctx, 'a, G, F, I>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, start: Int<'ctx, I>, stop: Int<'ctx, I>, step: Int<'ctx, I>, body: F, -) -> Result +) -> Result<(), String> where G: CodeGenerator + ?Sized, F: FnOnce( @@ -42,7 +42,7 @@ where &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks<'ctx>, Int<'ctx, I>, - ) -> Result, + ) -> Result<(), String>, I: IntKind<'ctx> + Default, { let int_model = IntModel(I::default()); @@ -60,3 +60,32 @@ where step.value, ) } + +/// Like [`gen_if_callback`] with [`Model`] abstractions and without the `else` block. +pub fn gen_if_model<'ctx, 'a, G, ThenFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + cond: Int<'ctx, Bool>, + then: ThenFn, +) -> Result<(), String> +where + G: CodeGenerator + ?Sized, + ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>, +{ + let current_bb = ctx.builder.get_insert_block().unwrap(); + let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.then"); + let end_bb = ctx.ctx.insert_basic_block_after(then_bb, "if.end"); + + // Inserting into `current_bb`. + ctx.builder.build_conditional_branch(cond.value, then_bb, end_bb).unwrap(); + + // Inserting into `then_bb` + ctx.builder.position_at_end(then_bb); + then(generator, ctx)?; + ctx.builder.build_unconditional_branch(end_bb).unwrap(); + + // Reposition to `end_bb` for continuation. + ctx.builder.position_at_end(end_bb); + + Ok(()) +} diff --git a/nac3core/src/codegen/numpy_new.rs b/nac3core/src/codegen/numpy_new.rs index 551ed7b0..48917551 100644 --- a/nac3core/src/codegen/numpy_new.rs +++ b/nac3core/src/codegen/numpy_new.rs @@ -99,8 +99,7 @@ fn create_empty_ndarray<'ctx, G>( where G: CodeGenerator + ?Sized, { - let shape = parse_numpy_int_sequence(generator, ctx, shape, shape_ty); - let shape = shape.value.get(generator, ctx, |f| f.items, "shape"); + let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty); let ndarray = NDArrayObject::alloca_uninitialized_of_type(generator, ctx, ndarray_ty, "ndarray"); @@ -248,8 +247,7 @@ pub fn gen_ndarray_broadcast_to<'ctx>( split_scalar_or_ndarray(generator, ctx, input, input_ty).as_ndarray(generator, ctx); // Process `shape` - let broadcast_shape = parse_numpy_int_sequence(generator, ctx, shape, shape_ty); - let broadcast_shape = broadcast_shape.value.get(generator, ctx, |f| f.items, "shape"); + let (_, broadcast_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty); // NOTE: shape.size should equal to `broadcasted_ndims`. let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims); call_nac3_ndarray_util_assert_shape_no_negative( @@ -300,8 +298,7 @@ pub fn gen_ndarray_reshape<'ctx>( // Process the shape input from user and resolve negative indices. // The resulting `new_shape`'s size should be equal to reshaped_ndims. // This is ensured by the typechecker. - let new_shape = parse_numpy_int_sequence(generator, ctx, shape, shape_ty); - let new_shape = new_shape.value.get(generator, ctx, |f| f.items, "new_shape"); + let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty); // Resolve unknown dimensions & validate `new_shape`. let new_ndims = sizet_model.constant(generator, ctx.ctx, reshaped_ndims); @@ -353,7 +350,7 @@ pub fn gen_ndarray_arange<'ctx>( // Create data and set elements ndarray.create_data(generator, ctx); - ndarray.foreach(generator, ctx, |_generator, ctx, _hooks, i, pelement| { + ndarray.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, i, pelement| { let val = ctx.builder.build_unsigned_int_to_float(i.value, ctx.ctx.f64_type(), "val").unwrap(); ctx.builder.build_store(pelement, val).unwrap(); @@ -495,10 +492,9 @@ pub fn gen_ndarray_transpose<'ctx>( // Parse argument #2 axes let in_axes_ty = fun.0.args[1].ty; let in_axes = args[1].1.clone().to_basic_value_enum(ctx, generator, in_axes_ty)?; - let in_axes = parse_numpy_int_sequence(generator, ctx, in_axes, in_axes_ty); + let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes, in_axes_ty); let num_axes = ndarray.get_ndims(generator, ctx.ctx); - let axes = in_axes.value.get(generator, ctx, |f| f.items, "axes"); call_nac3_ndarray_transpose( generator, diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 8569b021..7453ea8c 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -1,6 +1,5 @@ use super::model::*; use super::structure::cslice::CSlice; -use super::structure::ndarray::broadcast::broadcast_all_ndarrays; use super::{ super::symbol_resolver::ValueEnum, expr::destructure_range, @@ -438,7 +437,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( let value = split_scalar_or_ndarray(generator, ctx, value, value_ty).as_ndarray(generator, ctx); - let broadcast_result = broadcast_all_ndarrays(generator, ctx, &vec![target, value]); + let broadcast_result = NDArrayObject::broadcast_all(generator, ctx, &[target, value]); let target = broadcast_result.ndarrays[0]; let value = broadcast_result.ndarrays[1]; diff --git a/nac3core/src/codegen/structure/list.rs b/nac3core/src/codegen/structure/list.rs index 65a07fce..a3d232db 100644 --- a/nac3core/src/codegen/structure/list.rs +++ b/nac3core/src/codegen/structure/list.rs @@ -33,44 +33,38 @@ impl<'ctx, Item: Model<'ctx>, Size: IntKind<'ctx>> StructKind<'ctx> for List, Size: IntKind<'ctx>> { +/// A NAC3 Python List object. +pub struct ListObject<'ctx> { /// Typechecker type of the list items pub item_type: Type, - pub value: Ptr<'ctx, StructModel>>, + pub value: Ptr<'ctx, StructModel, SizeT>>>, } -impl<'ctx, Item: Model<'ctx>, Len: IntKind<'ctx>> ListObject<'ctx, Item, Len> { +impl<'ctx> ListObject<'ctx> { /// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`]. - /// - /// - The `Item` model has to be manually provided, and should match the - /// `get_llvm_type()` of `ty` and the `get_type()`. You may want to use - /// [`AnyModel`] if `ty`'s type is not knowable statically. pub fn from_value_and_type, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - val: V, - ty: Type, - item_model: Item, - len_model: Len, + list_val: V, + list_type: Type, ) -> Self { - let plist_model = PtrModel(StructModel(List { item: item_model, len: len_model })); - // Check typechecker type and extract `item_type` - let item_type = match &*ctx.unifier.get_ty(ty) { + let item_type = match &*ctx.unifier.get_ty(list_type) { TypeEnum::TObj { obj_id, params, .. } if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => { iter_type_vars(params).next().unwrap().ty // Extract `item_type` } - _ => panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(ty)), + _ => { + panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(list_type)) + } }; - // LLVM types of `item_model` and `ty` should match - let llvm_ty = ctx.get_llvm_type(generator, ty); - item_model.check_type(generator, ctx.ctx, llvm_ty).unwrap(); + let item_model = AnyModel(ctx.get_llvm_type(generator, item_type)); + let plist_model = PtrModel(StructModel(List { item: item_model, len: SizeT })); // Create object - let val = plist_model.check_value(generator, ctx.ctx, val).unwrap(); - ListObject { item_type: item_type, value: val } + let value = plist_model.check_value(generator, ctx.ctx, list_val).unwrap(); + ListObject { item_type, value } } } diff --git a/nac3core/src/codegen/structure/ndarray/broadcast.rs b/nac3core/src/codegen/structure/ndarray/broadcast.rs index 857b0240..54a41e5b 100644 --- a/nac3core/src/codegen/structure/ndarray/broadcast.rs +++ b/nac3core/src/codegen/structure/ndarray/broadcast.rs @@ -71,62 +71,64 @@ pub struct BroadcastAllResult<'ctx> { pub ndarrays: Vec>, } -// TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently. -pub fn broadcast_all_ndarrays<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndarrays: &Vec>, -) -> BroadcastAllResult<'ctx> { - assert!(!ndarrays.is_empty()); +impl<'ctx> NDArrayObject<'ctx> { + // TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently. + pub fn broadcast_all( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarrays: &[Self], + ) -> BroadcastAllResult<'ctx> { + assert!(!ndarrays.is_empty()); - let sizet_model = IntModel(SizeT); - let shape_model = StructModel(ShapeEntry); + let sizet_model = IntModel(SizeT); + let shape_model = StructModel(ShapeEntry); - let broadcast_ndims = get_broadcast_all_ndims(ndarrays.iter().map(|ndarray| ndarray.ndims)); + let broadcast_ndims = get_broadcast_all_ndims(ndarrays.iter().map(|ndarray| ndarray.ndims)); - // Prepare input shape entries - let num_shape_entries = - sizet_model.constant(generator, ctx.ctx, u64::try_from(ndarrays.len()).unwrap()); - let shape_entries = - shape_model.array_alloca(generator, ctx, num_shape_entries.value, "shape_entries"); - for (i, ndarray) in ndarrays.iter().enumerate() { - let i = sizet_model.constant(generator, ctx.ctx, i as u64).value; + // Prepare input shape entries + let num_shape_entries = + sizet_model.constant(generator, ctx.ctx, u64::try_from(ndarrays.len()).unwrap()); + let shape_entries = + shape_model.array_alloca(generator, ctx, num_shape_entries.value, "shape_entries"); + for (i, ndarray) in ndarrays.iter().enumerate() { + let i = sizet_model.constant(generator, ctx.ctx, i as u64).value; - let shape_entry = shape_entries.offset(generator, ctx, i, "shape_entry"); + let shape_entry = shape_entries.offset(generator, ctx, i, "shape_entry"); - let this_ndims = ndarray.value.get(generator, ctx, |f| f.ndims, "this_ndims"); - shape_entry.set(ctx, |f| f.ndims, this_ndims); + let this_ndims = ndarray.value.get(generator, ctx, |f| f.ndims, "this_ndims"); + shape_entry.set(ctx, |f| f.ndims, this_ndims); - let this_shape = ndarray.value.get(generator, ctx, |f| f.shape, "this_shape"); - shape_entry.set(ctx, |f| f.shape, this_shape); - } + let this_shape = ndarray.value.get(generator, ctx, |f| f.shape, "this_shape"); + shape_entry.set(ctx, |f| f.shape, this_shape); + } - // Prepare destination - let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims); - let broadcast_shape = - sizet_model.array_alloca(generator, ctx, broadcast_ndims_llvm.value, "dst_shape"); + // Prepare destination + let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims); + let broadcast_shape = + sizet_model.array_alloca(generator, ctx, broadcast_ndims_llvm.value, "dst_shape"); - // Compute the target broadcast shape `dst_shape` for all ndarrays. - call_nac3_ndarray_broadcast_shapes( - generator, - ctx, - num_shape_entries, - shape_entries, - broadcast_ndims_llvm, - broadcast_shape, - ); + // Compute the target broadcast shape `dst_shape` for all ndarrays. + call_nac3_ndarray_broadcast_shapes( + generator, + ctx, + num_shape_entries, + shape_entries, + broadcast_ndims_llvm, + broadcast_shape, + ); - // Now that we know about the broadcasting shape, broadcast all the inputs. + // Now that we know about the broadcasting shape, broadcast all the inputs. - // Broadcast all the inputs to shape `dst_shape`. - let broadcast_ndarrays: Vec<_> = ndarrays - .into_iter() - .map(|ndarray| ndarray.broadcast_to(generator, ctx, broadcast_ndims, broadcast_shape)) - .collect_vec(); + // Broadcast all the inputs to shape `dst_shape`. + let broadcast_ndarrays: Vec<_> = ndarrays + .iter() + .map(|ndarray| ndarray.broadcast_to(generator, ctx, broadcast_ndims, broadcast_shape)) + .collect_vec(); - BroadcastAllResult { - ndims: broadcast_ndims, - shape: broadcast_shape, - ndarrays: broadcast_ndarrays, + BroadcastAllResult { + ndims: broadcast_ndims, + shape: broadcast_shape, + ndarrays: broadcast_ndarrays, + } } } diff --git a/nac3core/src/codegen/structure/ndarray/functions.rs b/nac3core/src/codegen/structure/ndarray/functions.rs new file mode 100644 index 00000000..b594d1d5 --- /dev/null +++ b/nac3core/src/codegen/structure/ndarray/functions.rs @@ -0,0 +1,477 @@ +use inkwell::{ + values::{BasicValue, FloatValue, IntValue}, + FloatPredicate, IntPredicate, +}; +use itertools::Itertools; + +use crate::{ + codegen::{ + llvm_intrinsics, + model::{ + util::{gen_for_model_auto, gen_if_model}, + *, + }, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::Type, +}; + +use super::{scalar::ScalarObject, NDArrayObject}; + +/// Convenience function to crash the program when types of arguments are not supported. +/// Used to be debugged with a stacktrace. +fn unsupported_type(ctx: &CodeGenContext<'_, '_>, tys: I) -> ! +where + I: IntoIterator, +{ + unreachable!( + "unsupported types found '{}'", + tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "), + ) +} + +#[derive(Debug, Clone, Copy)] +pub enum FloorOrCeil { + Floor, + Ceil, +} + +#[derive(Debug, Clone, Copy)] +pub enum MinOrMax { + Min, + Max, +} + +fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec { + vec![ctx.primitives.int32, ctx.primitives.int64] +} + +fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec { + vec![ctx.primitives.uint32, ctx.primitives.uint64] +} + +fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec { + vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64] +} + +fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec { + vec![ + ctx.primitives.bool, + ctx.primitives.int32, + ctx.primitives.int64, + ctx.primitives.uint32, + ctx.primitives.uint64, + ] +} + +fn cast_to_int_conversion<'ctx, 'a, G, HandleFloatFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + scalar: ScalarObject<'ctx>, + target_int_dtype: Type, + handle_float: HandleFloatFn, +) -> ScalarObject<'ctx> +where + G: CodeGenerator + ?Sized, + HandleFloatFn: + FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>, +{ + let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type(); + + let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) { + // Special handling for floats + let n = scalar.value.into_float_value(); + handle_float(generator, ctx, n) + } else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) { + let n = scalar.value.into_int_value(); + + if n.get_type().get_bit_width() <= target_int_dtype_llvm.get_bit_width() { + ctx.builder.build_int_z_extend(n, target_int_dtype_llvm, "zext").unwrap() + } else { + ctx.builder.build_int_truncate(n, target_int_dtype_llvm, "trunc").unwrap() + } + } else { + unsupported_type(ctx, [scalar.dtype]); + }; + + assert_eq!(target_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check + ScalarObject { value: result.into(), dtype: target_int_dtype } +} + +impl<'ctx> ScalarObject<'ctx> { + /// Compare two scalars. Only int-to-int and float-to-float comparisons are allowed. + /// Panic otherwise. + pub fn compare( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + lhs: ScalarObject<'ctx>, + rhs: ScalarObject<'ctx>, + int_predicate: IntPredicate, + float_predicate: FloatPredicate, + name: &str, + ) -> Int<'ctx, Bool> { + if !ctx.unifier.unioned(lhs.dtype, rhs.dtype) { + unsupported_type(ctx, [lhs.dtype, rhs.dtype]); + } + + let bool_model = IntModel(Bool); + + let common_ty = lhs.dtype; + let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) { + let lhs = lhs.value.into_float_value(); + let rhs = rhs.value.into_float_value(); + ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap() + } else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) { + let lhs = lhs.value.into_int_value(); + let rhs = rhs.value.into_int_value(); + ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap() + } else { + unsupported_type(ctx, [lhs.dtype, rhs.dtype]); + }; + + bool_model.check_value(generator, ctx.ctx, result).unwrap() + } + + /// Invoke NAC3's builtin `int32()`. + #[must_use] + pub fn cast_to_int32( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Self { + cast_to_int_conversion( + generator, + ctx, + *self, + ctx.primitives.int32, + |_generator, ctx, input| { + let n = + ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap(); + ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap() + }, + ) + } + + /// Invoke NAC3's builtin `int64()`. + #[must_use] + pub fn cast_to_int64( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Self { + cast_to_int_conversion( + generator, + ctx, + *self, + ctx.primitives.int64, + |_generator, ctx, input| { + ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap() + }, + ) + } + + /// Invoke NAC3's builtin `uint32()`. + #[must_use] + pub fn cast_to_uint32( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Self { + cast_to_int_conversion( + generator, + ctx, + *self, + ctx.primitives.uint32, + |_generator, ctx, n| { + let n_gez = ctx + .builder + .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") + .unwrap(); + + let to_int32 = + ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap(); + let to_uint64 = + ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); + + ctx.builder + .build_select( + n_gez, + ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(), + to_int32, + "conv", + ) + .unwrap() + .into_int_value() + }, + ) + } + + /// Invoke NAC3's builtin `uint64()`. + #[must_use] + pub fn cast_to_uint64( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Self { + cast_to_int_conversion( + generator, + ctx, + *self, + ctx.primitives.uint64, + |_generator, ctx, n| { + let val_gez = ctx + .builder + .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") + .unwrap(); + + let to_int64 = + ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap(); + let to_uint64 = + ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); + ctx.builder + .build_select(val_gez, to_uint64, to_int64, "conv") + .unwrap() + .into_int_value() + }, + ) + } + + /// Invoke NAC3's builtin `bool()`. + #[must_use] + pub fn cast_to_bool( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Self { + // TODO: Why is the original code being so lax about i1 and i8 for the returned int type? + let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) { + self.value.into_int_value() + } else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) { + let n = self.value.into_int_value(); + ctx.builder + .build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool") + .unwrap() + } else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { + let n = self.value.into_float_value(); + ctx.builder + .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool") + .unwrap() + } else { + unsupported_type(ctx, [self.dtype]) + }; + + ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() } + } + + /// Invoke NAC3's builtin `round()`. + #[must_use] + pub fn round( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + target_int_dtype: Type, + ) -> Self { + let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type(); + + let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { + let n = self.value.into_float_value(); + let n = llvm_intrinsics::call_float_round(ctx, n, None); + ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "round").unwrap() + } else { + unsupported_type(ctx, [self.dtype, target_int_dtype]) + }; + ScalarObject { dtype: target_int_dtype, value: result.as_basic_value_enum() } + } + + /// Invoke NAC3's builtin `np_round()`. + /// + /// NOTE: `np.round()` has different behaviors than `round()` in terms of their result + /// on "tie" cases and return type. + #[must_use] + pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { + let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { + let n = self.value.into_float_value(); + llvm_intrinsics::call_float_rint(ctx, n, None) + } else { + unsupported_type(ctx, [self.dtype]) + }; + ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() } + } + + /// Invoke NAC3's builtin `min()` or `max()`. + fn min_or_max_helper( + ctx: &mut CodeGenContext<'ctx, '_>, + kind: MinOrMax, + a: Self, + b: Self, + ) -> Self { + if !ctx.unifier.unioned(a.dtype, b.dtype) { + unsupported_type(ctx, [a.dtype, b.dtype]) + } + + let common_dtype = a.dtype; + + if ctx.unifier.unioned(common_dtype, ctx.primitives.float) { + let function = match kind { + MinOrMax::Min => llvm_intrinsics::call_float_minnum, + MinOrMax::Max => llvm_intrinsics::call_float_maxnum, + }; + let result = + function(ctx, a.value.into_float_value(), b.value.into_float_value(), None); + ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float } + } else if ctx.unifier.unioned_any( + common_dtype, + [unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(), + ) { + // Treating bool has an unsigned int since that is convenient + let function = match kind { + MinOrMax::Min => llvm_intrinsics::call_int_umin, + MinOrMax::Max => llvm_intrinsics::call_int_umax, + }; + let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None); + ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype } + } else { + unsupported_type(ctx, [common_dtype]) + } + } + + /// Invoke NAC3's builtin `floor()` or `ceil()`. + #[must_use] + pub fn floor_or_ceil( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + kind: FloorOrCeil, + target_int_dtype: Type, + ) -> Self { + let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type(); + + if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { + let function = match kind { + FloorOrCeil::Floor => llvm_intrinsics::call_float_floor, + FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil, + }; + let n = self.value.into_float_value(); + let n = function(ctx, n, None); + let n = ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "").unwrap(); + ScalarObject { dtype: target_int_dtype, value: n.as_basic_value_enum() } + } else { + unsupported_type(ctx, [self.dtype]) + } + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Helper function for NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`. + /// + /// Generate LLVM IR to find the extremum and index of the **first** extremum value. + /// + /// Care has also been taken to make the error messages match that of NumPy. + fn min_or_max_helper( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + kind: MinOrMax, + on_empty_err_msg: &str, + ) -> (ScalarObject<'ctx>, Int<'ctx, SizeT>) { + let sizet_model = IntModel(SizeT); + let dtype_llvm = ctx.get_llvm_type(generator, self.dtype); + + // If the ndarray is empty, throw an error. + let is_empty = self.is_empty(generator, ctx); + ctx.make_assert( + generator, + is_empty.value, + "0:ValueError", + on_empty_err_msg, + [None, None, None], + ctx.current_loc, + ); + + // Setup and initialize the extremum to be the first element in the ndarray + let pextremum_index = sizet_model.alloca(generator, ctx, "extremum_index"); + let pextremum = ctx.builder.build_alloca(dtype_llvm, "extremum").unwrap(); + + let zero = sizet_model.const_0(generator, ctx.ctx); + pextremum_index.store(ctx, zero); + + let first_scalar = self.get_nth(generator, ctx, zero); + ctx.builder.build_store(pextremum, first_scalar.value).unwrap(); + + // Find extremum + let start = sizet_model.const_1(generator, ctx.ctx); // Start on 1 + let stop = self.size(generator, ctx); + let step = sizet_model.const_1(generator, ctx.ctx); + gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, _hooks, i| { + // Worth reading on "Notes" in + // on how `NaN` values have to be handled. + + let scalar = self.get_nth(generator, ctx, i); + + let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap(); + let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum }; + + let new_extremum = ScalarObject::min_or_max_helper(ctx, kind, old_extremum, scalar); + + // Check if new_extremum is more extreme than old_extremum. + let update_index = ScalarObject::compare( + generator, + ctx, + new_extremum, + old_extremum, + IntPredicate::NE, + FloatPredicate::ONE, + "", + ); + + gen_if_model(generator, ctx, update_index, |_generator, ctx| { + pextremum_index.store(ctx, i); + Ok(()) + }) + .unwrap(); + Ok(()) + }) + .unwrap(); + + // Finally return the extremum and extremum index. + let extremum_index = pextremum_index.load(generator, ctx, "extremum_index"); + + let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap(); + let extremum = ScalarObject { dtype: self.dtype, value: extremum }; + + (extremum, extremum_index) + } + + /// Invoke NAC3's builtin `np_min()` or `np_max()`. + pub fn min_or_max( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + kind: MinOrMax, + ) -> ScalarObject<'ctx> { + let on_empty_err_msg = format!( + "zero-size array to reduction operation {} which has no identity", + match kind { + MinOrMax::Min => "minimum", + MinOrMax::Max => "maximum", + } + ); + self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).0 + } + + /// Invoke NAC3's builtin `np_argmin()` or `np_argmax()`. + pub fn argmin_or_argmax( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + kind: MinOrMax, + ) -> Int<'ctx, SizeT> { + let on_empty_err_msg = format!( + "attempt to get {} of an empty sequence", + match kind { + MinOrMax::Min => "argmin", + MinOrMax::Max => "argmax", + } + ); + self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).1 + } +} diff --git a/nac3core/src/codegen/structure/ndarray/indexing.rs b/nac3core/src/codegen/structure/ndarray/indexing.rs index d2449cab..02cf870b 100644 --- a/nac3core/src/codegen/structure/ndarray/indexing.rs +++ b/nac3core/src/codegen/structure/ndarray/indexing.rs @@ -1,7 +1,4 @@ -use crate::codegen::{ - irrt::call_nac3_ndarray_index, model::*, structure::ndarray::scalar::ScalarObject, - CodeGenContext, CodeGenerator, -}; +use crate::codegen::{irrt::call_nac3_ndarray_index, model::*, CodeGenContext, CodeGenerator}; use super::{scalar::ScalarOrNDArray, NDArrayObject}; @@ -130,7 +127,7 @@ impl<'ctx> RustNDIndex<'ctx> { dst_ndindex_ptr: Ptr<'ctx, StructModel>, ) { let ndindex_type_model = IntModel(NDIndexType::default()); - let i32_model = IntModel(Int32::default()); + let i32_model = IntModel(Int32); let user_slice_model = StructModel(UserSlice); // Set `dst_ndindex_ptr->type` @@ -178,6 +175,7 @@ impl<'ctx> RustNDIndex<'ctx> { impl<'ctx> NDArrayObject<'ctx> { /// Get the ndims [`Type`] after indexing with a given slice. + #[must_use] pub fn deduce_ndims_after_indexing_with(&self, indexes: &[RustNDIndex<'ctx>]) -> u64 { let mut ndims = self.ndims; for index in indexes { @@ -235,11 +233,8 @@ impl<'ctx> NDArrayObject<'ctx> { let subndarray = self.index(generator, ctx, indexes, name); if subndarray.is_unsized() { - // NOTE: `np.size(self) == 0` is impossible. - let pfirst = subndarray.get_nth_pelement(generator, ctx, zero, "pfirst"); - let first = ctx.builder.build_load(pfirst, "first").unwrap(); - - ScalarOrNDArray::Scalar(ScalarObject { dtype: self.dtype, value: first }) + // NOTE: `np.size(self) == 0` here is never possible. + ScalarOrNDArray::Scalar(subndarray.get_nth(generator, ctx, zero)) } else { ScalarOrNDArray::NDArray(subndarray) } @@ -319,9 +314,9 @@ pub mod util { }) }; - let start = help(&start)?; - let stop = help(&stop)?; - let step = help(&step)?; + let start = help(start)?; + let stop = help(stop)?; + let step = help(step)?; RustNDIndex::Slice(RustUserSlice { start, stop, step }) } else { diff --git a/nac3core/src/codegen/structure/ndarray/mapping.rs b/nac3core/src/codegen/structure/ndarray/mapping.rs index 84d835cc..fddd953a 100644 --- a/nac3core/src/codegen/structure/ndarray/mapping.rs +++ b/nac3core/src/codegen/structure/ndarray/mapping.rs @@ -1,12 +1,11 @@ +use inkwell::values::BasicValueEnum; use itertools::Itertools; use util::gen_for_model_auto; use crate::{ codegen::{ model::*, - structure::ndarray::{ - broadcast::broadcast_all_ndarrays, scalar::ScalarObject, NDArrayObject, - }, + structure::ndarray::{scalar::ScalarObject, NDArrayObject}, CodeGenContext, CodeGenerator, }, typecheck::typedef::Type, @@ -14,179 +13,126 @@ use crate::{ use super::scalar::ScalarOrNDArray; -pub fn starmap_scalars_array_like<'ctx, 'a, F, G>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - inputs: &Vec>, - mapping: F, -) -> Result, String> -where - F: FnOnce( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - Int<'ctx, SizeT>, - &Vec>, - ) -> Result, String>, - G: CodeGenerator + ?Sized, -{ - assert!(!inputs.is_empty()); - - let sizet_model = IntModel(SizeT); - - // Check if all inputs are ScalarObjects - let scalars: Option> = - inputs.iter().map(|input| ScalarObject::try_from(input)).try_collect().ok(); - - match scalars { - Some(scalars) => { - // When inputs are all scalars, return a ScalarObject back - - let i = sizet_model.const_0(generator, ctx.ctx); - - let scalar = mapping(generator, ctx, i, &scalars)?; - Ok(ScalarOrNDArray::Scalar(scalar)) - } - None => { - // When not all inputs are scalars, promote all non-ndarray inputs - // to ndarrays, do broadcast_shapes on them, and map. - - let ndarrays = - inputs.into_iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec(); - - let broadcast_result = broadcast_all_ndarrays(generator, ctx, &ndarrays); - - let start = sizet_model.const_0(generator, ctx.ctx); - let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`. - let step = sizet_model.const_1(generator, ctx.ctx); - - // Map element-wise and store results into `mapped_ndarray`. - let mapped_ndarray = gen_for_model_auto( - generator, - ctx, - start, - stop, - step, - move |generator, ctx, _hooks, i| { - let elements = ndarrays - .iter() - .map(|ndarray| { - let pelement = ndarray.get_nth_pelement(generator, ctx, i, "pelement"); - let element = ctx.builder.build_load(pelement, "element").unwrap(); - ScalarObject { value: element, dtype: ndarray.dtype } - }) - .collect_vec(); - - let ret = mapping(generator, ctx, i, &elements)?; - - // It might look weird but it is perfectly fine putting the allocation codegen - // here within `for`. - // The reason for doing this is to get the `dtype` out of `ret`, which is only - // available after running `mapping`. - let mapped_ndarray = NDArrayObject::alloca_uninitialized( - generator, - ctx, - ret.dtype, - broadcast_result.ndims, - "mapped_ndarray", - ); - mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape); - mapped_ndarray.create_data(generator, ctx); - - let pret = mapped_ndarray.get_nth_pelement(generator, ctx, i, "pret"); - ctx.builder.build_store(pret, ret.value).unwrap(); - Ok(mapped_ndarray) - }, - )?; - Ok(ScalarOrNDArray::NDArray(mapped_ndarray)) - } - } -} - -impl<'ctx> ScalarObject<'ctx> { - pub fn map<'a, F, G>( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - mapping: F, - ) -> Result - where - F: FnOnce( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - Int<'ctx, SizeT>, - ScalarObject<'ctx>, - ) -> Result, String>, - G: CodeGenerator + ?Sized, - { - let ScalarOrNDArray::Scalar(ret) = starmap_scalars_array_like( - generator, - ctx, - &vec![ScalarOrNDArray::Scalar(*self)], - |generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]), - )? - else { - unreachable!() - }; - Ok(ret) - } -} - impl<'ctx> NDArrayObject<'ctx> { - pub fn map<'a, F, G>( - &self, + /// TODO: Document me. Has complex behavior. + pub fn broadcasting_starmap<'a, G, MappingFn>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, - mapping: F, + ndarrays: &[Self], + ret_dtype: Type, + name: &str, + mapping: MappingFn, ) -> Result where - F: FnOnce( + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( &mut G, &mut CodeGenContext<'ctx, 'a>, Int<'ctx, SizeT>, - ScalarObject<'ctx>, + &[ScalarObject<'ctx>], ) -> Result, String>, - G: CodeGenerator + ?Sized, { - let ScalarOrNDArray::NDArray(ret) = starmap_scalars_array_like( + let sizet_model = IntModel(SizeT); + + // Broadcast inputs + let broadcast_result = NDArrayObject::broadcast_all(generator, ctx, ndarrays); + + // Allocate the resulting ndarray + let mapped_ndarray = NDArrayObject::alloca_uninitialized( generator, ctx, - &vec![ScalarOrNDArray::NDArray(*self)], - |generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]), - )? - else { - unreachable!() - }; - Ok(ret) - } -} + ret_dtype, + broadcast_result.ndims, + name, + ); + mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape); + mapped_ndarray.create_data(generator, ctx); -impl<'ctx> ScalarOrNDArray<'ctx> { - pub fn map<'a, F, G>( + // Map element-wise and store results into `mapped_ndarray`. + let start = sizet_model.const_0(generator, ctx.ctx); + let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`. + let step = sizet_model.const_1(generator, ctx.ctx); + gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| { + let elements = + ndarrays.iter().map(|ndarray| ndarray.get_nth(generator, ctx, i)).collect_vec(); + + let ret = mapping(generator, ctx, i, &elements)?; + + let pret = mapped_ndarray.get_nth_pointer(generator, ctx, i, "pret"); + ctx.builder.build_store(pret, ret.value).unwrap(); + Ok(()) + })?; + + Ok(mapped_ndarray) + } + + pub fn map<'a, G, Mapping>( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, ret_dtype: Type, - mapping: F, + name: &str, + mapping: Mapping, ) -> Result where - F: FnOnce( + G: CodeGenerator + ?Sized, + Mapping: FnOnce( &mut G, &mut CodeGenContext<'ctx, 'a>, Int<'ctx, SizeT>, ScalarObject<'ctx>, - ) -> Result, String>, - G: CodeGenerator + ?Sized, + ) -> Result, String>, { - match self { - ScalarOrNDArray::Scalar(scalar) => starmap_scalars_array_like( - generator, - ctx, - &vec![ScalarOrNDArray::Scalar(*scalar)], - |generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]), - ), - ScalarOrNDArray::NDArray(ndarray) => { - ndarray.map(generator, ctx, mapping).map(ScalarOrNDArray::NDArray) - } + NDArrayObject::broadcasting_starmap( + generator, + ctx, + &[*self], + ret_dtype, + name, + |generator, ctx, i, scalars| { + let value = mapping(generator, ctx, i, scalars[0])?; + Ok(ScalarObject { dtype: ret_dtype, value }) + }, + ) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// TODO: Document me. Has complex behavior. + pub fn broadcasting_starmap<'a, G, MappingFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + inputs: &[Self], + ret_dtype: Type, + name: &str, + mapping: MappingFn, + ) -> Result + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + Int<'ctx, SizeT>, + &[ScalarObject<'ctx>], + ) -> Result, String>, + { + let sizet_model = IntModel(SizeT); + + // Check if all inputs are ScalarObjects + let all_scalars: Option> = + inputs.iter().map(ScalarObject::try_from).try_collect().ok(); + + if let Some(scalars) = all_scalars { + let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index + let scalar = mapping(generator, ctx, i, &scalars)?; + Ok(ScalarOrNDArray::Scalar(scalar)) + } else { + // Promote all input to ndarrays and map through them. + let inputs = inputs.iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec(); + let ndarray = NDArrayObject::broadcasting_starmap( + generator, ctx, &inputs, ret_dtype, name, mapping, + )?; + Ok(ScalarOrNDArray::NDArray(ndarray)) } } } diff --git a/nac3core/src/codegen/structure/ndarray/mod.rs b/nac3core/src/codegen/structure/ndarray/mod.rs index bfa6eda2..eaf3de7d 100644 --- a/nac3core/src/codegen/structure/ndarray/mod.rs +++ b/nac3core/src/codegen/structure/ndarray/mod.rs @@ -1,4 +1,5 @@ pub mod broadcast; +pub mod functions; pub mod indexing; pub mod mapping; pub mod scalar; @@ -22,8 +23,9 @@ use inkwell::{ context::Context, types::BasicType, values::{BasicValue, BasicValueEnum, PointerValue}, - AddressSpace, + AddressSpace, IntPredicate, }; +use scalar::ScalarObject; use util::{call_memcpy_model, gen_for_model_auto}; pub struct NpArrayFields<'ctx, F: FieldTraversal<'ctx>> { @@ -52,6 +54,7 @@ impl<'ctx> StructKind<'ctx> for NpArray { } } +/// A NAC3 Python ndarray object. #[derive(Debug, Clone, Copy)] pub struct NDArrayObject<'ctx> { pub dtype: Type, @@ -116,7 +119,7 @@ impl<'ctx> NDArrayObject<'ctx> { /// Get the pointer to the n-th (0-based) element. /// /// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`. - pub fn get_nth_pelement( + pub fn get_nth_pointer( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, @@ -131,6 +134,18 @@ impl<'ctx> NDArrayObject<'ctx> { .unwrap() } + /// Get the n-th (0-based) scalar. + pub fn get_nth( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + nth: Int<'ctx, SizeT>, + ) -> ScalarObject<'ctx> { + let p = self.get_nth_pointer(generator, ctx, nth, "value"); + let value = ctx.builder.build_load(p, "value").unwrap(); + ScalarObject { dtype: self.dtype, value } + } + /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. /// /// Please refer to the IRRT implementation to see its purpose. @@ -210,7 +225,7 @@ impl<'ctx> NDArrayObject<'ctx> { name: &str, ) -> Self { let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); - let ndims = extract_ndims(&mut ctx.unifier, ndims); + let ndims = extract_ndims(&ctx.unifier, ndims); Self::alloca_uninitialized(generator, ctx, dtype, ndims, name) } @@ -224,7 +239,22 @@ impl<'ctx> NDArrayObject<'ctx> { sizet_model.constant(generator, ctx, self.ndims) } - /// Return true if this ndarray is unsized. + /// Get if this ndarray's `np.size` is `0` - containing no content. + pub fn is_empty( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Int<'ctx, Bool> { + let sizet_model = IntModel(SizeT); + + let size = self.size(generator, ctx); + size.compare(ctx, IntPredicate::EQ, sizet_model.const_0(generator, ctx.ctx), "is_empty") + } + + /// Return true if this ndarray is unsized - `ndims == 0` and only contains a scalar. + /// + /// This is a staticially known property of ndarrays. This is why it is returning + /// a Rust boolean instead of a [`BasicValue`]. #[must_use] pub fn is_unsized(&self) -> bool { self.ndims == 0 @@ -273,7 +303,7 @@ impl<'ctx> NDArrayObject<'ctx> { ) { assert_eq!(self.ndims, src_ndarray.ndims); let src_shape = src_ndarray.value.get(generator, ctx, |f| f.shape, "src_shape"); - self.copy_shape_from_array(generator, ctx, src_shape) + self.copy_shape_from_array(generator, ctx, src_shape); } /// Copy strides dimensions from an array. @@ -298,14 +328,14 @@ impl<'ctx> NDArrayObject<'ctx> { ) { assert_eq!(self.ndims, src_ndarray.ndims); let src_strides = src_ndarray.value.get(generator, ctx, |f| f.strides, "src_strides"); - self.copy_strides_from_array(generator, ctx, src_strides) + self.copy_strides_from_array(generator, ctx, src_strides); } - /// Loop through every element pointer in the ndarray in its flatten view. + /// Iterate through every element pointer in the ndarray in its flatten view. /// /// `body` also access to [`BreakContinueHooks`] to short-circuit and an element's /// index. The given element pointer also has been casted to the LLVM type of this ndarray's `dtype`. - pub fn foreach<'a, G, F>( + pub fn foreach_pointer<'a, G, F>( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, @@ -328,12 +358,36 @@ impl<'ctx> NDArrayObject<'ctx> { let step = sizet_model.const_1(generator, ctx.ctx); gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, hooks, i| { - let pelement = self.get_nth_pelement(generator, ctx, i, "element"); + let pelement = self.get_nth_pointer(generator, ctx, i, "element"); body(generator, ctx, hooks, i, pelement) }) } - /// Fill the NDArray with a value. + /// Iterate through every scalar in this ndarray. + pub fn foreach<'a, G, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + body: F, + ) -> Result<(), String> + where + G: CodeGenerator + ?Sized, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + Int<'ctx, SizeT>, + ScalarObject<'ctx>, + ) -> Result<(), String>, + { + self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| { + let value = ctx.builder.build_load(p, "value").unwrap(); + let scalar = ScalarObject { dtype: self.dtype, value }; + body(generator, ctx, hooks, i, scalar) + }) + } + + /// Fill the ndarray with a value. /// /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. pub fn fill( @@ -342,11 +396,11 @@ impl<'ctx> NDArrayObject<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, fill_value: BasicValueEnum<'ctx>, ) { - self.foreach(generator, ctx, |_generator, ctx, _hooks, _i, pelement| { + self.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, _i, pelement| { ctx.builder.build_store(pelement, fill_value).unwrap(); Ok(()) }) - .unwrap() + .unwrap(); } /// Create a reshaped view on this ndarray like `np.reshape()`. diff --git a/nac3core/src/codegen/structure/ndarray/scalar.rs b/nac3core/src/codegen/structure/ndarray/scalar.rs index c1ab12a1..1a03a223 100644 --- a/nac3core/src/codegen/structure/ndarray/scalar.rs +++ b/nac3core/src/codegen/structure/ndarray/scalar.rs @@ -8,6 +8,10 @@ use crate::{ use super::NDArrayObject; /// An LLVM numpy scalar with its [`Type`]. +/// +/// Intended to be used with [`ScalarOrNDArray`]. +/// +/// A scalar does not have to be an actual number. It could be arbitrary objects. #[derive(Debug, Clone, Copy)] pub struct ScalarObject<'ctx> { pub dtype: Type, @@ -55,6 +59,22 @@ impl<'ctx> ScalarOrNDArray<'ctx> { } } + #[must_use] + pub fn into_scalar(&self) -> ScalarObject<'ctx> { + match self { + ScalarOrNDArray::NDArray(_ndarray) => panic!("Got NDArray"), + ScalarOrNDArray::Scalar(scalar) => *scalar, + } + } + + #[must_use] + pub fn into_ndarray(&self) -> NDArrayObject<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => *ndarray, + ScalarOrNDArray::Scalar(_scalar) => panic!("Got Scalar"), + } + } + /// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`. /// - If this is an ndarray, the ndarray is returned. /// - If this is a scalar, an unsized ndarray view is created on it. @@ -68,6 +88,14 @@ impl<'ctx> ScalarOrNDArray<'ctx> { ScalarOrNDArray::Scalar(scalar) => scalar.as_ndarray(generator, ctx), } } + + #[must_use] + pub fn dtype(&self) -> Type { + match self { + ScalarOrNDArray::Scalar(scalar) => scalar.dtype, + ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype, + } + } } impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> { @@ -76,7 +104,18 @@ impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> { fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { match value { ScalarOrNDArray::Scalar(scalar) => Ok(*scalar), - ScalarOrNDArray::NDArray(_) => Err(()), + ScalarOrNDArray::NDArray(_ndarray) => Err(()), + } + } +} + +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(_scalar) => Err(()), + ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray), } } } diff --git a/nac3core/src/codegen/structure/ndarray/shape_util.rs b/nac3core/src/codegen/structure/ndarray/shape_util.rs index 15970e78..591f7dd6 100644 --- a/nac3core/src/codegen/structure/ndarray/shape_util.rs +++ b/nac3core/src/codegen/structure/ndarray/shape_util.rs @@ -2,15 +2,11 @@ use inkwell::values::BasicValueEnum; use util::gen_for_model_auto; use crate::{ - codegen::{ - model::*, - structure::list::{List, ListObject}, - CodeGenContext, CodeGenerator, - }, + codegen::{model::*, structure::list::ListObject, CodeGenContext, CodeGenerator}, typecheck::typedef::{Type, TypeEnum}, }; -/// Parse a NumPy-like "int sequence" input and return the int sequence as a [`ListObject`] +/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length. /// /// * `sequence` - The `sequence` parameter. /// * `sequence_ty` - The typechecker type of `sequence` @@ -20,99 +16,97 @@ use crate::{ /// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` /// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` /// -/// `int32` values will be sign-extended to `SizeT` +/// All `int32` values will be sign-extended to `SizeT`. pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - sequence: BasicValueEnum<'ctx>, - sequence_ty: Type, -) -> ListObject<'ctx, IntModel, SizeT> { + input_sequence: BasicValueEnum<'ctx>, + input_sequence_ty: Type, +) -> (Int<'ctx, SizeT>, Ptr<'ctx, IntModel>) { let sizet_model = IntModel(SizeT); - let list_model = StructModel(List { len: SizeT, item: IntModel(SizeT) }); let zero = sizet_model.const_0(generator, ctx.ctx); let one = sizet_model.const_1(generator, ctx.ctx); // The result `list` to return. - let result = list_model.alloca(generator, ctx, "result_sequence"); - match &*ctx.unifier.get_ty(sequence_ty) { + match &*ctx.unifier.get_ty(input_sequence_ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => { // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` - let in_sequence_model = - PtrModel(StructModel(List { item: IntModel(Int32), len: SizeT })); - let in_sequence = in_sequence_model.check_value(generator, ctx.ctx, sequence).unwrap(); - /* - Reference code: - ``` - result.size = sequence.size; - result.data = __builtin_alloca(sizeof(SizeT) * sequence.size); - for (SizeT i = 0; i < sequence.size; i++) { - result.data[i] = (SizeT) sequence.data[i]; - } - return result - ``` - */ + // Check `input_sequence` + let input_sequence = + ListObject::from_value_and_type(generator, ctx, input_sequence, input_sequence_ty); - let ndims = in_sequence.get(generator, ctx, |f| f.len, "size"); - result.set(ctx, |f| f.len, ndims); + let len = input_sequence.value.gep(ctx, |f| f.len).load(generator, ctx, "len"); + let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence"); - let result_data = sizet_model.array_alloca(generator, ctx, ndims.value, "data"); - result.set(ctx, |f| f.items, result_data); + // Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result` + gen_for_model_auto(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| { + // Load the i-th int32 in the input sequence + let int = input_sequence + .value + .get(generator, ctx, |f| f.items, "int") + .ix(generator, ctx, i.value, "int") + .value + .into_int_value(); + + // Cast to SizeT + let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int"); + + // Store + result.offset(generator, ctx, i.value, "int").store(ctx, int); - gen_for_model_auto(generator, ctx, zero, ndims, one, |generator, ctx, _hooks, i| { - let in_dim = in_sequence - .get(generator, ctx, |f| f.items, "in_dim") - .ix(generator, ctx, i.value, "in_dim") - .s_extend_or_bit_cast(generator, ctx, SizeT, "in_dim"); - result_data.offset(generator, ctx, i.value, "dim").store(ctx, in_dim); Ok(()) }) .unwrap(); + + (len, result) } TypeEnum::TTuple { ty: tuple_types } => { // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` - let ndims_int = tuple_types.len(); - let ndims = sizet_model.constant(generator, ctx.ctx, ndims_int as u64); - result.set(ctx, |f| f.len, ndims); + let input_sequence = input_sequence.into_struct_value(); // A tuple is a struct - // A tuple has to be a StructValue - // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. - let tuple = sequence.into_struct_value(); - let data = sizet_model.array_alloca(generator, ctx, ndims.value, "sequence_data"); - result.set(ctx, |f| f.items, data); + let len_int = tuple_types.len(); - for i in 0..ndims_int { - // Get the i-th (0-based) element off of the tuple and load it - // into `result`. - let dim = ctx + let len = sizet_model.constant(generator, ctx.ctx, len_int as u64); + let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence"); + + for i in 0..len_int { + // Get the i-th element off of the tuple and load it into `result`. + let int = ctx .builder - .build_extract_value(tuple, i as u32, format!("dim").as_str()) + .build_extract_value(input_sequence, i as u32, "int") .unwrap() .into_int_value(); - let dim = sizet_model.s_extend_or_bit_cast(generator, ctx, dim, "dim"); + let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int"); let offset = sizet_model.constant(generator, ctx.ctx, i as u64); - data.offset(generator, ctx, offset.value, "dim").store(ctx, dim); + result.offset(generator, ctx, offset.value, "int").store(ctx, int); } + + (len, result) } TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => { // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` - let sequence_int = sizet_model.check_value(generator, ctx.ctx, sequence).unwrap(); + let input_int = input_sequence.into_int_value(); - // Size is 1 - result.set(ctx, |f| f.len, one); + let len = sizet_model.const_1(generator, ctx.ctx); + let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence"); - // Alloca an array of length 1 and store the sole integer input into the array. - let data = sizet_model.array_alloca(generator, ctx, one.value, "data"); - data.offset(generator, ctx, zero.value, "dim").store(ctx, sequence_int); + let int = sizet_model.s_extend_or_bit_cast(generator, ctx, input_int, "int"); + + // Storing into result[0] + result.store(ctx, int); + + (len, result) } - _ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(sequence_ty)), + _ => panic!( + "encountered unknown sequence type: {}", + ctx.unifier.stringify(input_sequence_ty) + ), } - - ListObject { item_type: ctx.primitives.usize(), value: result } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 1ed3628c..cebbebe9 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1180,7 +1180,7 @@ impl<'a> BuiltinBuilder<'a> { let ret_elem_ty = size_variant.of_int(&ctx.primitives); let func = match kind { Kind::Ceil => builtin_fns::call_ceil, - Kind::Floor => builtin_fns::call_ceil_or_floor, + Kind::Floor => builtin_fns::call_floor, }; Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) }), @@ -1361,7 +1361,7 @@ impl<'a> BuiltinBuilder<'a> { let ndims1 = create_ndims(self.unifier, 1); let ndarray_float_1d = make_ndarray_ty( self.unifier, - &self.primitives, + self.primitives, Some(self.primitives.float), Some(ndims1), ); @@ -1470,7 +1470,7 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::FunNpSize => { // TODO: Make the return type usize create_fn_by_codegen( - &mut self.unifier, + self.unifier, &VarMap::new(), prim.name(), self.primitives.int32, @@ -1485,7 +1485,7 @@ impl<'a> BuiltinBuilder<'a> { // of the input ndarray. let ret_ty = self.unifier.get_dummy_var().ty; create_fn_by_codegen( - &mut self.unifier, + self.unifier, &VarMap::new(), prim.name(), ret_ty, @@ -1548,7 +1548,7 @@ impl<'a> BuiltinBuilder<'a> { let func = match prim { PrimDef::FunNpCeil => builtin_fns::call_ceil, - PrimDef::FunNpFloor => builtin_fns::call_ceil_or_floor, + PrimDef::FunNpFloor => builtin_fns::call_floor, _ => unreachable!(), }; Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))