From 858b4b9f3f49adc1b9d007de9c8653372a37a7ea Mon Sep 17 00:00:00 2001 From: lyken Date: Fri, 9 Aug 2024 12:03:10 +0800 Subject: [PATCH] core/ndstrides: checkpoint 3 --- nac3core/src/codegen/builtin_fns.rs | 1298 +---------------- nac3core/src/codegen/numpy_new.rs | 8 +- .../codegen/structure/ndarray/functions.rs | 139 +- .../src/codegen/structure/ndarray/mapping.rs | 52 +- nac3core/src/codegen/structure/ndarray/mod.rs | 20 + nac3core/src/toplevel/builtins.rs | 349 +++-- nac3core/src/toplevel/numpy.rs | 33 + nac3core/src/typecheck/mod.rs | 1 - nac3core/src/typecheck/numpy.rs | 33 - nac3core/src/typecheck/type_inferencer/mod.rs | 3 +- 10 files changed, 453 insertions(+), 1483 deletions(-) delete mode 100644 nac3core/src/typecheck/numpy.rs diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 89d5d710..9e6f9e75 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,14 +1,12 @@ use inkwell::types::BasicTypeEnum; use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; -use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use itertools::Itertools; use crate::codegen::classes::{ NDArrayValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; -use crate::codegen::stmt::gen_for_callback_incrementing; -use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; +use crate::codegen::{extern_fns, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::toplevel::helper::PrimDef; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::Type; @@ -23,988 +21,6 @@ fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) - ) } -/// Invokes the `int32` builtin function. -pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - Ok(match n { - BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); - - ctx.builder.build_int_z_extend(n, llvm_i32, "zext").map(Into::into).unwrap() - } - - BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 32 => { - debug_assert!([ctx.primitives.int32, ctx.primitives.uint32,] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - n.into() - } - - BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { - debug_assert!([ctx.primitives.int64, ctx.primitives.uint64,] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - ctx.builder.build_int_truncate(n, llvm_i32, "trunc").map(Into::into).unwrap() - } - - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - let to_int64 = - ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap(); - ctx.builder.build_int_truncate(to_int64, llvm_i32, "conv").map(Into::into).unwrap() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int32, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, "int32", &[n_ty]), - }) -} - -/// Invokes the `int64` builtin function. -pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - - Ok(match n { - BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => { - debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { - ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() - } else { - ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() - } - } - - BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { - debug_assert!([ctx.primitives.int64, ctx.primitives.uint64,] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - n.into() - } - - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - ctx.builder - .build_float_to_signed_int(n, ctx.ctx.i64_type(), "fptosi") - .map(Into::into) - .unwrap() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int64, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, "int64", &[n_ty]), - }) -} - -/// Invokes the `uint32` builtin function. -pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - - Ok(match n { - BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); - - ctx.builder.build_int_z_extend(n, llvm_i32, "zext").map(Into::into).unwrap() - } - - BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 32 => { - debug_assert!([ctx.primitives.int32, ctx.primitives.uint32,] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - n.into() - } - - BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { - debug_assert!( - ctx.unifier.unioned(n_ty, ctx.primitives.int64) - || ctx.unifier.unioned(n_ty, ctx.primitives.uint64) - ); - - ctx.builder.build_int_truncate(n, llvm_i32, "trunc").map(Into::into).unwrap() - } - - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - let n_gez = ctx - .builder - .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") - .unwrap(); - - let to_int32 = ctx.builder.build_float_to_signed_int(n, llvm_i32, "").unwrap(); - let to_uint64 = - ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); - - ctx.builder - .build_select( - n_gez, - ctx.builder.build_int_truncate(to_uint64, llvm_i32, "").unwrap(), - to_int32, - "conv", - ) - .unwrap() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint32, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, "uint32", &[n_ty]), - }) -} - -/// Invokes the `uint64` builtin function. -pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - - Ok(match n { - BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => { - debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { - ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() - } else { - ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() - } - } - - BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { - debug_assert!([ctx.primitives.int64, ctx.primitives.uint64,] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - n.into() - } - - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - let val_gez = ctx - .builder - .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") - .unwrap(); - - let to_int64 = ctx.builder.build_float_to_signed_int(n, llvm_i64, "").unwrap(); - let to_uint64 = ctx.builder.build_float_to_unsigned_int(n, llvm_i64, "").unwrap(); - - ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint64, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, "uint64", &[n_ty]), - }) -} - -/// Invokes the `float` builtin function. -pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - let llvm_f64 = ctx.ctx.f64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - - Ok(match n { - BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - if [ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.int64] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty)) - { - ctx.builder - .build_signed_int_to_float(n, llvm_f64, "sitofp") - .map(Into::into) - .unwrap() - } else { - ctx.builder - .build_unsigned_int_to_float(n, llvm_f64, "uitofp") - .map(Into::into) - .unwrap() - } - } - - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - n.into() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, "float", &[n_ty]), - }) -} - -/// Invokes the `round` builtin function. -pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), - ret_elem_ty: Type, -) -> Result, String> { - const FN_NAME: &str = "round"; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty).into_int_type(); - - Ok(match n { - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - let val = llvm_intrinsics::call_float_round(ctx, n, None); - ctx.builder - .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) - .map(Into::into) - .unwrap() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, FN_NAME, &[n_ty]), - }) -} - -/// Invokes the `np_round` builtin function. -pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_round"; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - - Ok(match n { - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - llvm_intrinsics::call_float_rint(ctx, n, None).into() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, FN_NAME, &[n_ty]), - }) -} - -/// Invokes the `bool` builtin function. -pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "bool"; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - - Ok(match n { - BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); - - n.into() - } - - BasicValueEnum::IntValue(n) => { - debug_assert!([ - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - ctx.builder - .build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME) - .map(Into::into) - .unwrap() - } - - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - ctx.builder - .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), FN_NAME) - .map(Into::into) - .unwrap() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - let elem = call_bool(generator, ctx, (elem_ty, val))?; - - Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) - }, - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, FN_NAME, &[n_ty]), - }) -} - -/// Invokes the `floor` builtin function. -pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), - ret_elem_ty: Type, -) -> Result, String> { - const FN_NAME: &str = "floor"; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); - - Ok(match n { - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - let val = llvm_intrinsics::call_float_floor(ctx, n, None); - if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { - ctx.builder - .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) - .map(Into::into) - .unwrap() - } else { - val.into() - } - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, FN_NAME, &[n_ty]), - }) -} - -/// Invokes the `ceil` builtin function. -pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), - ret_elem_ty: Type, -) -> Result, String> { - const FN_NAME: &str = "ceil"; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); - - Ok(match n { - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - let val = llvm_intrinsics::call_float_ceil(ctx, n, None); - if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { - ctx.builder - .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) - .map(Into::into) - .unwrap() - } else { - val.into() - } - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, FN_NAME, &[n_ty]), - }) -} - -/// Invokes the `min` builtin function. -pub fn call_min<'ctx>( - ctx: &mut CodeGenContext<'ctx, '_>, - m: (Type, BasicValueEnum<'ctx>), - n: (Type, BasicValueEnum<'ctx>), -) -> BasicValueEnum<'ctx> { - const FN_NAME: &str = "min"; - - let (m_ty, m) = m; - let (n_ty, n) = n; - - let common_ty = if ctx.unifier.unioned(m_ty, n_ty) { - m_ty - } else { - unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]) - }; - - match (m, n) { - (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty, *ty))); - - if [ctx.primitives.int32, ctx.primitives.int64] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty, *ty)) - { - llvm_intrinsics::call_int_smin(ctx, m, n, Some(FN_NAME)).into() - } else { - llvm_intrinsics::call_int_umin(ctx, m, n, Some(FN_NAME)).into() - } - } - - (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { - debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float)); - - llvm_intrinsics::call_float_minnum(ctx, m, n, Some(FN_NAME)).into() - } - - _ => unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]), - } -} - -/// Invokes the `np_minimum` builtin function. -pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_minimum"; - - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - - let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None }; - - Ok(match (x1, x2) { - (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ctx.primitives.float, - ] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); - - call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) - } - - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); - - call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) - } - - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) -} - -/// Invokes the `max` builtin function. -pub fn call_max<'ctx>( - ctx: &mut CodeGenContext<'ctx, '_>, - m: (Type, BasicValueEnum<'ctx>), - n: (Type, BasicValueEnum<'ctx>), -) -> BasicValueEnum<'ctx> { - const FN_NAME: &str = "max"; - - let (m_ty, m) = m; - let (n_ty, n) = n; - - let common_ty = if ctx.unifier.unioned(m_ty, n_ty) { - m_ty - } else { - unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]) - }; - - match (m, n) { - (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty, *ty))); - - if [ctx.primitives.int32, ctx.primitives.int64] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty, *ty)) - { - llvm_intrinsics::call_int_smax(ctx, m, n, Some(FN_NAME)).into() - } else { - llvm_intrinsics::call_int_umax(ctx, m, n, Some(FN_NAME)).into() - } - } - - (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { - debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float)); - - llvm_intrinsics::call_float_maxnum(ctx, m, n, Some(FN_NAME)).into() - } - - _ => unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]), - } -} - -/// Invokes the `np_max`, `np_min`, `np_argmax`, `np_argmin` functions -/// * `fn_name`: Can be one of `"np_argmin"`, `"np_argmax"`, `"np_max"`, `"np_min"` -pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - a: (Type, BasicValueEnum<'ctx>), - fn_name: &str, -) -> Result, String> { - debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name)); - - let llvm_int64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (a_ty, a) = a; - Ok(match a { - BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ctx.primitives.float, - ] - .iter() - .any(|ty| ctx.unifier.unioned(a_ty, *ty))); - - match fn_name { - "np_argmin" | "np_argmax" => llvm_int64.const_zero().into(), - "np_max" | "np_min" => a, - _ => unreachable!(), - } - } - BasicValueEnum::PointerValue(n) - if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - - let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); - let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx - .builder - .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") - .unwrap(); - - ctx.make_assert( - generator, - n_sz_eqz, - "0:ValueError", - format!("zero-size array to reduction operation {fn_name}").as_str(), - [None, None, None], - ctx.current_loc, - ); - } - - let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?; - - unsafe { - let identity = - n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_store(accumulator_addr, identity).unwrap(); - ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap(); - } - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_int64.const_int(1, false), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); - - let result = match fn_name { - "np_argmin" | "np_min" => { - call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)) - } - "np_argmax" | "np_max" => { - call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)) - } - _ => unreachable!(), - }; - - let updated_idx = match (accumulator, result) { - (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx - .builder - .build_select( - ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(), - idx.into(), - cur_idx, - "", - ) - .unwrap(), - (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => ctx - .builder - .build_select( - ctx.builder - .build_float_compare(FloatPredicate::ONE, m, n, "") - .unwrap(), - idx.into(), - cur_idx, - "", - ) - .unwrap(), - _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]), - }; - ctx.builder.build_store(res_idx, updated_idx).unwrap(); - Ok(()) - }, - llvm_int64.const_int(1, false), - )?; - - match fn_name { - "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(), - "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(), - _ => unreachable!(), - } - } - - _ => unsupported_type(ctx, fn_name, &[a_ty]), - }) -} - -/// Invokes the `np_maximum` builtin function. -pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_maximum"; - - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - - let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None }; - - Ok(match (x1, x2) { - (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ctx.primitives.float, - ] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); - - call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) - } - - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); - - call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) - } - - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) -} - /// Helper function to create a built-in elementwise unary numpy function that takes in either an ndarray or a scalar. /// /// * `(arg_ty, arg_val)`: The [`Type`] and llvm value of the input argument. @@ -1066,318 +82,6 @@ where Ok(result) } -pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "abs"; - helper_call_numpy_unary_elementwise( - generator, - ctx, - n, - FN_NAME, - &|_ctx, elem_ty| elem_ty, - &|_generator, ctx, val_ty, val| match val { - BasicValueEnum::IntValue(n) => Some({ - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(val_ty, *ty))); - - if [ctx.primitives.int32, ctx.primitives.int64] - .iter() - .any(|ty| ctx.unifier.unioned(val_ty, *ty)) - { - llvm_intrinsics::call_int_abs( - ctx, - n, - ctx.ctx.bool_type().const_zero(), - Some(FN_NAME), - ) - .into() - } else { - n.into() - } - }), - - BasicValueEnum::FloatValue(n) => Some({ - debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float)); - - llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into() - }), - - _ => None, - }, - ) -} - -/// Macro to conveniently generate numpy functions with [`helper_call_numpy_unary_elementwise`]. -/// -/// Arguments: -/// * `$name:ident`: The identifier of the rust function to be generated. -/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`] -/// * `$get_ret_elem_type:expr`: To be passed to the `get_ret_elem_type` parameter of [`helper_call_numpy_unary_elementwise`]. -/// But there is no need to make it a reference. -/// * `$on_scalar:expr`: To be passed to the `on_scalar` parameter of [`helper_call_numpy_unary_elementwise`]. -/// But there is no need to make it a reference. -macro_rules! create_helper_call_numpy_unary_elementwise { - ($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_scalar:expr) => { - #[allow(clippy::redundant_closure_call)] - pub fn $name<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - arg: (Type, BasicValueEnum<'ctx>), - ) -> Result, String> { - helper_call_numpy_unary_elementwise( - generator, - ctx, - arg, - $fn_name, - &$get_ret_elem_type, - &$on_scalar, - ) - } - }; -} - -/// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns boolean (as an `i8`) elementwise. -/// -/// Arguments: -/// * `$name:ident`: The identifier of the rust function to be generated. -/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`]. -/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns -/// the boolean results of LLVM type `i1`. The returned `i1` value will be converted into an `i8`. -/// -/// ```ignore -/// // Type of `$on_scalar:expr` -/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>( -/// generator: &mut G, -/// ctx: &mut CodeGenContext<'ctx, '_>, -/// arg: FloatValue<'ctx> -/// ) -> IntValue<'ctx> // of LLVM type `i1` -/// ``` -macro_rules! create_helper_call_numpy_unary_elementwise_float_to_bool { - ($name:ident, $fn_name:literal, $on_scalar:expr) => { - create_helper_call_numpy_unary_elementwise!( - $name, - $fn_name, - |ctx, _| ctx.primitives.bool, - |generator, ctx, n_ty, val| { - match val { - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - let ret = $on_scalar(generator, ctx, n); - Some(generator.bool_to_i8(ctx, ret).into()) - } - _ => None, - } - } - ); - }; -} - -/// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns float elementwise. -/// -/// Arguments: -/// * `$name:ident`: The identifier of the rust function to be generated. -/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`]. -/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns float results. -/// -/// ```ignore -/// // Type of `$on_scalar:expr` -/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>( -/// generator: &mut G, -/// ctx: &mut CodeGenContext<'ctx, '_>, -/// arg: FloatValue<'ctx> -/// ) -> FloatValue<'ctx> -/// ``` -macro_rules! create_helper_call_numpy_unary_elementwise_float_to_float { - ($name:ident, $fn_name:literal, $elem_call:expr) => { - create_helper_call_numpy_unary_elementwise!( - $name, - $fn_name, - |ctx, _| ctx.primitives.float, - |_generator, ctx, val_ty, val| { - match val { - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float)); - - Some($elem_call(ctx, n, Option::<&str>::None).into()) - } - _ => None, - } - } - ); - }; -} - -create_helper_call_numpy_unary_elementwise_float_to_bool!( - call_numpy_isnan, - "np_isnan", - irrt::call_isnan -); -create_helper_call_numpy_unary_elementwise_float_to_bool!( - call_numpy_isinf, - "np_isinf", - irrt::call_isinf -); - -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_sin, - "np_sin", - llvm_intrinsics::call_float_sin -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_cos, - "np_cos", - llvm_intrinsics::call_float_cos -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_tan, - "np_tan", - extern_fns::call_tan -); - -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_arcsin, - "np_arcsin", - extern_fns::call_asin -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_arccos, - "np_arccos", - extern_fns::call_acos -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_arctan, - "np_arctan", - extern_fns::call_atan -); - -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_sinh, - "np_sinh", - extern_fns::call_sinh -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_cosh, - "np_cosh", - extern_fns::call_cosh -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_tanh, - "np_tanh", - extern_fns::call_tanh -); - -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_arcsinh, - "np_arcsinh", - extern_fns::call_asinh -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_arccosh, - "np_arccosh", - extern_fns::call_acosh -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_arctanh, - "np_arctanh", - extern_fns::call_atanh -); - -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_exp, - "np_exp", - llvm_intrinsics::call_float_exp -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_exp2, - "np_exp2", - llvm_intrinsics::call_float_exp2 -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_expm1, - "np_expm1", - extern_fns::call_expm1 -); - -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_log, - "np_log", - llvm_intrinsics::call_float_log -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_log2, - "np_log2", - llvm_intrinsics::call_float_log2 -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_log10, - "np_log10", - llvm_intrinsics::call_float_log10 -); - -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_sqrt, - "np_sqrt", - llvm_intrinsics::call_float_sqrt -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_cbrt, - "np_cbrt", - extern_fns::call_cbrt -); - -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_fabs, - "np_fabs", - llvm_intrinsics::call_float_fabs -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_numpy_rint, - "np_rint", - llvm_intrinsics::call_float_rint -); - -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_scipy_special_erf, - "sp_spec_erf", - extern_fns::call_erf -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_scipy_special_erfc, - "sp_spec_erfc", - extern_fns::call_erfc -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_scipy_special_gamma, - "sp_spec_gamma", - |ctx, val, _| irrt::call_gamma(ctx, val) -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_scipy_special_gammaln, - "sp_spec_gammaln", - |ctx, val, _| irrt::call_gammaln(ctx, val) -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_scipy_special_j0, - "sp_spec_j0", - |ctx, val, _| irrt::call_j0(ctx, val) -); -create_helper_call_numpy_unary_elementwise_float_to_float!( - call_scipy_special_j1, - "sp_spec_j1", - extern_fns::call_j1 -); - /// Invokes the `np_arctan2` builtin function. pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, diff --git a/nac3core/src/codegen/numpy_new.rs b/nac3core/src/codegen/numpy_new.rs index 48917551..1820b61e 100644 --- a/nac3core/src/codegen/numpy_new.rs +++ b/nac3core/src/codegen/numpy_new.rs @@ -15,11 +15,11 @@ use crate::{ }, }, symbol_resolver::ValueEnum, - toplevel::{numpy::unpack_ndarray_var_tys, DefinitionId}, - typecheck::{ - numpy::extract_ndims, - typedef::{FunSignature, Type}, + toplevel::{ + numpy::{extract_ndims, unpack_ndarray_var_tys}, + DefinitionId, }, + typecheck::typedef::{FunSignature, Type}, }; use super::{ diff --git a/nac3core/src/codegen/structure/ndarray/functions.rs b/nac3core/src/codegen/structure/ndarray/functions.rs index b594d1d5..1bd812e6 100644 --- a/nac3core/src/codegen/structure/ndarray/functions.rs +++ b/nac3core/src/codegen/structure/ndarray/functions.rs @@ -68,7 +68,7 @@ fn cast_to_int_conversion<'ctx, 'a, G, HandleFloatFn>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, scalar: ScalarObject<'ctx>, - target_int_dtype: Type, + ret_int_dtype: Type, handle_float: HandleFloatFn, ) -> ScalarObject<'ctx> where @@ -76,7 +76,7 @@ where HandleFloatFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>, { - let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type(); + let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type(); let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) { // Special handling for floats @@ -85,20 +85,44 @@ where } else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) { let n = scalar.value.into_int_value(); - if n.get_type().get_bit_width() <= target_int_dtype_llvm.get_bit_width() { - ctx.builder.build_int_z_extend(n, target_int_dtype_llvm, "zext").unwrap() + if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() { + ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap() } else { - ctx.builder.build_int_truncate(n, target_int_dtype_llvm, "trunc").unwrap() + ctx.builder.build_int_truncate(n, ret_int_dtype_llvm, "trunc").unwrap() } } else { unsupported_type(ctx, [scalar.dtype]); }; - assert_eq!(target_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check - ScalarObject { value: result.into(), dtype: target_int_dtype } + assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check + ScalarObject { value: result.into(), dtype: ret_int_dtype } } impl<'ctx> ScalarObject<'ctx> { + /// Convenience function. Assume this scalar has typechecker type float64, get its underlying LLVM value. + /// + /// Panic if the type is wrong. + pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> { + if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { + self.value.into_float_value() // self.value must be a FloatValue + } else { + panic!("not a float type") + } + } + + /// Convenience function. Assume this scalar has typechecker type int32, get its underlying LLVM value. + /// + /// Panic if the type is wrong. + pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) { + let value = self.value.into_int_value(); + debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check + value + } else { + panic!("not a float type") + } + } + /// Compare two scalars. Only int-to-int and float-to-float comparisons are allowed. /// Panic otherwise. pub fn compare( @@ -238,10 +262,7 @@ impl<'ctx> ScalarObject<'ctx> { /// Invoke NAC3's builtin `bool()`. #[must_use] - pub fn cast_to_bool( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - ) -> Self { + pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { // TODO: Why is the original code being so lax about i1 and i8 for the returned int type? let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) { self.value.into_int_value() @@ -262,24 +283,47 @@ impl<'ctx> ScalarObject<'ctx> { ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() } } + /// Invoke NAC3's builtin `float()`. + #[must_use] + pub fn cast_to_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { + let llvm_f64 = ctx.ctx.f64_type(); + + let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { + self.value.into_float_value() + } else if ctx + .unifier + .unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat()) + { + let n = self.value.into_int_value(); + ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap() + } else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) { + let n = self.value.into_int_value(); + ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap() + } else { + unsupported_type(ctx, [self.dtype]); + }; + + ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float } + } + /// Invoke NAC3's builtin `round()`. #[must_use] pub fn round( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - target_int_dtype: Type, + ret_int_dtype: Type, ) -> Self { - let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type(); + let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type(); let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { let n = self.value.into_float_value(); let n = llvm_intrinsics::call_float_round(ctx, n, None); - ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "round").unwrap() + ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap() } else { - unsupported_type(ctx, [self.dtype, target_int_dtype]) + unsupported_type(ctx, [self.dtype, ret_int_dtype]) }; - ScalarObject { dtype: target_int_dtype, value: result.as_basic_value_enum() } + ScalarObject { dtype: ret_int_dtype, value: result.as_basic_value_enum() } } /// Invoke NAC3's builtin `np_round()`. @@ -287,7 +331,7 @@ impl<'ctx> ScalarObject<'ctx> { /// NOTE: `np.round()` has different behaviors than `round()` in terms of their result /// on "tie" cases and return type. #[must_use] - pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { + pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { let n = self.value.into_float_value(); llvm_intrinsics::call_float_rint(ctx, n, None) @@ -298,7 +342,7 @@ impl<'ctx> ScalarObject<'ctx> { } /// Invoke NAC3's builtin `min()` or `max()`. - fn min_or_max_helper( + pub fn min_or_max( ctx: &mut CodeGenContext<'ctx, '_>, kind: MinOrMax, a: Self, @@ -335,15 +379,19 @@ impl<'ctx> ScalarObject<'ctx> { } /// Invoke NAC3's builtin `floor()` or `ceil()`. + /// + /// * `ret_int_dtype` - The type of int to return. + /// + /// Takes in a float/int and returns an int of type `ret_int_dtype` #[must_use] pub fn floor_or_ceil( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, kind: FloorOrCeil, - target_int_dtype: Type, + ret_int_dtype: Type, ) -> Self { - let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type(); + let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type(); if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { let function = match kind { @@ -352,8 +400,45 @@ impl<'ctx> ScalarObject<'ctx> { }; let n = self.value.into_float_value(); let n = function(ctx, n, None); - let n = ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "").unwrap(); - ScalarObject { dtype: target_int_dtype, value: n.as_basic_value_enum() } + + let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap(); + ScalarObject { dtype: ret_int_dtype, value: n.as_basic_value_enum() } + } else { + unsupported_type(ctx, [self.dtype]) + } + } + + /// Invoke NAC3's builtin `np_floor()`/ `np_ceil()`. + /// + /// Takes in a float/int and returns a float64 result. + #[must_use] + pub fn np_floor_or_ceil(&self, ctx: &mut CodeGenContext<'ctx, '_>, kind: FloorOrCeil) -> Self { + if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { + let function = match kind { + FloorOrCeil::Floor => llvm_intrinsics::call_float_floor, + FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil, + }; + let n = self.value.into_float_value(); + let n = function(ctx, n, None); + ScalarObject { dtype: ctx.primitives.float, value: n.as_basic_value_enum() } + } else { + unsupported_type(ctx, [self.dtype]) + } + } + + /// Invoke NAC3's builtin `abs()`. + pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { + if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { + let n = self.value.into_float_value(); + let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs")); + ScalarObject { value: n.into(), dtype: ctx.primitives.float } + } else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) { + let n = self.value.into_int_value(); + + let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false + let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs")); + + ScalarObject { value: n.into(), dtype: self.dtype } } else { unsupported_type(ctx, [self.dtype]) } @@ -361,12 +446,12 @@ impl<'ctx> ScalarObject<'ctx> { } impl<'ctx> NDArrayObject<'ctx> { - /// Helper function for NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`. + /// Helper function to implement NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`. /// /// Generate LLVM IR to find the extremum and index of the **first** extremum value. /// /// Care has also been taken to make the error messages match that of NumPy. - fn min_or_max_helper( + fn min_max_argmin_argmax_helper( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, @@ -410,7 +495,7 @@ impl<'ctx> NDArrayObject<'ctx> { let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap(); let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum }; - let new_extremum = ScalarObject::min_or_max_helper(ctx, kind, old_extremum, scalar); + let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar); // Check if new_extremum is more extreme than old_extremum. let update_index = ScalarObject::compare( @@ -455,7 +540,7 @@ impl<'ctx> NDArrayObject<'ctx> { MinOrMax::Max => "maximum", } ); - self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).0 + self.min_max_argmin_argmax_helper(generator, ctx, kind, &on_empty_err_msg).0 } /// Invoke NAC3's builtin `np_argmin()` or `np_argmax()`. @@ -472,6 +557,6 @@ impl<'ctx> NDArrayObject<'ctx> { MinOrMax::Max => "argmax", } ); - self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).1 + self.min_max_argmin_argmax_helper(generator, ctx, kind, &on_empty_err_msg).1 } } diff --git a/nac3core/src/codegen/structure/ndarray/mapping.rs b/nac3core/src/codegen/structure/ndarray/mapping.rs index fddd953a..e30806ba 100644 --- a/nac3core/src/codegen/structure/ndarray/mapping.rs +++ b/nac3core/src/codegen/structure/ndarray/mapping.rs @@ -15,12 +15,12 @@ use super::scalar::ScalarOrNDArray; impl<'ctx> NDArrayObject<'ctx> { /// TODO: Document me. Has complex behavior. + /// and explain why `ret_dtype` has to be specified beforehand. pub fn broadcasting_starmap<'a, G, MappingFn>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, ndarrays: &[Self], ret_dtype: Type, - name: &str, mapping: MappingFn, ) -> Result where @@ -30,7 +30,7 @@ impl<'ctx> NDArrayObject<'ctx> { &mut CodeGenContext<'ctx, 'a>, Int<'ctx, SizeT>, &[ScalarObject<'ctx>], - ) -> Result, String>, + ) -> Result, String>, { let sizet_model = IntModel(SizeT); @@ -43,7 +43,7 @@ impl<'ctx> NDArrayObject<'ctx> { ctx, ret_dtype, broadcast_result.ndims, - name, + "mapped_ndarray", ); mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape); mapped_ndarray.create_data(generator, ctx); @@ -59,7 +59,7 @@ impl<'ctx> NDArrayObject<'ctx> { let ret = mapping(generator, ctx, i, &elements)?; let pret = mapped_ndarray.get_nth_pointer(generator, ctx, i, "pret"); - ctx.builder.build_store(pret, ret.value).unwrap(); + ctx.builder.build_store(pret, ret).unwrap(); Ok(()) })?; @@ -71,7 +71,6 @@ impl<'ctx> NDArrayObject<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, ret_dtype: Type, - name: &str, mapping: Mapping, ) -> Result where @@ -88,23 +87,19 @@ impl<'ctx> NDArrayObject<'ctx> { ctx, &[*self], ret_dtype, - name, - |generator, ctx, i, scalars| { - let value = mapping(generator, ctx, i, scalars[0])?; - Ok(ScalarObject { dtype: ret_dtype, value }) - }, + |generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]), ) } } impl<'ctx> ScalarOrNDArray<'ctx> { /// TODO: Document me. Has complex behavior. + /// and explain why `ret_dtype` has to be specified beforehand. pub fn broadcasting_starmap<'a, G, MappingFn>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, inputs: &[Self], ret_dtype: Type, - name: &str, mapping: MappingFn, ) -> Result where @@ -114,7 +109,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> { &mut CodeGenContext<'ctx, 'a>, Int<'ctx, SizeT>, &[ScalarObject<'ctx>], - ) -> Result, String>, + ) -> Result, String>, { let sizet_model = IntModel(SizeT); @@ -124,15 +119,40 @@ impl<'ctx> ScalarOrNDArray<'ctx> { if let Some(scalars) = all_scalars { let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index - let scalar = mapping(generator, ctx, i, &scalars)?; + let scalar = + ScalarObject { value: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype }; Ok(ScalarOrNDArray::Scalar(scalar)) } else { // Promote all input to ndarrays and map through them. let inputs = inputs.iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec(); - let ndarray = NDArrayObject::broadcasting_starmap( - generator, ctx, &inputs, ret_dtype, name, mapping, - )?; + let ndarray = + NDArrayObject::broadcasting_starmap(generator, ctx, &inputs, ret_dtype, mapping)?; Ok(ScalarOrNDArray::NDArray(ndarray)) } } + + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ret_dtype: Type, + mapping: Mapping, + ) -> Result + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + Int<'ctx, SizeT>, + ScalarObject<'ctx>, + ) -> Result, String>, + { + ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[*self], + ret_dtype, + |generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]), + ) + } } diff --git a/nac3core/src/codegen/structure/ndarray/mod.rs b/nac3core/src/codegen/structure/ndarray/mod.rs index eaf3de7d..2a0979ea 100644 --- a/nac3core/src/codegen/structure/ndarray/mod.rs +++ b/nac3core/src/codegen/structure/ndarray/mod.rs @@ -229,6 +229,26 @@ impl<'ctx> NDArrayObject<'ctx> { Self::alloca_uninitialized(generator, ctx, dtype, ndims, name) } + /// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents + /// over. + /// + /// The new ndarray will own its data and will be C-contiguous. + pub fn make_clone( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: &str, + ) -> Self { + let clone = + NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, self.ndims, name); + + let shape = self.value.gep(ctx, |f| f.shape).load(generator, ctx, "shape"); + clone.copy_shape_from_array(generator, ctx, shape); + clone.create_data(generator, ctx); + clone.copy_data_from(generator, ctx, *self); + clone + } + /// Get this ndarray's `ndims` as an LLVM constant. pub fn get_ndims( &self, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index cebbebe9..aa4a7707 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -13,22 +13,28 @@ use strum::IntoEnumIterator; use crate::{ codegen::{ - builtin_fns, + builtin_fns::{self}, classes::{ProxyValue, RangeValue}, expr::destructure_range, - irrt::*, + extern_fns, + irrt::{self, *}, + llvm_intrinsics, model::Int32, numpy::*, numpy_new::{self, gen_ndarray_transpose}, stmt::exn_constructor, - structure::ndarray::NDArrayObject, + structure::ndarray::{ + functions::{FloorOrCeil, MinOrMax}, + scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray}, + NDArrayObject, + }, }, symbol_resolver::SymbolValue, - toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, - typecheck::{ - numpy::create_ndims, - typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, + toplevel::{ + helper::PrimDef, + numpy::{create_ndims, make_ndarray_ty}, }, + typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, }; use super::*; @@ -1053,16 +1059,34 @@ impl<'a> BuiltinBuilder<'a> { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let func = match prim { - PrimDef::FunInt32 => builtin_fns::call_int32, - PrimDef::FunInt64 => builtin_fns::call_int64, - PrimDef::FunUInt32 => builtin_fns::call_uint32, - PrimDef::FunUInt64 => builtin_fns::call_uint64, - PrimDef::FunFloat => builtin_fns::call_float, - PrimDef::FunBool => builtin_fns::call_bool, + let ret_dtype = match prim { + PrimDef::FunInt32 => ctx.primitives.int32, + PrimDef::FunInt64 => ctx.primitives.int64, + PrimDef::FunUInt32 => ctx.primitives.uint32, + PrimDef::FunUInt64 => ctx.primitives.uint64, + PrimDef::FunFloat => ctx.primitives.float, + PrimDef::FunBool => ctx.primitives.bool, _ => unreachable!(), }; - Ok(Some(func(generator, ctx, (arg_ty, arg))?)) + + let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map( + generator, + ctx, + ret_dtype, + |generator, ctx, _i, scalar| { + let result = match prim { + PrimDef::FunInt32 => scalar.cast_to_int32(generator, ctx), + PrimDef::FunInt64 => scalar.cast_to_int64(generator, ctx), + PrimDef::FunUInt32 => scalar.cast_to_uint32(generator, ctx), + PrimDef::FunUInt64 => scalar.cast_to_uint64(generator, ctx), + PrimDef::FunFloat => scalar.cast_to_float(ctx), + PrimDef::FunBool => scalar.cast_to_bool(ctx), + _ => unreachable!(), + }; + Ok(result.value) + }, + )?; + Ok(Some(result.to_basic_value_enum())) }, )))), loc: None, @@ -1113,20 +1137,23 @@ impl<'a> BuiltinBuilder<'a> { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let ret_elem_ty = size_variant.of_int(&ctx.primitives); - Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) + let ret_int_dtype = size_variant.of_int(&ctx.primitives); + + let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map( + generator, + ctx, + ret_int_dtype, + |generator, ctx, _i, scalar| { + Ok(scalar.round(generator, ctx, ret_int_dtype).value) + }, + )?; + Ok(Some(result.to_basic_value_enum())) }), ) } /// Build the functions `ceil()` and `floor()` and their 64 bit variants. fn build_ceil_floor_function(&mut self, prim: PrimDef) -> TopLevelDef { - #[derive(Clone, Copy)] - enum Kind { - Floor, - Ceil, - } - debug_assert_prim_is_allowed( prim, &[PrimDef::FunFloor, PrimDef::FunFloor64, PrimDef::FunCeil, PrimDef::FunCeil64], @@ -1134,10 +1161,10 @@ impl<'a> BuiltinBuilder<'a> { let (size_variant, kind) = { match prim { - PrimDef::FunFloor => (SizeVariant::Bits32, Kind::Floor), - PrimDef::FunFloor64 => (SizeVariant::Bits64, Kind::Floor), - PrimDef::FunCeil => (SizeVariant::Bits32, Kind::Ceil), - PrimDef::FunCeil64 => (SizeVariant::Bits64, Kind::Ceil), + PrimDef::FunFloor => (SizeVariant::Bits32, FloorOrCeil::Floor), + PrimDef::FunFloor64 => (SizeVariant::Bits64, FloorOrCeil::Floor), + PrimDef::FunCeil => (SizeVariant::Bits32, FloorOrCeil::Ceil), + PrimDef::FunCeil64 => (SizeVariant::Bits64, FloorOrCeil::Ceil), _ => unreachable!(), } }; @@ -1177,12 +1204,15 @@ impl<'a> BuiltinBuilder<'a> { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let ret_elem_ty = size_variant.of_int(&ctx.primitives); - let func = match kind { - Kind::Ceil => builtin_fns::call_ceil, - Kind::Floor => builtin_fns::call_floor, - }; - Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) + let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map( + generator, + ctx, + int_sized, + |generator, ctx, _i, scalar| { + Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value) + }, + )?; + Ok(Some(result.to_basic_value_enum())) }), ) } @@ -1546,12 +1576,22 @@ impl<'a> BuiltinBuilder<'a> { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let func = match prim { - PrimDef::FunNpCeil => builtin_fns::call_ceil, - PrimDef::FunNpFloor => builtin_fns::call_floor, + let kind = match prim { + PrimDef::FunNpFloor => FloorOrCeil::Floor, + PrimDef::FunNpCeil => FloorOrCeil::Ceil, _ => unreachable!(), }; - Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) + + let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map( + generator, + ctx, + ctx.primitives.float, + move |_generator, ctx, _i, scalar| { + let result = scalar.np_floor_or_ceil(ctx, kind); + Ok(result.value) + }, + )?; + Ok(Some(result.to_basic_value_enum())) }), ) } @@ -1569,7 +1609,17 @@ impl<'a> BuiltinBuilder<'a> { Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_numpy_round(generator, ctx, (arg_ty, arg))?)) + + let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map( + generator, + ctx, + ctx.primitives.float, + |_generator, ctx, _i, scalar| { + let result = scalar.np_round(ctx); + Ok(result.value) + }, + )?; + Ok(Some(result.to_basic_value_enum())) }), ) } @@ -1678,16 +1728,21 @@ impl<'a> BuiltinBuilder<'a> { codegen_callback: Some(Arc::new(GenCall::new(Box::new( move |ctx, _, fun, args, generator| { let m_ty = fun.0.args[0].ty; - let n_ty = fun.0.args[1].ty; let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?; + + let n_ty = fun.0.args[1].ty; let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; - let func = match prim { - PrimDef::FunMin => builtin_fns::call_min, - PrimDef::FunMax => builtin_fns::call_max, + let kind = match prim { + PrimDef::FunMin => MinOrMax::Min, + PrimDef::FunMax => MinOrMax::Max, _ => unreachable!(), }; - Ok(Some(func(ctx, (m_ty, m_val), (n_ty, n_val)))) + + let m = ScalarObject { dtype: m_ty, value: m_val }; + let n = ScalarObject { dtype: n_ty, value: n_val }; + let result = ScalarObject::min_or_max(ctx, kind, m, n); + Ok(Some(result.value)) }, )))), loc: None, @@ -1729,7 +1784,25 @@ impl<'a> BuiltinBuilder<'a> { let a_ty = fun.0.args[0].ty; let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; - Ok(Some(builtin_fns::call_numpy_max_min(generator, ctx, (a_ty, a), prim.name())?)) + let a = split_scalar_or_ndarray(generator, ctx, a, a_ty).as_ndarray(generator, ctx); + let result = match prim { + PrimDef::FunNpArgmin => a + .argmin_or_argmax(generator, ctx, MinOrMax::Min) + .value + .as_basic_value_enum(), + PrimDef::FunNpArgmax => a + .argmin_or_argmax(generator, ctx, MinOrMax::Max) + .value + .as_basic_value_enum(), + PrimDef::FunNpMin => { + a.min_or_max(generator, ctx, MinOrMax::Min).value.as_basic_value_enum() + } + PrimDef::FunNpMax => { + a.min_or_max(generator, ctx, MinOrMax::Max).value.as_basic_value_enum() + } + _ => unreachable!(), + }; + Ok(Some(result)) }), ) } @@ -1764,13 +1837,32 @@ impl<'a> BuiltinBuilder<'a> { let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - let func = match prim { - PrimDef::FunNpMinimum => builtin_fns::call_numpy_minimum, - PrimDef::FunNpMaximum => builtin_fns::call_numpy_maximum, + let kind = match prim { + PrimDef::FunNpMinimum => MinOrMax::Min, + PrimDef::FunNpMaximum => MinOrMax::Max, _ => unreachable!(), }; - Ok(Some(func(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + let x1 = split_scalar_or_ndarray(generator, ctx, x1_val, x1_ty); + let x2 = split_scalar_or_ndarray(generator, ctx, x2_val, x2_ty); + + // NOTE: x1.dtype() and x2.dtype() should be the same + let common_ty = x1.dtype(); + + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + common_ty, + |_generator, ctx, _i, scalars| { + let x1 = scalars[0]; + let x2 = scalars[1]; + + let result = ScalarObject::min_or_max(ctx, kind, x1, x2); + Ok(result.value) + }, + )?; + Ok(Some(result.to_basic_value_enum())) }, )))), loc: None, @@ -1781,6 +1873,7 @@ impl<'a> BuiltinBuilder<'a> { fn build_abs_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunAbs; + let num_ty = self.num_ty; // To move into codegen_callback TopLevelDef::Function { name: prim.name().into(), simple_name: prim.simple_name().into(), @@ -1798,11 +1891,17 @@ impl<'a> BuiltinBuilder<'a> { instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { + move |ctx, _, fun, args, generator| { let n_ty = fun.0.args[0].ty; let n_val = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; - Ok(Some(builtin_fns::call_abs(generator, ctx, (n_ty, n_val))?)) + let result = split_scalar_or_ndarray(generator, ctx, n_val, n_ty).map( + generator, + ctx, + num_ty.ty, + |_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).value), + )?; + Ok(Some(result.to_basic_value_enum())) }, )))), loc: None, @@ -1825,13 +1924,23 @@ impl<'a> BuiltinBuilder<'a> { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - let func = match prim { - PrimDef::FunNpIsInf => builtin_fns::call_numpy_isinf, - PrimDef::FunNpIsNan => builtin_fns::call_numpy_isnan, + let function = match prim { + PrimDef::FunNpIsInf => irrt::call_isnan, + PrimDef::FunNpIsNan => irrt::call_isinf, _ => unreachable!(), }; - Ok(Some(func(generator, ctx, (x_ty, x_val))?)) + let result = split_scalar_or_ndarray(generator, ctx, x_val, x_ty).map( + generator, + ctx, + ctx.primitives.bool, + |generator, ctx, _i, scalar| { + let n = scalar.into_float64(ctx); + let n = function(generator, ctx, n); + Ok(n.as_basic_value_enum()) + }, + )?; + Ok(Some(result.to_basic_value_enum())) }), ) } @@ -1889,49 +1998,58 @@ impl<'a> BuiltinBuilder<'a> { let arg_ty = fun.0.args[0].ty; let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let func = match prim { - PrimDef::FunNpSin => builtin_fns::call_numpy_sin, - PrimDef::FunNpCos => builtin_fns::call_numpy_cos, - PrimDef::FunNpTan => builtin_fns::call_numpy_tan, + let result = split_scalar_or_ndarray(generator, ctx, arg_val, arg_ty).map( + generator, + ctx, + ctx.primitives.float, + |_generator, ctx, _i, scalar| { + let n = scalar.into_float64(ctx); + let n = match prim { + PrimDef::FunNpSin => llvm_intrinsics::call_float_sin(ctx, n, None), + PrimDef::FunNpCos => llvm_intrinsics::call_float_cos(ctx, n, None), + PrimDef::FunNpTan => extern_fns::call_tan(ctx, n, None), - PrimDef::FunNpArcsin => builtin_fns::call_numpy_arcsin, - PrimDef::FunNpArccos => builtin_fns::call_numpy_arccos, - PrimDef::FunNpArctan => builtin_fns::call_numpy_arctan, + PrimDef::FunNpArcsin => extern_fns::call_asin(ctx, n, None), + PrimDef::FunNpArccos => extern_fns::call_acos(ctx, n, None), + PrimDef::FunNpArctan => extern_fns::call_atan(ctx, n, None), - PrimDef::FunNpSinh => builtin_fns::call_numpy_sinh, - PrimDef::FunNpCosh => builtin_fns::call_numpy_cosh, - PrimDef::FunNpTanh => builtin_fns::call_numpy_tanh, + PrimDef::FunNpSinh => extern_fns::call_sinh(ctx, n, None), + PrimDef::FunNpCosh => extern_fns::call_cosh(ctx, n, None), + PrimDef::FunNpTanh => extern_fns::call_tanh(ctx, n, None), - PrimDef::FunNpArcsinh => builtin_fns::call_numpy_arcsinh, - PrimDef::FunNpArccosh => builtin_fns::call_numpy_arccosh, - PrimDef::FunNpArctanh => builtin_fns::call_numpy_arctanh, + PrimDef::FunNpArcsinh => extern_fns::call_asinh(ctx, n, None), + PrimDef::FunNpArccosh => extern_fns::call_acosh(ctx, n, None), + PrimDef::FunNpArctanh => extern_fns::call_atanh(ctx, n, None), - PrimDef::FunNpExp => builtin_fns::call_numpy_exp, - PrimDef::FunNpExp2 => builtin_fns::call_numpy_exp2, - PrimDef::FunNpExpm1 => builtin_fns::call_numpy_expm1, + PrimDef::FunNpExp => llvm_intrinsics::call_float_exp(ctx, n, None), + PrimDef::FunNpExp2 => llvm_intrinsics::call_float_exp2(ctx, n, None), + PrimDef::FunNpExpm1 => extern_fns::call_expm1(ctx, n, None), - PrimDef::FunNpLog => builtin_fns::call_numpy_log, - PrimDef::FunNpLog2 => builtin_fns::call_numpy_log2, - PrimDef::FunNpLog10 => builtin_fns::call_numpy_log10, + PrimDef::FunNpLog => llvm_intrinsics::call_float_log(ctx, n, None), + PrimDef::FunNpLog2 => llvm_intrinsics::call_float_log2(ctx, n, None), + PrimDef::FunNpLog10 => llvm_intrinsics::call_float_log10(ctx, n, None), - PrimDef::FunNpSqrt => builtin_fns::call_numpy_sqrt, - PrimDef::FunNpCbrt => builtin_fns::call_numpy_cbrt, + PrimDef::FunNpSqrt => llvm_intrinsics::call_float_sqrt(ctx, n, None), + PrimDef::FunNpCbrt => extern_fns::call_cbrt(ctx, n, None), - PrimDef::FunNpFabs => builtin_fns::call_numpy_fabs, - PrimDef::FunNpRint => builtin_fns::call_numpy_rint, + PrimDef::FunNpFabs => llvm_intrinsics::call_float_fabs(ctx, n, None), + PrimDef::FunNpRint => llvm_intrinsics::call_float_rint(ctx, n, None), - PrimDef::FunSpSpecErf => builtin_fns::call_scipy_special_erf, - PrimDef::FunSpSpecErfc => builtin_fns::call_scipy_special_erfc, + PrimDef::FunSpSpecErf => extern_fns::call_erf(ctx, n, None), + PrimDef::FunSpSpecErfc => extern_fns::call_erfc(ctx, n, None), - PrimDef::FunSpSpecGamma => builtin_fns::call_scipy_special_gamma, - PrimDef::FunSpSpecGammaln => builtin_fns::call_scipy_special_gammaln, + PrimDef::FunSpSpecGamma => irrt::call_gamma(ctx, n), + PrimDef::FunSpSpecGammaln => irrt::call_gammaln(ctx, n), - PrimDef::FunSpSpecJ0 => builtin_fns::call_scipy_special_j0, - PrimDef::FunSpSpecJ1 => builtin_fns::call_scipy_special_j1, + PrimDef::FunSpSpecJ0 => irrt::call_j0(ctx, n), + PrimDef::FunSpSpecJ1 => extern_fns::call_j1(ctx, n, None), - _ => unreachable!(), - }; - Ok(Some(func(generator, ctx, (arg_ty, arg_val))?)) + _ => unreachable!(), + }; + Ok(n.as_basic_value_enum()) + }, + )?; + Ok(Some(result.to_basic_value_enum())) }), ) } @@ -1953,20 +2071,20 @@ impl<'a> BuiltinBuilder<'a> { let PrimitiveStore { float, int32, .. } = *self.primitives; - // The argument types of the two input arguments are controlled here. - let (x1_ty, x2_ty) = match prim { + // The argument types of the two input arguments + the return type + let (x1_dtype, x2_dtype, ret_dtype) = match prim { PrimDef::FunNpArctan2 | PrimDef::FunNpCopysign | PrimDef::FunNpFmax | PrimDef::FunNpFmin | PrimDef::FunNpHypot - | PrimDef::FunNpNextAfter => (float, float), - PrimDef::FunNpLdExp => (float, int32), + | PrimDef::FunNpNextAfter => (float, float, float), + PrimDef::FunNpLdExp => (float, int32, float), _ => unreachable!(), }; - let x1_ty = self.new_type_or_ndarray_ty(x1_ty); - let x2_ty = self.new_type_or_ndarray_ty(x2_ty); + let x1_ty = self.new_type_or_ndarray_ty(x1_dtype); + let x2_ty = self.new_type_or_ndarray_ty(x2_dtype); let param_ty = &[(x1_ty.ty, "x1"), (x2_ty.ty, "x2")]; let ret_ty = self.unifier.get_fresh_var(None, None); @@ -1990,21 +2108,46 @@ impl<'a> BuiltinBuilder<'a> { move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - let func = match prim { - PrimDef::FunNpArctan2 => builtin_fns::call_numpy_arctan2, - PrimDef::FunNpCopysign => builtin_fns::call_numpy_copysign, - PrimDef::FunNpFmax => builtin_fns::call_numpy_fmax, - PrimDef::FunNpFmin => builtin_fns::call_numpy_fmin, - PrimDef::FunNpLdExp => builtin_fns::call_numpy_ldexp, - PrimDef::FunNpHypot => builtin_fns::call_numpy_hypot, - PrimDef::FunNpNextAfter => builtin_fns::call_numpy_nextafter, - _ => unreachable!(), - }; + let x1 = split_scalar_or_ndarray(generator, ctx, x1_val, x1_ty); + let x2 = split_scalar_or_ndarray(generator, ctx, x2_val, x2_ty); - Ok(Some(func(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ret_dtype, + |_generator, ctx, _i, scalars| { + let x1 = scalars[0]; + let x2 = scalars[1]; + + let result = match prim { + PrimDef::FunNpArctan2 + | PrimDef::FunNpCopysign + | PrimDef::FunNpFmax + | PrimDef::FunNpFmin + | PrimDef::FunNpHypot + | PrimDef::FunNpNextAfter => { + let x1 = x1.into_float64(ctx); + let x2 = x2.into_float64(ctx); + extern_fns::call_atan2(ctx, x1, x2, None).as_basic_value_enum() + } + PrimDef::FunNpLdExp => { + let x1 = x1.into_float64(ctx); + let x2 = x2.into_int32(ctx); + extern_fns::call_ldexp(ctx, x1, x2, None).as_basic_value_enum() + } + _ => unreachable!(), + }; + + Ok(result) + }, + )?; + + Ok(Some(result.to_basic_value_enum())) }, )))), loc: None, diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index 63f6173d..015b4eac 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -1,4 +1,7 @@ +use std::sync::Arc; + use crate::{ + symbol_resolver::SymbolValue, toplevel::helper::PrimDef, typecheck::{ type_inferencer::PrimitiveStore, @@ -83,3 +86,33 @@ pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarI pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) { unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap() } + +/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible. +/// The `ndims` must only contain 1 value. +#[must_use] +pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 { + let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty); + let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else { + panic!("ndims_ty should be a TLiteral"); + }; + + assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value"); + + let ndims = values[0].clone(); + u64::try_from(ndims).unwrap() +} + +/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value. +pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type { + unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None) +} + +/// Return the ndims after broadcasting ndarrays of different ndims. +/// +/// Panics if the input list is empty. +pub fn get_broadcast_all_ndims(ndims: I) -> u64 +where + I: IntoIterator, +{ + ndims.into_iter().max().unwrap() +} diff --git a/nac3core/src/typecheck/mod.rs b/nac3core/src/typecheck/mod.rs index ddd06fea..4cac1bf0 100644 --- a/nac3core/src/typecheck/mod.rs +++ b/nac3core/src/typecheck/mod.rs @@ -1,6 +1,5 @@ mod function_check; pub mod magic_methods; -pub mod numpy; pub mod type_error; pub mod type_inferencer; pub mod typedef; diff --git a/nac3core/src/typecheck/numpy.rs b/nac3core/src/typecheck/numpy.rs deleted file mode 100644 index 3a5a9654..00000000 --- a/nac3core/src/typecheck/numpy.rs +++ /dev/null @@ -1,33 +0,0 @@ -use crate::{symbol_resolver::SymbolValue, typecheck::typedef::TypeEnum}; - -use super::typedef::{Type, Unifier}; - -/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible. -/// The `ndims` must only contain 1 value. -#[must_use] -pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 { - let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty); - let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else { - panic!("ndims_ty should be a TLiteral"); - }; - - assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value"); - - let ndims = values[0].clone(); - u64::try_from(ndims).unwrap() -} - -/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value. -pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type { - unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None) -} - -/// Return the ndims after broadcasting ndarrays of different ndims. -/// -/// Panics if the input list is empty. -pub fn get_broadcast_all_ndims(ndims: I) -> u64 -where - I: IntoIterator, -{ - ndims.into_iter().max().unwrap() -} diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index eba87e10..54e29ee4 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -11,12 +11,11 @@ use super::{ RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap, }, }; -use crate::typecheck::numpy::extract_ndims; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef}, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, + numpy::{extract_ndims, make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelContext, TopLevelDef, }, typecheck::typedef::Mapping,