From 30c6cffbad77886178d39b630b905409fbb18f8c Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 25 Apr 2024 15:47:16 +0800 Subject: [PATCH] core/builtins: Refactored numpy builtins to accept scalar and ndarrays --- nac3core/src/codegen/builtin_fns.rs | 2240 +++++++++++++---- nac3core/src/codegen/numpy.rs | 2 - nac3core/src/toplevel/builtins.rs | 1264 ++++++---- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/type_inferencer/mod.rs | 284 ++- nac3standalone/demo/interpret_demo.py | 45 +- nac3standalone/demo/src/ndarray.py | 687 +++++ 11 files changed, 3521 insertions(+), 1015 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 49ecdda05..4ff40c011 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,9 +1,13 @@ use inkwell::{FloatPredicate, IntPredicate}; -use inkwell::types::{BasicTypeEnum, IntType}; -use inkwell::values::{BasicValueEnum, FloatValue, IntValue}; +use inkwell::types::BasicTypeEnum; +use inkwell::values::BasicValueEnum; use itertools::Itertools; -use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics}; +use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics, numpy}; +use crate::codegen::classes::NDArrayValue; +use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; +use crate::toplevel::helper::PRIMITIVE_DEF_IDS; +use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::Type; /// Shorthand for [`unreachable!()`] when a type of argument is not supported. @@ -21,69 +25,93 @@ fn unsupported_type( } /// Invokes the `int32` builtin function. -pub fn call_int32<'ctx>( +pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), -) -> IntValue<'ctx> { +) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; - match n.get_type() { - BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8) => { + Ok(match n { + BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); ctx.builder - .build_int_z_extend(n.into_int_value(), llvm_i32, "zext") + .build_int_z_extend(n, llvm_i32, "zext") + .map(Into::into) .unwrap() } - BasicTypeEnum::IntType(int_ty) if int_ty.get_bit_width() == 32 => { + BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 32 => { debug_assert!([ ctx.primitives.int32, ctx.primitives.uint32, ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); - n.into_int_value() + n.into() } - BasicTypeEnum::IntType(int_ty) if int_ty.get_bit_width() == 64 => { + BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { debug_assert!([ ctx.primitives.int64, ctx.primitives.uint64, ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); ctx.builder - .build_int_truncate(n.into_int_value(), llvm_i32, "trunc") + .build_int_truncate(n, llvm_i32, "trunc") + .map(Into::into) .unwrap() } - BasicTypeEnum::FloatType(_) => { + BasicValueEnum::FloatValue(n) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); let to_int64 = ctx.builder - .build_float_to_signed_int(n.into_float_value(), ctx.ctx.i64_type(), "") + .build_float_to_signed_int(n, ctx.ctx.i64_type(), "") .unwrap(); ctx.builder .build_int_truncate(to_int64, llvm_i32, "conv") + .map(Into::into) .unwrap() } + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.int32, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| { + call_int32(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + _ => unsupported_type(ctx, "int32", &[n_ty]) - } + }) } /// Invokes the `int64` builtin function. -pub fn call_int64<'ctx>( +pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), -) -> IntValue<'ctx> { +) -> Result, String> { let llvm_i64 = ctx.ctx.i64_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; - match n.get_type() { - BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8 | 32) => { + Ok(match n { + BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => { debug_assert!([ ctx.primitives.bool, ctx.primitives.int32, @@ -92,115 +120,156 @@ pub fn call_int64<'ctx>( if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { ctx.builder - .build_int_s_extend(n.into_int_value(), llvm_i64, "sext") + .build_int_s_extend(n, llvm_i64, "sext") + .map(Into::into) .unwrap() } else { ctx.builder - .build_int_z_extend(n.into_int_value(), llvm_i64, "zext") + .build_int_z_extend(n, llvm_i64, "zext") + .map(Into::into) .unwrap() } } - BasicTypeEnum::IntType(int_ty) if int_ty.get_bit_width() == 64 => { + BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { debug_assert!([ ctx.primitives.int64, ctx.primitives.uint64, ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); - n.into_int_value() + n.into() } - BasicTypeEnum::FloatType(_) => { + BasicValueEnum::FloatValue(n) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); ctx.builder - .build_float_to_signed_int(n.into_float_value(), ctx.ctx.i64_type(), "fptosi") + .build_float_to_signed_int(n, ctx.ctx.i64_type(), "fptosi") + .map(Into::into) .unwrap() } + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.int64, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| { + call_int64(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + _ => unsupported_type(ctx, "int64", &[n_ty]) - } + }) } /// Invokes the `uint32` builtin function. -pub fn call_uint32<'ctx>( +pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), -) -> IntValue<'ctx> { +) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; - match n.get_type() { - BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8) => { + Ok(match n { + BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); ctx.builder - .build_int_z_extend(n.into_int_value(), llvm_i32, "zext") + .build_int_z_extend(n, llvm_i32, "zext") + .map(Into::into) .unwrap() } - BasicTypeEnum::IntType(int_ty) if int_ty.get_bit_width() == 32 => { + BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 32 => { debug_assert!([ ctx.primitives.int32, ctx.primitives.uint32, ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); - n.into_int_value() + n.into() } - BasicTypeEnum::IntType(int_ty) if int_ty.get_bit_width() == 64 => { + BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { debug_assert!( ctx.unifier.unioned(n_ty, ctx.primitives.int64) || ctx.unifier.unioned(n_ty, ctx.primitives.uint64) ); ctx.builder - .build_int_truncate(n.into_int_value(), llvm_i32, "trunc") + .build_int_truncate(n, llvm_i32, "trunc") + .map(Into::into) .unwrap() } - BasicTypeEnum::FloatType(_) => { + BasicValueEnum::FloatValue(n) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - let val = n.into_float_value(); - let val_gez = ctx.builder - .build_float_compare(FloatPredicate::OGE, val, val.get_type().const_zero(), "") + let n_gez = ctx.builder + .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") .unwrap(); let to_int32 = ctx.builder - .build_float_to_signed_int(val, llvm_i32, "") + .build_float_to_signed_int(n, llvm_i32, "") .unwrap(); let to_uint64 = ctx.builder - .build_float_to_unsigned_int(val, ctx.ctx.i64_type(), "") + .build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "") .unwrap(); ctx.builder .build_select( - val_gez, + n_gez, ctx.builder.build_int_truncate(to_uint64, llvm_i32, "").unwrap(), to_int32, "conv", ) - .map(BasicValueEnum::into_int_value) .unwrap() } + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.uint32, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| { + call_uint32(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + _ => unsupported_type(ctx, "uint32", &[n_ty]) - } + }) } /// Invokes the `uint64` builtin function. -pub fn call_uint64<'ctx>( +pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), -) -> IntValue<'ctx> { +) -> Result, String> { let llvm_i64 = ctx.ctx.i64_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; - match n.get_type() { - BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8 | 32) => { + Ok(match n { + BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => { debug_assert!([ ctx.primitives.bool, ctx.primitives.int32, @@ -209,60 +278,79 @@ pub fn call_uint64<'ctx>( if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { ctx.builder - .build_int_s_extend(n.into_int_value(), llvm_i64, "sext") + .build_int_s_extend(n, llvm_i64, "sext") + .map(Into::into) .unwrap() } else { ctx.builder - .build_int_z_extend(n.into_int_value(), llvm_i64, "zext") + .build_int_z_extend(n, llvm_i64, "zext") + .map(Into::into) .unwrap() } } - BasicTypeEnum::IntType(int_ty) if int_ty.get_bit_width() == 64 => { + BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { debug_assert!([ ctx.primitives.int64, ctx.primitives.uint64, ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); - n.into_int_value() + n.into() } - BasicTypeEnum::FloatType(_) => { + BasicValueEnum::FloatValue(n) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - let val = n.into_float_value(); let val_gez = ctx.builder - .build_float_compare(FloatPredicate::OGE, val, val.get_type().const_zero(), "") + .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") .unwrap(); let to_int64 = ctx.builder - .build_float_to_signed_int(val, llvm_i64, "") + .build_float_to_signed_int(n, llvm_i64, "") .unwrap(); let to_uint64 = ctx.builder - .build_float_to_unsigned_int(val, llvm_i64, "") + .build_float_to_unsigned_int(n, llvm_i64, "") .unwrap(); ctx.builder .build_select(val_gez, to_uint64, to_int64, "conv") - .map(BasicValueEnum::into_int_value) .unwrap() } + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.uint64, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| { + call_uint64(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + _ => unsupported_type(ctx, "uint64", &[n_ty]) - } + }) } /// Invokes the `float` builtin function. -pub fn call_float<'ctx>( +pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), -) -> FloatValue<'ctx> { +) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; - match n.get_type() { - BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8 | 32 | 64) => { + Ok(match n { + BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => { debug_assert!([ ctx.primitives.bool, ctx.primitives.int32, @@ -277,76 +365,150 @@ pub fn call_float<'ctx>( ctx.primitives.int64, ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)) { ctx.builder - .build_signed_int_to_float(n.into_int_value(), llvm_f64, "sitofp") + .build_signed_int_to_float(n, llvm_f64, "sitofp") + .map(Into::into) .unwrap() } else { ctx.builder - .build_unsigned_int_to_float(n.into_int_value(), llvm_f64, "uitofp") - .unwrap() + .build_unsigned_int_to_float(n, llvm_f64, "uitofp") + .map(Into::into) + .unwrap() } } - BasicTypeEnum::FloatType(_) => { + BasicValueEnum::FloatValue(n) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - n.into_float_value() + n.into() + } + + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.float, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| { + call_float(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() } _ => unsupported_type(ctx, "float", &[n_ty]) - } + }) } /// Invokes the `round` builtin function. -pub fn call_round<'ctx>( +pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, FloatValue<'ctx>), - llvm_ret_ty: IntType<'ctx>, -) -> IntValue<'ctx> { + n: (Type, BasicValueEnum<'ctx>), + ret_elem_ty: Type, +) -> Result, String> { const FN_NAME: &str = "round"; + let llvm_usize = generator.get_size_type(ctx.ctx); + let (n_ty, n) = n; + let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty).into_int_type(); - if !ctx.unifier.unioned(n_ty, ctx.primitives.float) { - unsupported_type(ctx, FN_NAME, &[n_ty]) - } + Ok(match n { + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - let val = llvm_intrinsics::call_float_round(ctx, n, None); - ctx.builder - .build_float_to_signed_int(val, llvm_ret_ty, FN_NAME) - .unwrap() + let val = llvm_intrinsics::call_float_round(ctx, n, None); + ctx.builder + .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) + .map(Into::into) + .unwrap() + } + + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ret_elem_ty, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| { + call_round(generator, ctx, (elem_ty, val), ret_elem_ty) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[n_ty]) + }) } /// Invokes the `np_round` builtin function. -pub fn call_numpy_round<'ctx>( +pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + n: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_round"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (n_ty, n) = n; - if !ctx.unifier.unioned(n_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_round", &[n_ty]) - } + Ok(match n { + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_roundeven(ctx, n, None) + llvm_intrinsics::call_float_roundeven(ctx, n, None).into() + } + + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.float, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| { + call_numpy_round(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[n_ty]) + }) } /// Invokes the `bool` builtin function. -pub fn call_bool<'ctx>( +pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), -) -> IntValue<'ctx> { +) -> Result, String> { const FN_NAME: &str = "bool"; + let llvm_usize = generator.get_size_type(ctx.ctx); + let (n_ty, n) = n; - match n.get_type() { - BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8) => { + Ok(match n { + BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); - n.into_int_value() + n.into() } - BasicTypeEnum::IntType(_) => { + BasicValueEnum::IntValue(n) => { debug_assert!([ ctx.primitives.int32, ctx.primitives.uint32, @@ -354,75 +516,146 @@ pub fn call_bool<'ctx>( ctx.primitives.uint64, ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); - let val = n.into_int_value(); ctx.builder - .build_int_compare(IntPredicate::NE, val, val.get_type().const_zero(), FN_NAME) + .build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME) + .map(Into::into) .unwrap() } - BasicTypeEnum::FloatType(_) => { + BasicValueEnum::FloatValue(n) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - let val = n.into_float_value(); ctx.builder - .build_float_compare(FloatPredicate::UNE, val, val.get_type().const_zero(), FN_NAME) + .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), FN_NAME) + .map(Into::into) .unwrap() } + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.bool, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| { + let elem = call_bool( + generator, + ctx, + (elem_ty, val), + )?; + + Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) + }, + )?; + + ndarray.as_ptr_value().into() + } + _ => unsupported_type(ctx, FN_NAME, &[n_ty]) - } + }) } /// Invokes the `floor` builtin function. -pub fn call_floor<'ctx>( +pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, FloatValue<'ctx>), - llvm_ret_ty: BasicTypeEnum<'ctx>, -) -> BasicValueEnum<'ctx> { + n: (Type, BasicValueEnum<'ctx>), + ret_elem_ty: Type, +) -> Result, String> { const FN_NAME: &str = "floor"; + let llvm_usize = generator.get_size_type(ctx.ctx); + let (n_ty, n) = n; - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); - let val = llvm_intrinsics::call_float_floor(ctx, n, None); - match llvm_ret_ty { - _ if llvm_ret_ty == val.get_type().into() => val.into(), + Ok(match n { + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - BasicTypeEnum::IntType(_) => { - ctx.builder - .build_float_to_signed_int(val, llvm_ret_ty.into_int_type(), FN_NAME) - .map(Into::into) - .unwrap() + let val = llvm_intrinsics::call_float_floor(ctx, n, None); + if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { + ctx.builder + .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) + .map(Into::into) + .unwrap() + } else { + val.into() + } + } + + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ret_elem_ty, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| { + call_floor(generator, ctx, (elem_ty, val), ret_elem_ty) + }, + )?; + + ndarray.as_ptr_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]) - } + }) } /// Invokes the `ceil` builtin function. -pub fn call_ceil<'ctx>( +pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, FloatValue<'ctx>), - llvm_ret_ty: BasicTypeEnum<'ctx>, -) -> BasicValueEnum<'ctx> { + n: (Type, BasicValueEnum<'ctx>), + ret_elem_ty: Type, +) -> Result, String> { const FN_NAME: &str = "ceil"; + let llvm_usize = generator.get_size_type(ctx.ctx); + let (n_ty, n) = n; - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); - let val = llvm_intrinsics::call_float_ceil(ctx, n, None); - match llvm_ret_ty { - _ if llvm_ret_ty == val.get_type().into() => val.into(), + Ok(match n { + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - BasicTypeEnum::IntType(_) => { - ctx.builder - .build_float_to_signed_int(val, llvm_ret_ty.into_int_type(), FN_NAME) - .map(Into::into) - .unwrap() + let val = llvm_intrinsics::call_float_ceil(ctx, n, None); + if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { + ctx.builder + .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) + .map(Into::into) + .unwrap() + } else { + val.into() + } + } + + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ret_elem_ty, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| { + call_floor(generator, ctx, (elem_ty, val), ret_elem_ty) + }, + )?; + + ndarray.as_ptr_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]) - } + }) } /// Invokes the `min` builtin function. @@ -436,16 +669,14 @@ pub fn call_min<'ctx>( let (m_ty, m) = m; let (n_ty, n) = n; - if !ctx.unifier.unioned(m_ty, n_ty) { + let common_ty = if ctx.unifier.unioned(m_ty, n_ty) { + m_ty + } else { unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]) - } - debug_assert_eq!(m.get_type(), n.get_type()); + }; - let common_ty = m_ty; - let llvm_common_ty = m.get_type(); - - match llvm_common_ty { - BasicTypeEnum::IntType(_) => { + match (m, n) { + (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { debug_assert!([ ctx.primitives.bool, ctx.primitives.int32, @@ -454,8 +685,6 @@ pub fn call_min<'ctx>( ctx.primitives.uint64, ].iter().any(|ty| ctx.unifier.unioned(common_ty, *ty))); - let (m, n) = (m.into_int_value(), n.into_int_value()); - if [ ctx.primitives.int32, ctx.primitives.int64, @@ -466,11 +695,9 @@ pub fn call_min<'ctx>( } } - BasicTypeEnum::FloatType(_) => { + (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float)); - let (m, n) = (m.into_float_value(), n.into_float_value()); - llvm_intrinsics::call_float_minnum(ctx, m, n, Some(FN_NAME)).into() } @@ -489,16 +716,14 @@ pub fn call_max<'ctx>( let (m_ty, m) = m; let (n_ty, n) = n; - if !ctx.unifier.unioned(m_ty, n_ty) { + let common_ty = if ctx.unifier.unioned(m_ty, n_ty) { + m_ty + } else { unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]) - } - debug_assert_eq!(m.get_type(), n.get_type()); + }; - let common_ty = m_ty; - let llvm_common_ty = m.get_type(); - - match llvm_common_ty { - BasicTypeEnum::IntType(_) => { + match (m, n) { + (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { debug_assert!([ ctx.primitives.bool, ctx.primitives.int32, @@ -507,8 +732,6 @@ pub fn call_max<'ctx>( ctx.primitives.uint64, ].iter().any(|ty| ctx.unifier.unioned(common_ty, *ty))); - let (m, n) = (m.into_int_value(), n.into_int_value()); - if [ ctx.primitives.int32, ctx.primitives.int64, @@ -519,11 +742,9 @@ pub fn call_max<'ctx>( } } - BasicTypeEnum::FloatType(_) => { + (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float)); - let (m, n) = (m.into_float_value(), n.into_float_value()); - llvm_intrinsics::call_float_maxnum(ctx, m, n, Some(FN_NAME)).into() } @@ -532,18 +753,20 @@ pub fn call_max<'ctx>( } /// Invokes the `abs` builtin function. -pub fn call_abs<'ctx>( +pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), -) -> BasicValueEnum<'ctx> { +) -> Result, String> { const FN_NAME: &str = "abs"; let llvm_i1 = ctx.ctx.bool_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; - match n.get_type() { - BasicTypeEnum::IntType(_) => { + Ok(match n { + BasicValueEnum::IntValue(n) => { debug_assert!([ ctx.primitives.bool, ctx.primitives.int32, @@ -558,573 +781,1700 @@ pub fn call_abs<'ctx>( ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)) { llvm_intrinsics::call_int_abs( ctx, - n.into_int_value(), + n, llvm_i1.const_zero(), Some(FN_NAME), ).into() } else { - n + n.into() } } - BasicTypeEnum::FloatType(_) => { + BasicValueEnum::FloatValue(n) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_fabs(ctx, n.into_float_value(), Some(FN_NAME)).into() + llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into() + } + + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(n, llvm_usize, None), + |generator, ctx, val| { + call_abs(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]) - } + }) } /// Invokes the `np_isnan` builtin function. pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> IntValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_isnan"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_isnan", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - irrt::call_isnan(generator, ctx, x) + irrt::call_isnan(generator, ctx, x).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.bool, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + let val = call_numpy_isnan(generator, ctx, (elem_ty, val))?; + + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_isinf` builtin function. pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> IntValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_isinf"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_isinf", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - irrt::call_isinf(generator, ctx, x) + irrt::call_isinf(generator, ctx, x).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.bool, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + let val = call_numpy_isinf(generator, ctx, (elem_ty, val))?; + + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_sin` builtin function. -pub fn call_numpy_sin<'ctx>( +pub fn call_numpy_sin<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_sin"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_sin", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_sin(ctx, x, None) + llvm_intrinsics::call_float_sin(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_sin(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_cos` builtin function. -pub fn call_numpy_cos<'ctx>( +pub fn call_numpy_cos<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_cos"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_cos", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_cos(ctx, x, None) + llvm_intrinsics::call_float_cos(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_cos(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_exp` builtin function. -pub fn call_numpy_exp<'ctx>( +pub fn call_numpy_exp<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_exp"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_exp", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_exp(ctx, x, None) + llvm_intrinsics::call_float_exp(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_exp(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_exp2` builtin function. -pub fn call_numpy_exp2<'ctx>( +pub fn call_numpy_exp2<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_exp2"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_exp2", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_exp2(ctx, x, None) + llvm_intrinsics::call_float_exp2(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_exp2(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_log` builtin function. -pub fn call_numpy_log<'ctx>( +pub fn call_numpy_log<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_log"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_log", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_log(ctx, x, None) + llvm_intrinsics::call_float_log(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_log(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_log10` builtin function. -pub fn call_numpy_log10<'ctx>( +pub fn call_numpy_log10<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_log10"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_log10", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_log10(ctx, x, None) + llvm_intrinsics::call_float_log10(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_log10(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_log2` builtin function. -pub fn call_numpy_log2<'ctx>( +pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_log2"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_log2", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_log2(ctx, x, None) + llvm_intrinsics::call_float_log2(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_log2(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) +} + +/// Invokes the `np_fabs` builtin function. +pub fn call_numpy_fabs<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_fabs"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (x_ty, x) = x; + + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); + + llvm_intrinsics::call_float_fabs(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_fabs(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_sqrt` builtin function. -pub fn call_numpy_fabs<'ctx>( +pub fn call_numpy_sqrt<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_sqrt"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_fabs", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_fabs(ctx, x, None) -} + llvm_intrinsics::call_float_sqrt(ctx, x, None).into() + } -/// Invokes the `np_sqrt` builtin function. -pub fn call_numpy_sqrt<'ctx>( - ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { - let (x_ty, x) = x; + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_sqrt", &[x_ty]) - } + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_sqrt(generator, ctx, (elem_ty, val)) + }, + )?; - llvm_intrinsics::call_float_sqrt(ctx, x, None) + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_rint` builtin function. -pub fn call_numpy_rint<'ctx>( +pub fn call_numpy_rint<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_rint"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_rint", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_roundeven(ctx, x, None) + llvm_intrinsics::call_float_roundeven(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_rint(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_tan` builtin function. -pub fn call_numpy_tan<'ctx>( +pub fn call_numpy_tan<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_tan"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_tan", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_tan(ctx, x, None) + extern_fns::call_tan(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_tan(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_arcsin` builtin function. -pub fn call_numpy_arcsin<'ctx>( +pub fn call_numpy_arcsin<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_arcsin"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_arcsin", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_asin(ctx, x, None) + extern_fns::call_asin(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_arcsin(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_arccos` builtin function. -pub fn call_numpy_arccos<'ctx>( +pub fn call_numpy_arccos<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_arccos"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_arccos", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); + + extern_fns::call_acos(ctx, x, None).into() + } - extern_fns::call_acos(ctx, x, None) + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_arccos(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_arctan` builtin function. -pub fn call_numpy_arctan<'ctx>( +pub fn call_numpy_arctan<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_arctan"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_arctan", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_atan(ctx, x, None) + extern_fns::call_atan(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_arctan(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_sinh` builtin function. -pub fn call_numpy_sinh<'ctx>( +pub fn call_numpy_sinh<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_sinh"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_sinh", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_sinh(ctx, x, None) + extern_fns::call_sinh(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_sinh(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_cosh` builtin function. -pub fn call_numpy_cosh<'ctx>( +pub fn call_numpy_cosh<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_cosh"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_cosh", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_cosh(ctx, x, None) + extern_fns::call_cosh(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_cosh(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_tanh` builtin function. -pub fn call_numpy_tanh<'ctx>( +pub fn call_numpy_tanh<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_tanh"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_tanh", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_tanh(ctx, x, None) + extern_fns::call_tanh(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_tanh(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } -/// Invokes the `np_asinh` builtin function. -pub fn call_numpy_asinh<'ctx>( +/// Invokes the `np_arcsinh` builtin function. +pub fn call_numpy_arcsinh<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_arcsinh"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_asinh", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_asinh(ctx, x, None) + extern_fns::call_asinh(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_arcsinh(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } -/// Invokes the `np_acosh` builtin function. -pub fn call_numpy_acosh<'ctx>( +/// Invokes the `np_arccosh` builtin function. +pub fn call_numpy_arccosh<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_arccosh"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_acosh", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_acosh(ctx, x, None) + extern_fns::call_acosh(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_arccosh(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } -/// Invokes the `np_atanh` builtin function. -pub fn call_numpy_atanh<'ctx>( +/// Invokes the `np_arctanh` builtin function. +pub fn call_numpy_arctanh<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_arctanh"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_atanh", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_atanh(ctx, x, None) + extern_fns::call_atanh(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_arctanh(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_expm1` builtin function. -pub fn call_numpy_expm1<'ctx>( +pub fn call_numpy_expm1<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_expm1"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_expm1", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_expm1(ctx, x, None) + extern_fns::call_expm1(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_expm1(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_cbrt` builtin function. -pub fn call_numpy_cbrt<'ctx>( +pub fn call_numpy_cbrt<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_cbrt"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_cbrt", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_cbrt(ctx, x, None) + extern_fns::call_cbrt(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_numpy_cbrt(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `sp_spec_erf` builtin function. -pub fn call_scipy_special_erf<'ctx>( +pub fn call_scipy_special_erf<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - z: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + z: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "sp_spec_erf"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (z_ty, z) = z; - if !ctx.unifier.unioned(z_ty, ctx.primitives.float) { - unsupported_type(ctx, "sp_spec_erf", &[z_ty]) - } + Ok(match z { + BasicValueEnum::FloatValue(z) => { + debug_assert!(ctx.unifier.unioned(z_ty, ctx.primitives.float)); - extern_fns::call_erf(ctx, z, None) + extern_fns::call_erf(ctx, z, None).into() + } + + BasicValueEnum::PointerValue(z) if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(z, llvm_usize, None), + |generator, ctx, val| { + call_scipy_special_erf(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[z_ty]) + }) } /// Invokes the `sp_spec_erfc` builtin function. -pub fn call_scipy_special_erfc<'ctx>( +pub fn call_scipy_special_erfc<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "sp_spec_erfc"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "sp_spec_erfc", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_erfc(ctx, x, None) + extern_fns::call_erfc(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_scipy_special_erfc(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `sp_spec_gamma` builtin function. -pub fn call_scipy_special_gamma<'ctx>( +pub fn call_scipy_special_gamma<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - z: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + z: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "sp_spec_gamma"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (z_ty, z) = z; - if !ctx.unifier.unioned(z_ty, ctx.primitives.float) { - unsupported_type(ctx, "sp_spec_gamma", &[z_ty]) - } + Ok(match z { + BasicValueEnum::FloatValue(z) => { + debug_assert!(ctx.unifier.unioned(z_ty, ctx.primitives.float)); - irrt::call_gamma(ctx, z) + irrt::call_gamma(ctx, z).into() + } + + BasicValueEnum::PointerValue(z) if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(z, llvm_usize, None), + |generator, ctx, val| { + call_scipy_special_gamma(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[z_ty]) + }) } /// Invokes the `sp_spec_gammaln` builtin function. -pub fn call_scipy_special_gammaln<'ctx>( +pub fn call_scipy_special_gammaln<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "sp_spec_gammaln"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "sp_spec_gammaln", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - irrt::call_gammaln(ctx, x) + irrt::call_gammaln(ctx, x).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_scipy_special_gammaln(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `sp_spec_j0` builtin function. -pub fn call_scipy_special_j0<'ctx>( +pub fn call_scipy_special_j0<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "sp_spec_j0"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "sp_spec_j0", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - irrt::call_j0(ctx, x) + irrt::call_j0(ctx, x).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_scipy_special_j0(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `sp_spec_j1` builtin function. -pub fn call_scipy_special_j1<'ctx>( +pub fn call_scipy_special_j1<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + x: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "sp_spec_j1"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (x_ty, x) = x; - if !ctx.unifier.unioned(x_ty, ctx.primitives.float) { - unsupported_type(ctx, "sp_spec_j1", &[x_ty]) - } + Ok(match x { + BasicValueEnum::FloatValue(x) => { + debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); - extern_fns::call_j1(ctx, x, None) + extern_fns::call_j1(ctx, x, None).into() + } + + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + elem_ty, + None, + NDArrayValue::from_ptr_val(x, llvm_usize, None), + |generator, ctx, val| { + call_scipy_special_j1(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x_ty]) + }) } /// Invokes the `np_arctan2` builtin function. -pub fn call_numpy_arctan2<'ctx>( +pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, FloatValue<'ctx>), - x2: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { - let float_t = ctx.primitives.float; + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_arctan2"; let (x1_ty, x1) = x1; let (x2_ty, x2) = x2; - if !ctx.unifier.unioned(x1_ty, float_t) || !ctx.unifier.unioned(x2_ty, float_t) { - unsupported_type(ctx, "np_atan2", &[x1_ty, x2_ty]) - } + Ok(match (x1, x2) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); + debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); - extern_fns::call_atan2(ctx, x1, x2, None) + extern_fns::call_atan2(ctx, x1, x2, None).into() + } + + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { + let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + let dtype = if is_ndarray1 && is_ndarray2 { + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + + debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + + ndarray_dtype1 + } else if is_ndarray1 { + unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + } else if is_ndarray2 { + unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + } else { unreachable!() }; + + let x1_scalar_ty = if is_ndarray1 { + dtype + } else { + x1_ty + }; + let x2_scalar_ty = if is_ndarray2 { + dtype + } else { + x2_ty + }; + + numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + dtype, + None, + (x1, !is_ndarray1), + (x2, !is_ndarray2), + |generator, ctx, (lhs, rhs)| { + call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + }, + )?.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + }) } /// Invokes the `np_copysign` builtin function. -pub fn call_numpy_copysign<'ctx>( +pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, FloatValue<'ctx>), - x2: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { - let float_t = ctx.primitives.float; + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_copysign"; let (x1_ty, x1) = x1; let (x2_ty, x2) = x2; - debug_assert_eq!(x1.get_type(), x2.get_type()); - if !ctx.unifier.unioned(x1_ty, float_t) || !ctx.unifier.unioned(x2_ty, float_t) { - unsupported_type(ctx, "np_copysign", &[x1_ty, x2_ty]) - } + Ok(match (x1, x2) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); + debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_copysign(ctx, x1, x2, None) + llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() + } + + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { + let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + let dtype = if is_ndarray1 && is_ndarray2 { + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + + debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + + ndarray_dtype1 + } else if is_ndarray1 { + unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + } else if is_ndarray2 { + unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + } else { unreachable!() }; + + let x1_scalar_ty = if is_ndarray1 { + dtype + } else { + x1_ty + }; + let x2_scalar_ty = if is_ndarray2 { + dtype + } else { + x2_ty + }; + + numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + dtype, + None, + (x1, !is_ndarray1), + (x2, !is_ndarray2), + |generator, ctx, (lhs, rhs)| { + call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + }, + )?.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + }) } /// Invokes the `np_fmax` builtin function. -pub fn call_numpy_fmax<'ctx>( +pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, FloatValue<'ctx>), - x2: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { - let float_t = ctx.primitives.float; + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_fmax"; let (x1_ty, x1) = x1; let (x2_ty, x2) = x2; - debug_assert_eq!(x1.get_type(), x2.get_type()); - if !ctx.unifier.unioned(x1_ty, float_t) || !ctx.unifier.unioned(x2_ty, float_t) { - unsupported_type(ctx, "np_fmax", &[x1_ty, x2_ty]) - } + Ok(match (x1, x2) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); + debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None) + llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() + } + + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { + let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + let dtype = if is_ndarray1 && is_ndarray2 { + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + + debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + + ndarray_dtype1 + } else if is_ndarray1 { + unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + } else if is_ndarray2 { + unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + } else { unreachable!() }; + + let x1_scalar_ty = if is_ndarray1 { + dtype + } else { + x1_ty + }; + let x2_scalar_ty = if is_ndarray2 { + dtype + } else { + x2_ty + }; + + numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + dtype, + None, + (x1, !is_ndarray1), + (x2, !is_ndarray2), + |generator, ctx, (lhs, rhs)| { + call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + }, + )?.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + }) } /// Invokes the `np_fmin` builtin function. -pub fn call_numpy_fmin<'ctx>( +pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, FloatValue<'ctx>), - x2: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { - let float_t = ctx.primitives.float; + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_fmin"; let (x1_ty, x1) = x1; let (x2_ty, x2) = x2; - debug_assert_eq!(x1.get_type(), x2.get_type()); - if !ctx.unifier.unioned(x1_ty, float_t) || !ctx.unifier.unioned(x2_ty, float_t) { - unsupported_type(ctx, "np_fmin", &[x1_ty, x2_ty]) - } + Ok(match (x1, x2) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); + debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_minnum(ctx, x1, x2, None) + llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() + } + + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { + let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + let dtype = if is_ndarray1 && is_ndarray2 { + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + + debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + + ndarray_dtype1 + } else if is_ndarray1 { + unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + } else if is_ndarray2 { + unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + } else { unreachable!() }; + + let x1_scalar_ty = if is_ndarray1 { + dtype + } else { + x1_ty + }; + let x2_scalar_ty = if is_ndarray2 { + dtype + } else { + x2_ty + }; + + numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + dtype, + None, + (x1, !is_ndarray1), + (x2, !is_ndarray2), + |generator, ctx, (lhs, rhs)| { + call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + }, + )?.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + }) } /// Invokes the `np_ldexp` builtin function. -pub fn call_numpy_ldexp<'ctx>( +pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, FloatValue<'ctx>), - x2: (Type, IntValue<'ctx>), -) -> FloatValue<'ctx> { + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_ldexp"; + let (x1_ty, x1) = x1; let (x2_ty, x2) = x2; - if !ctx.unifier.unioned(x1_ty, ctx.primitives.float) { - unsupported_type(ctx, "fp_ldexp", &[x1_ty, x2_ty]) - } - if !ctx.unifier.unioned(x2_ty, ctx.primitives.int32) { - unsupported_type(ctx, "fp_ldexp", &[x1_ty, x2_ty]) - } + Ok(match (x1, x2) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => { + debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); + debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32)); - extern_fns::call_ldexp(ctx, x1, x2, None) + extern_fns::call_ldexp(ctx, x1, x2, None).into() + } + + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { + let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + let dtype = if is_ndarray1 { + unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + } else { + x1_ty + }; + + let x1_scalar_ty = dtype; + let x2_scalar_ty = if is_ndarray2 { + unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + } else { + x2_ty + }; + + numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + dtype, + None, + (x1, !is_ndarray1), + (x2, !is_ndarray2), + |generator, ctx, (lhs, rhs)| { + call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + }, + )?.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + }) } /// Invokes the `np_hypot` builtin function. -pub fn call_numpy_hypot<'ctx>( +pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, FloatValue<'ctx>), - x2: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { - let float_t = ctx.primitives.float; + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_hypot"; let (x1_ty, x1) = x1; let (x2_ty, x2) = x2; - if !ctx.unifier.unioned(x1_ty, float_t) || !ctx.unifier.unioned(x2_ty, float_t) { - unsupported_type(ctx, "np_hypot", &[x1_ty, x2_ty]) - } + Ok(match (x1, x2) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); + debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); - extern_fns::call_hypot(ctx, x1, x2, None) + extern_fns::call_hypot(ctx, x1, x2, None).into() + } + + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { + let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + let dtype = if is_ndarray1 && is_ndarray2 { + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + + debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + + ndarray_dtype1 + } else if is_ndarray1 { + unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + } else if is_ndarray2 { + unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + } else { unreachable!() }; + + let x1_scalar_ty = if is_ndarray1 { + dtype + } else { + x1_ty + }; + let x2_scalar_ty = if is_ndarray2 { + dtype + } else { + x2_ty + }; + + numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + dtype, + None, + (x1, !is_ndarray1), + (x2, !is_ndarray2), + |generator, ctx, (lhs, rhs)| { + call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + }, + )?.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + }) } /// Invokes the `np_nextafter` builtin function. -pub fn call_numpy_nextafter<'ctx>( +pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, FloatValue<'ctx>), - x2: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { - let float_t = ctx.primitives.float; + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_nextafter"; let (x1_ty, x1) = x1; let (x2_ty, x2) = x2; - if !ctx.unifier.unioned(x1_ty, float_t) || !ctx.unifier.unioned(x2_ty, float_t) { - unsupported_type(ctx, "np_nextafter", &[x1_ty, x2_ty]) - } + Ok(match (x1, x2) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); + debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); - extern_fns::call_nextafter(ctx, x1, x2, None) + extern_fns::call_nextafter(ctx, x1, x2, None).into() + } + + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { + let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + let dtype = if is_ndarray1 && is_ndarray2 { + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + + debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + + ndarray_dtype1 + } else if is_ndarray1 { + unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + } else if is_ndarray2 { + unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + } else { unreachable!() }; + + let x1_scalar_ty = if is_ndarray1 { + dtype + } else { + x1_ty + }; + let x2_scalar_ty = if is_ndarray2 { + dtype + } else { + x2_ty + }; + + numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + dtype, + None, + (x1, !is_ndarray1), + (x2, !is_ndarray2), + |generator, ctx, (lhs, rhs)| { + call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + }, + )?.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + }) } \ No newline at end of file diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 62c763bcb..f22c721e0 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -451,8 +451,6 @@ fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>( } }; - debug_assert_eq!(lhs_elem.get_type(), rhs_elem.get_type()); - value_fn(generator, ctx, (lhs_elem, rhs_elem)) }, )?; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 5e6d56c8d..e057db963 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -282,7 +282,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built .. } = *primitives; - let ndarray_float = make_ndarray_ty(unifier, &primitives, Some(float), None); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None); let ndarray_float_2d = { let value = match primitives.size_t { 64 => SymbolValue::U64(2u64), @@ -294,7 +294,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built loc: None, }); - make_ndarray_ty(unifier, &primitives, Some(float), Some(ndims)) + make_ndarray_ty(unifier, primitives, Some(float), Some(ndims)) }; let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 }); let num_ty = unifier.get_fresh_var_with_range( @@ -302,8 +302,40 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built Some("N".into()), None, ); + let num_var_map: VarMap = vec![ + (num_ty.1, num_ty.0), + ].into_iter().collect(); + + let new_type_or_ndarray_ty = |unifier: &mut Unifier, primitives: &PrimitiveStore, scalar_ty: Type| { + let ndarray = make_ndarray_ty(unifier, primitives, Some(scalar_ty), None); + + unifier.get_fresh_var_with_range( + &[scalar_ty, ndarray], + Some("T".into()), + None, + ) + }; + + let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.0), None); + let float_or_ndarray_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let float_or_ndarray_var_map: VarMap = vec![ + (float_or_ndarray_ty.1, float_or_ndarray_ty.0), + ].into_iter().collect(); + + let num_or_ndarray_ty = unifier.get_fresh_var_with_range( + &[num_ty.0, ndarray_num_ty], + Some("T".into()), + None, + ); + let num_or_ndarray_var_map: VarMap = vec![ + (num_ty.1, num_ty.0), + (num_or_ndarray_ty.1, num_or_ndarray_ty.0), + ].into_iter().collect(); - let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect(); let exception_fields = vec![ ("__name__".into(), int32, true), ("__file__".into(), string, true), @@ -568,9 +600,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "int32".into(), simple_name: "int32".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], - ret: int32, - vars: var_map.clone(), + args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + ret: num_or_ndarray_ty.0, + vars: num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -581,7 +613,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_int32(ctx, (arg_ty, arg)).into())) + Ok(Some(builtin_fns::call_int32(generator, ctx, (arg_ty, arg))?)) }, )))), loc: None, @@ -590,9 +622,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "int64".into(), simple_name: "int64".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], - ret: int64, - vars: var_map.clone(), + args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + ret: num_or_ndarray_ty.0, + vars: num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -603,7 +635,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_int64(ctx, (arg_ty, arg)).into())) + Ok(Some(builtin_fns::call_int64(generator, ctx, (arg_ty, arg))?)) }, )))), loc: None, @@ -612,9 +644,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "uint32".into(), simple_name: "uint32".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], - ret: uint32, - vars: var_map.clone(), + args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + ret: num_or_ndarray_ty.0, + vars: num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -625,7 +657,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_uint32(ctx, (arg_ty, arg)).into())) + Ok(Some(builtin_fns::call_uint32(generator, ctx, (arg_ty, arg))?)) }, )))), loc: None, @@ -634,9 +666,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "uint64".into(), simple_name: "uint64".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], - ret: uint64, - vars: var_map.clone(), + args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + ret: num_or_ndarray_ty.0, + vars: num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -647,7 +679,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_uint64(ctx, (arg_ty, arg)).into())) + Ok(Some(builtin_fns::call_uint64(generator, ctx, (arg_ty, arg))?)) }, )))), loc: None, @@ -656,9 +688,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "float".into(), simple_name: "float".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], - ret: float, - vars: var_map.clone(), + args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + ret: num_or_ndarray_ty.0, + vars: num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -669,14 +701,14 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_float(ctx, (arg_ty, arg)).into())) + Ok(Some(builtin_fns::call_float(generator, ctx, (arg_ty, arg))?)) }, )))), loc: None, })), create_fn_by_codegen( unifier, - &var_map, + &VarMap::new(), "np_ndarray", ndarray_float, // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a @@ -689,7 +721,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built ), create_fn_by_codegen( unifier, - &var_map, + &VarMap::new(), "np_empty", ndarray_float, // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a @@ -702,7 +734,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built ), create_fn_by_codegen( unifier, - &var_map, + &VarMap::new(), "np_zeros", ndarray_float, // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a @@ -715,7 +747,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built ), create_fn_by_codegen( unifier, - &var_map, + &VarMap::new(), "np_ones", ndarray_float, // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a @@ -727,16 +759,16 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built }), ), { - let tv = unifier.get_fresh_var(Some("T".into()), None).0; + let tv = unifier.get_fresh_var(Some("T".into()), None); create_fn_by_codegen( unifier, - &var_map, + &[(tv.1, tv.0)].into_iter().collect(), "np_full", ndarray, // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a // type variable - &[(list_int32, "shape"), (tv, "fill_value")], + &[(list_int32, "shape"), (tv.0, "fill_value")], Box::new(|ctx, obj, fun, args, generator| { gen_ndarray_full(ctx, &obj, fun, &args, generator) .map(|val| Some(val.as_basic_value_enum())) @@ -758,7 +790,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built FuncArg { name: "k".into(), ty: int32, default_value: Some(SymbolValue::I32(0)) }, ], ret: ndarray_float_2d, - vars: var_map.clone(), + vars: VarMap::default(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -774,7 +806,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built })), create_fn_by_codegen( unifier, - &var_map, + &VarMap::new(), "np_identity", ndarray_float_2d, &[(int32, "n")], @@ -783,53 +815,96 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built .map(|val| Some(val.as_basic_value_enum())) }), ), + { + let common_ndim = unifier.get_fresh_const_generic_var( + primitives.usize(), + Some("N".into()), + None, + ); + let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); + + let p0_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let ret_ty = unifier.get_fresh_var_with_range( + &[int32, ndarray_int32], + Some("R".into()), + None, + ); + + create_fn_by_codegen( + unifier, + &[ + (common_ndim.1, common_ndim.0), + (p0_ty.1, p0_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + "round", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, arg_ty)?; + + Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) + }), + ) + }, + { + let common_ndim = unifier.get_fresh_const_generic_var( + primitives.usize(), + Some("N".into()), + None, + ); + let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); + + let p0_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let ret_ty = unifier.get_fresh_var_with_range( + &[int64, ndarray_int64], + Some("R".into()), + None, + ); + + create_fn_by_codegen( + unifier, + &[ + (common_ndim.1, common_ndim.0), + (p0_ty.1, p0_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + "round64", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, arg_ty)?; + + Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) + }), + ) + }, create_fn_by_codegen( unifier, - &var_map, - "round", - int32, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i32 = ctx.ctx.i32_type(); - - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); - - Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i32).into())) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, - "round64", - int64, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i64 = ctx.ctx.i64_type(); - - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); - - Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i64).into())) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, + &float_or_ndarray_var_map, "np_round", - float, - &[(float, "n")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_numpy_round(ctx, (arg_ty, arg)).into())) + Ok(Some(builtin_fns::call_numpy_round(generator, ctx, (arg_ty, arg))?)) }), ), Arc::new(RwLock::new(TopLevelDef::Function { @@ -961,9 +1036,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "bool".into(), simple_name: "bool".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], - ret: primitives.bool, - vars: var_map.clone(), + args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + ret: num_or_ndarray_ty.0, + vars: num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -974,107 +1049,193 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_bool(ctx, (arg_ty, arg)).into())) + Ok(Some(builtin_fns::call_bool(generator, ctx, (arg_ty, arg))?)) }, )))), loc: None, })), + { + let common_ndim = unifier.get_fresh_const_generic_var( + primitives.usize(), + Some("N".into()), + None, + ); + let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); + + let p0_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let ret_ty = unifier.get_fresh_var_with_range( + &[int32, ndarray_int32], + Some("R".into()), + None, + ); + + create_fn_by_codegen( + unifier, + &[ + (common_ndim.1, common_ndim.0), + (p0_ty.1, p0_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + "floor", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, arg_ty)?; + + Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) + }), + ) + }, + { + let common_ndim = unifier.get_fresh_const_generic_var( + primitives.usize(), + Some("N".into()), + None, + ); + let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); + + let p0_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let ret_ty = unifier.get_fresh_var_with_range( + &[int64, ndarray_int64], + Some("R".into()), + None, + ); + + create_fn_by_codegen( + unifier, + &[ + (common_ndim.1, common_ndim.0), + (p0_ty.1, p0_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + "floor64", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, arg_ty)?; + + Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) + }), + ) + }, create_fn_by_codegen( unifier, - &var_map, - "floor", - int32, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i32 = ctx.ctx.i32_type(); - - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); - - Ok(Some(builtin_fns::call_floor(ctx, (arg_ty, arg), llvm_i32.into()))) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, - "floor64", - int64, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i64 = ctx.ctx.i64_type(); - - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); - - Ok(Some(builtin_fns::call_floor(ctx, (arg_ty, arg), llvm_i64.into()))) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, + &float_or_ndarray_var_map, "np_floor", - float, - &[(float, "n")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_floor(ctx, (arg_ty, arg), arg.get_type().into()))) + Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) }), ), + { + let common_ndim = unifier.get_fresh_const_generic_var( + primitives.usize(), + Some("N".into()), + None, + ); + let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); + + let p0_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let ret_ty = unifier.get_fresh_var_with_range( + &[int32, ndarray_int32], + Some("R".into()), + None, + ); + + create_fn_by_codegen( + unifier, + &[ + (common_ndim.1, common_ndim.0), + (p0_ty.1, p0_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + "ceil", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, arg_ty)?; + + Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) + }), + ) + }, + { + let common_ndim = unifier.get_fresh_const_generic_var( + primitives.usize(), + Some("N".into()), + None, + ); + let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); + + let p0_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let ret_ty = unifier.get_fresh_var_with_range( + &[int64, ndarray_int64], + Some("R".into()), + None, + ); + + create_fn_by_codegen( + unifier, + &[ + (common_ndim.1, common_ndim.0), + (p0_ty.1, p0_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + "ceil64", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, arg_ty)?; + + Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) + }), + ) + }, create_fn_by_codegen( unifier, - &var_map, - "ceil", - int32, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i32 = ctx.ctx.i32_type(); - - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); - - Ok(Some(builtin_fns::call_ceil(ctx, (arg_ty, arg), llvm_i32.into()))) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, - "ceil64", - int64, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i64 = ctx.ctx.i64_type(); - - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); - - Ok(Some(builtin_fns::call_ceil(ctx, (arg_ty, arg), llvm_i64.into()))) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, + &float_or_ndarray_var_map, "np_ceil", - float, - &[(float, "n")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_ceil(ctx, (arg_ty, arg), arg.get_type().into()))) + Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) }), ), Arc::new(RwLock::new({ @@ -1103,7 +1264,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built .into_iter() .collect(), })), - var_id: vec![arg_ty.1], + var_id: Vec::default(), instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -1199,7 +1360,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }, ], ret: num_ty.0, - vars: var_map.clone(), + vars: num_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -1226,7 +1387,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }, ], ret: num_ty.0, - vars: var_map.clone(), + vars: num_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -1248,9 +1409,9 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built name: "abs".into(), simple_name: "abs".into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], - ret: num_ty.0, - vars: var_map.clone(), + args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], + ret: num_or_ndarray_ty.0, + vars: num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -1261,623 +1422,740 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let n_ty = fun.0.args[0].ty; let n_val = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; - Ok(Some(builtin_fns::call_abs(ctx, (n_ty, n_val)))) + Ok(Some(builtin_fns::call_abs(generator, ctx, (n_ty, n_val))?)) }, )))), loc: None, })), create_fn_by_codegen( unifier, - &var_map, + &VarMap::new(), "np_isnan", boolean, &[(float, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_isnan(generator, ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_isnan(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &VarMap::new(), "np_isinf", boolean, &[(float, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_isinf(generator, ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_isinf(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_sin", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_sin(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_sin(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_cos", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_cos(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_cos(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_exp", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_exp(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_exp(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_exp2", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_exp2(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_exp2(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_log", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_log(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_log(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_log10", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_log10(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_log10(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_log2", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_log2(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_log2(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_fabs", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_fabs(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_fabs(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_sqrt", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_sqrt(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_sqrt(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_rint", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_rint(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_rint(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_tan", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_tan(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_tan(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_arcsin", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_arcsin(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_arcsin(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_arccos", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_arccos(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_arccos(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_arctan", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_arctan(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_arctan(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_sinh", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_sinh(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_sinh(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_cosh", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_cosh(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_cosh(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_tanh", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_tanh(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_tanh(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_arcsinh", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_asinh(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_arcsinh(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_arccosh", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_acosh(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_arccosh(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_arctanh", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_atanh(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_arctanh(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_expm1", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_expm1(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_expm1(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "np_cbrt", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_cbrt(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_numpy_cbrt(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "sp_spec_erf", - float, - &[(float, "z")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "z")], Box::new(|ctx, _, fun, args, generator| { let z_ty = fun.0.args[0].ty; let z_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, z_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, z_ty)?; - Ok(Some(builtin_fns::call_scipy_special_erf(ctx, (z_ty, z_val)).into())) + Ok(Some(builtin_fns::call_scipy_special_erf(generator, ctx, (z_ty, z_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "sp_spec_erfc", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let z_ty = fun.0.args[0].ty; let z_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, z_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, z_ty)?; - Ok(Some(builtin_fns::call_scipy_special_erfc(ctx, (z_ty, z_val)).into())) + Ok(Some(builtin_fns::call_scipy_special_erfc(generator, ctx, (z_ty, z_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "sp_spec_gamma", - float, - &[(float, "z")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "z")], Box::new(|ctx, _, fun, args, generator| { let z_ty = fun.0.args[0].ty; let z_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, z_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, z_ty)?; - Ok(Some(builtin_fns::call_scipy_special_gamma(ctx, (z_ty, z_val)).into())) + Ok(Some(builtin_fns::call_scipy_special_gamma(generator, ctx, (z_ty, z_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "sp_spec_gammaln", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_scipy_special_gammaln(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_scipy_special_gammaln(generator, ctx, (x_ty, x_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "sp_spec_j0", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let z_ty = fun.0.args[0].ty; let z_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, z_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, z_ty)?; - Ok(Some(builtin_fns::call_scipy_special_j0(ctx, (z_ty, z_val)).into())) + Ok(Some(builtin_fns::call_scipy_special_j0(generator, ctx, (z_ty, z_val))?)) }), ), create_fn_by_codegen( unifier, - &var_map, + &float_or_ndarray_var_map, "sp_spec_j1", - float, - &[(float, "x")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "x")], Box::new(|ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_scipy_special_j1(ctx, (x_ty, x_val)).into())) + Ok(Some(builtin_fns::call_scipy_special_j1(generator, ctx, (x_ty, x_val))?)) }), ), // Not mapped: jv/yv, libm only supports integer orders. - create_fn_by_codegen( - unifier, - &var_map, - "np_arctan2", - float, - &[(float, "x1"), (float, "x2")], - Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)? - .into_float_value(); - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)? - .into_float_value(); + { + let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = unifier.get_fresh_var(None, None); - Ok(Some(builtin_fns::call_numpy_arctan2( - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - ).into())) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, - "np_copysign", - float, - &[(float, "x1"), (float, "x2")], - Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)? - .into_float_value(); - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)? - .into_float_value(); + Arc::new(RwLock::new(TopLevelDef::Function { + name: "np_arctan2".into(), + simple_name: "np_arctan2".into(), + signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty.iter().map(|p| FuncArg { + name: p.1.into(), + ty: p.0, + default_value: None, + }).collect(), + ret: ret_ty.0, + vars: [ + (x1_ty.1, x1_ty.0), + (x2_ty.1, x2_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + })), + var_id: vec![ret_ty.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone() + .to_basic_value_enum(ctx, generator, x2_ty)?; + + Ok(Some(builtin_fns::call_numpy_arctan2( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + })))), + loc: None, + })) + }, + { + let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = unifier.get_fresh_var(None, None); - Ok(Some(builtin_fns::call_numpy_copysign( - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - ).into())) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, - "np_fmax", - float, - &[(float, "x1"), (float, "x2")], - Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)? - .into_float_value(); - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)? - .into_float_value(); + Arc::new(RwLock::new(TopLevelDef::Function { + name: "np_copysign".into(), + simple_name: "np_copysign".into(), + signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty.iter().map(|p| FuncArg { + name: p.1.into(), + ty: p.0, + default_value: None, + }).collect(), + ret: ret_ty.0, + vars: [ + (x1_ty.1, x1_ty.0), + (x2_ty.1, x2_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + })), + var_id: vec![ret_ty.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone() + .to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_numpy_fmax( - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - ).into())) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, - "np_fmin", - float, - &[(float, "x1"), (float, "x2")], - Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)? - .into_float_value(); - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)? - .into_float_value(); + Ok(Some(builtin_fns::call_numpy_copysign( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + })))), + loc: None, + })) + }, + { + let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = unifier.get_fresh_var(None, None); - Ok(Some(builtin_fns::call_numpy_fmin( - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - ).into())) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, - "np_ldexp", - float, - &[(float, "x1"), (int32, "x2")], - Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)? - .into_float_value(); - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)? - .into_int_value(); + Arc::new(RwLock::new(TopLevelDef::Function { + name: "np_fmax".into(), + simple_name: "np_fmax".into(), + signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty.iter().map(|p| FuncArg { + name: p.1.into(), + ty: p.0, + default_value: None, + }).collect(), + ret: ret_ty.0, + vars: [ + (x1_ty.1, x1_ty.0), + (x2_ty.1, x2_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + })), + var_id: vec![x1_ty.1, x2_ty.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone() + .to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_numpy_ldexp( - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - ).into())) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, - "np_hypot", - float, - &[(float, "x1"), (float, "x2")], - Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)? - .into_float_value(); - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)? - .into_float_value(); + Ok(Some(builtin_fns::call_numpy_fmax( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + })))), + loc: None, + })) + }, + { + let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = unifier.get_fresh_var(None, None); - Ok(Some(builtin_fns::call_numpy_hypot( - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - ).into())) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, - "np_nextafter", - float, - &[(float, "x1"), (float, "x2")], - Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)? - .into_float_value(); - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)? - .into_float_value(); + Arc::new(RwLock::new(TopLevelDef::Function { + name: "np_fmin".into(), + simple_name: "np_fmin".into(), + signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty.iter().map(|p| FuncArg { + name: p.1.into(), + ty: p.0, + default_value: None, + }).collect(), + ret: ret_ty.0, + vars: [ + (x1_ty.1, x1_ty.0), + (x2_ty.1, x2_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + })), + var_id: vec![x1_ty.1, x2_ty.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone() + .to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_numpy_nextafter( - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - ).into())) - }), - ), + Ok(Some(builtin_fns::call_numpy_fmin( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + })))), + loc: None, + })) + }, + { + let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let x2_ty = new_type_or_ndarray_ty(unifier, primitives, int32); + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = unifier.get_fresh_var(None, None); + + Arc::new(RwLock::new(TopLevelDef::Function { + name: "np_ldexp".into(), + simple_name: "np_ldexp".into(), + signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty.iter().map(|p| FuncArg { + name: p.1.into(), + ty: p.0, + default_value: None, + }).collect(), + ret: ret_ty.0, + vars: [ + (x1_ty.1, x1_ty.0), + (x2_ty.1, x2_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + })), + var_id: vec![x1_ty.1, x2_ty.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone() + .to_basic_value_enum(ctx, generator, x2_ty)?; + + Ok(Some(builtin_fns::call_numpy_ldexp( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + })))), + loc: None, + })) + }, + { + let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = unifier.get_fresh_var(None, None); + + Arc::new(RwLock::new(TopLevelDef::Function { + name: "np_hypot".into(), + simple_name: "np_hypot".into(), + signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty.iter().map(|p| FuncArg { + name: p.1.into(), + ty: p.0, + default_value: None, + }).collect(), + ret: ret_ty.0, + vars: [ + (x1_ty.1, x1_ty.0), + (x2_ty.1, x2_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + })), + var_id: vec![x1_ty.1, x2_ty.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone() + .to_basic_value_enum(ctx, generator, x2_ty)?; + + Ok(Some(builtin_fns::call_numpy_hypot( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + })))), + loc: None, + })) + }, + { + let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = unifier.get_fresh_var(None, None); + + Arc::new(RwLock::new(TopLevelDef::Function { + name: "np_nextafter".into(), + simple_name: "np_nextafter".into(), + signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty.iter().map(|p| FuncArg { + name: p.1.into(), + ty: p.0, + default_value: None, + }).collect(), + ret: ret_ty.0, + vars: [ + (x1_ty.1, x1_ty.0), + (x2_ty.1, x2_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + })), + var_id: vec![x1_ty.1, x2_ty.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone() + .to_basic_value_enum(ctx, generator, x2_ty)?; + + Ok(Some(builtin_fns::call_numpy_nextafter( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + })))), + loc: None, + })) + }, Arc::new(RwLock::new(TopLevelDef::Function { name: "Some".into(), simple_name: "Some".into(), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index a9dc4ad1c..027ad4289 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [127]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [222]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index dc36b54b9..248aa920e 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar116]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar116\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar211]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar211\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index d6adcee10..c79adf850 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [129]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [134]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [224]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [229]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 55767a80c..968848599 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar115, typevar116]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar115\", \"typevar116\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar210, typevar211]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar210\", \"typevar211\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index f4f96f2d6..a217b0002 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [135]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [230]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [143]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [238]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index ce6cf6ee1..72fd21309 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -14,17 +14,7 @@ use crate::{ }, }; use itertools::{Itertools, izip}; -use nac3parser::ast::{ - self, - fold::{self, Fold}, - Arguments, - Comprehension, - ExprContext, - ExprKind, - Located, - Location, - StrRef -}; +use nac3parser::ast::{self, fold::{self, Fold}, Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef}; #[cfg(test)] mod test; @@ -860,57 +850,194 @@ impl<'a> Inferencer<'a> { }, })) } - // int64 is special because its argument can be a constant larger than int32 - if id == &"int64".into() && args.len() == 1 { - if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = - &args[0].node - { - let custom = Some(self.primitives.int64); - let v: Result = (*val).try_into(); - return if v.is_ok() { + + if [ + "int32", + "float", + "bool", + "np_isnan", + "np_isinf", + ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { + let target_ty = if id == &"int32".into() { + self.primitives.int32 + } else if id == &"float".into() { + self.primitives.float + } else if id == &"bool".into() || id == &"np_isnan".into() || id == &"np_isinf".into() { + self.primitives.bool + } else { unreachable!() }; + + let arg0 = self.fold_expr(args.remove(0))?; + let arg0_ty = arg0.custom.unwrap(); + + let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); + + make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) + } else { + target_ty + }; + + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "n".into(), + ty: arg0.custom.unwrap(), + default_value: None, + }, + ], + ret, + vars: VarMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: ctx.clone() }, + }), + args: vec![arg0], + keywords: vec![], + }, + })) + } + + if [ + "np_arctan2", + "np_copysign", + "np_fmax", + "np_fmin", + "np_ldexp", + "np_hypot", + "np_nextafter", + ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 2 { + let target_ty = self.primitives.float; + + let arg0 = self.fold_expr(args.remove(0))?; + let arg0_ty = arg0.custom.unwrap(); + let arg1 = self.fold_expr(args.remove(0))?; + let arg1_ty = arg1.custom.unwrap(); + + let arg0_dtype = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + unpack_ndarray_var_tys(self.unifier, arg0_ty).0 + } else { + arg0_ty + }; + + let arg1_dtype = if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + unpack_ndarray_var_tys(self.unifier, arg1_ty).0 + } else { + arg1_ty + }; + let expected_arg1_dtype = if id == &"np_ldexp".into() { + self.primitives.int32 + } else { + arg0_dtype + }; + if !self.unifier.unioned(arg1_dtype, expected_arg1_dtype) { + return report_error( + format!( + "Expected {} for second argument of {id}, got {}", + self.unifier.stringify(expected_arg1_dtype), + self.unifier.stringify(arg1_dtype), + ).as_str(), + arg0.location, + ) + } + + let ret = if [ + &arg0_ty, + &arg1_ty, + ].into_iter().any(|arg_ty| arg_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) { + // typeof_ndarray_broadcast requires both dtypes to be the same, but ldexp accepts + // (float, int32), so convert it to align with the dtype of the first arg + let arg1_ty = if id == &"np_ldexp".into() { + if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty); + + make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims)) + } else { + target_ty + } + } else { + arg1_ty + }; + + match typeof_ndarray_broadcast(self.unifier, self.primitives, arg0_ty, arg1_ty) { + Ok(broadcasted_ty) => broadcasted_ty, + Err(err) => return report_error(err.as_str(), location), + } + } else { + target_ty + }; + + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "x1".into(), + ty: arg0.custom.unwrap(), + default_value: None, + }, + FuncArg { + name: "x2".into(), + ty: arg1.custom.unwrap(), + default_value: None, + }, + ], + ret, + vars: VarMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: ctx.clone() }, + }), + args: vec![arg0, arg1], + keywords: vec![], + }, + })) + } + + // int64, uint32 and uint64 are special because their argument can be a constant outside the + // range of int32s + if [ + "int64", + "uint32", + "uint64", + ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { + let target_ty = if id == &"int64".into() { + self.primitives.int64 + } else if id == &"uint32".into() { + self.primitives.uint32 + } else if id == &"uint64".into() { + self.primitives.uint64 + } else { unreachable!() }; + + // Handle constants first to ensure that their types are not defaulted to int32, which + // causes an "Integer out of bound" error + if let ExprKind::Constant { + value: ast::Constant::Int(val), + kind + } = &args[0].node { + let conv_is_ok = if self.unifier.unioned(target_ty, self.primitives.int64) { + i64::try_from(*val).is_ok() + } else if self.unifier.unioned(target_ty, self.primitives.uint32) { + u32::try_from(*val).is_ok() + } else if self.unifier.unioned(target_ty, self.primitives.uint64) { + u64::try_from(*val).is_ok() + } else { unreachable!() }; + + return if conv_is_ok { Ok(Some(Located { location: args[0].location, - custom, - node: ExprKind::Constant { - value: ast::Constant::Int(*val), - kind: kind.clone(), - }, - })) - } else { - report_error("Integer out of bound", args[0].location) - } - } - } - if id == &"uint32".into() && args.len() == 1 { - if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = - &args[0].node - { - let custom = Some(self.primitives.uint32); - let v: Result = (*val).try_into(); - return if v.is_ok() { - Ok(Some(Located { - location: args[0].location, - custom, - node: ExprKind::Constant { - value: ast::Constant::Int(*val), - kind: kind.clone(), - }, - })) - } else { - report_error("Integer out of bound", args[0].location) - } - } - } - if id == &"uint64".into() && args.len() == 1 { - if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = - &args[0].node - { - let custom = Some(self.primitives.uint64); - let v: Result = (*val).try_into(); - return if v.is_ok() { - Ok(Some(Located { - location: args[0].location, - custom, + custom: Some(target_ty), node: ExprKind::Constant { value: ast::Constant::Int(*val), kind: kind.clone(), @@ -920,6 +1047,43 @@ impl<'a> Inferencer<'a> { report_error("Integer out of bound", args[0].location) } } + + let arg0 = self.fold_expr(args.remove(0))?; + let arg0_ty = arg0.custom.unwrap(); + + let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); + + make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) + } else { + target_ty + }; + + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "n".into(), + ty: arg0.custom.unwrap(), + default_value: None, + }, + ], + ret, + vars: VarMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: ctx.clone() }, + }), + args: vec![arg0], + keywords: vec![], + }, + })) } // 1-argument ndarray n-dimensional creation functions diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 03deff455..f38e70bdb 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -58,11 +58,38 @@ class _NDArrayDummy(Generic[T, N]): # https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]] -def round_away_zero(x): - if x >= 0.0: - return math.floor(x + 0.5) +def _bool(x): + if isinstance(x, np.ndarray): + return np.bool_(x) else: - return math.ceil(x - 0.5) + return bool(x) + +def _float(x): + if isinstance(x, np.ndarray): + return np.float_(x) + else: + return float(x) + +def round_away_zero(x): + if isinstance(x, np.ndarray): + return np.vectorize(round_away_zero)(x) + else: + if x >= 0.0: + return math.floor(x + 0.5) + else: + return math.ceil(x - 0.5) + +def _floor(x): + if isinstance(x, np.ndarray): + return np.vectorize(_floor)(x) + else: + return math.floor(x) + +def _ceil(x): + if isinstance(x, np.ndarray): + return np.vectorize(_ceil)(x) + else: + return math.ceil(x) def patch(module): def dbl_nan(): @@ -112,6 +139,8 @@ def patch(module): module.int64 = int64 module.uint32 = uint32 module.uint64 = uint64 + module.bool = _bool + module.float = _float module.TypeVar = TypeVar module.ConstGeneric = ConstGeneric module.Generic = Generic @@ -125,11 +154,11 @@ def patch(module): module.round = round_away_zero module.round64 = round_away_zero module.np_round = np.round - module.floor = math.floor - module.floor64 = math.floor + module.floor = _floor + module.floor64 = _floor module.np_floor = np.floor - module.ceil = math.ceil - module.ceil64 = math.ceil + module.ceil = _ceil + module.ceil64 = _ceil module.np_ceil = np.ceil # NumPy ndarray functions diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index eafe39d81..15afd5d21 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1,3 +1,11 @@ +@extern +def dbl_nan() -> float: + ... + +@extern +def dbl_inf() -> float: + ... + @extern def output_bool(x: bool): ... @@ -6,6 +14,18 @@ def output_bool(x: bool): def output_int32(x: int32): ... +@extern +def output_int64(x: int64): + ... + +@extern +def output_uint32(x: uint32): + ... + +@extern +def output_uint64(x: uint64): + ... + @extern def output_float64(x: float): ... @@ -24,6 +44,21 @@ def output_ndarray_int32_2(n: ndarray[int32, Literal[2]]): for c in range(len(n[r])): output_int32(n[r][c]) +def output_ndarray_int64_2(n: ndarray[int64, Literal[2]]): + for r in range(len(n)): + for c in range(len(n[r])): + output_int64(n[r][c]) + +def output_ndarray_uint32_2(n: ndarray[uint32, Literal[2]]): + for r in range(len(n)): + for c in range(len(n[r])): + output_uint32(n[r][c]) + +def output_ndarray_uint64_2(n: ndarray[uint64, Literal[2]]): + for r in range(len(n)): + for c in range(len(n[r])): + output_uint64(n[r][c]) + def output_ndarray_float_1(n: ndarray[float, Literal[1]]): for i in range(len(n)): output_float64(n[i]) @@ -649,6 +684,586 @@ def test_ndarray_ge_broadcast_rhs_scalar(): output_ndarray_float_2(x) output_ndarray_bool_2(y) +def test_ndarray_int32(): + x = np_identity(2) + y = int32(x) + + output_ndarray_float_2(x) + output_ndarray_int32_2(y) + +def test_ndarray_int64(): + x = np_identity(2) + y = int64(x) + + output_ndarray_float_2(x) + output_ndarray_int64_2(y) + +def test_ndarray_uint32(): + x = np_identity(2) + y = uint32(x) + + output_ndarray_float_2(x) + output_ndarray_uint32_2(y) + +def test_ndarray_uint64(): + x = np_identity(2) + y = uint64(x) + + output_ndarray_float_2(x) + output_ndarray_uint64_2(y) + +def test_ndarray_float(): + x = np_full([2, 2], 1) + y = float(x) + + output_ndarray_int32_2(x) + output_ndarray_float_2(y) + +def test_ndarray_bool(): + x = np_identity(2) + y = bool(x) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_round(): + x = np_identity(2) + xf32 = round(x) + xf64 = round64(x) + xff = np_round(x) + + output_ndarray_float_2(x) + output_ndarray_int32_2(xf32) + output_ndarray_int64_2(xf64) + output_ndarray_float_2(xff) + +def test_ndarray_floor(): + x = np_identity(2) + xf32 = floor(x) + xf64 = floor64(x) + xff = np_floor(x) + + output_ndarray_float_2(x) + output_ndarray_int32_2(xf32) + output_ndarray_int64_2(xf64) + output_ndarray_float_2(xff) + +def test_ndarray_ceil(): + x = np_identity(2) + xf32 = ceil(x) + xf64 = ceil64(x) + xff = np_ceil(x) + + output_ndarray_float_2(x) + output_ndarray_int32_2(xf32) + output_ndarray_int64_2(xf64) + output_ndarray_float_2(xff) + +def test_ndarray_abs(): + x = np_identity(2) + y = abs(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_isnan(): + x = np_identity(2) + x_isnan = np_isnan(x) + y = np_full([2, 2], dbl_nan()) + y_isnan = np_isnan(y) + + output_ndarray_float_2(x) + output_ndarray_bool_2(x_isnan) + output_ndarray_float_2(y) + output_ndarray_bool_2(y_isnan) + +def test_ndarray_isinf(): + x = np_identity(2) + x_isinf = np_isinf(x) + y = np_full([2, 2], dbl_inf()) + y_isinf = np_isinf(y) + + output_ndarray_float_2(x) + output_ndarray_bool_2(x_isinf) + output_ndarray_float_2(y) + output_ndarray_bool_2(y_isinf) + +def test_ndarray_sin(): + x = np_identity(2) + y = np_sin(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_cos(): + x = np_identity(2) + y = np_cos(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_exp(): + x = np_identity(2) + y = np_exp(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_exp2(): + x = np_identity(2) + y = np_exp2(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_log(): + x = np_identity(2) + y = np_log(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_log10(): + x = np_identity(2) + y = np_log10(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_log2(): + x = np_identity(2) + y = np_log2(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_fabs(): + x = -np_identity(2) + y = np_fabs(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_sqrt(): + x = np_identity(2) + y = np_sqrt(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_rint(): + x = np_identity(2) + y = np_rint(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_tan(): + x = np_identity(2) + y = np_tan(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_arcsin(): + x = np_identity(2) + y = np_arcsin(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_arccos(): + x = np_identity(2) + y = np_arccos(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_arctan(): + x = np_identity(2) + y = np_arctan(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_sinh(): + x = np_identity(2) + y = np_sinh(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_cosh(): + x = np_identity(2) + y = np_cosh(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_tanh(): + x = np_identity(2) + y = np_tanh(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_arcsinh(): + x = np_identity(2) + y = np_arcsinh(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_arccosh(): + x = np_identity(2) + y = np_arccosh(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_arctanh(): + x = np_identity(2) + y = np_arctanh(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_expm1(): + x = np_identity(2) + y = np_expm1(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_cbrt(): + x = np_identity(2) + y = np_cbrt(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_erf(): + x = np_identity(2) + y = sp_spec_erf(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_erfc(): + x = np_identity(2) + y = sp_spec_erfc(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_gamma(): + x = np_identity(2) + y = sp_spec_gamma(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_gammaln(): + x = np_identity(2) + y = sp_spec_gammaln(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_j0(): + x = np_identity(2) + y = sp_spec_j0(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_j1(): + x = np_identity(2) + y = sp_spec_j1(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_arctan2(): + x = np_identity(2) + zeros = np_zeros([2, 2]) + ones = np_ones([2, 2]) + atan2_x_zeros = np_arctan2(x, zeros) + atan2_x_ones = np_arctan2(x, ones) + + output_ndarray_float_2(x) + output_ndarray_float_2(zeros) + output_ndarray_float_2(ones) + output_ndarray_float_2(atan2_x_zeros) + output_ndarray_float_2(atan2_x_ones) + +def test_ndarray_arctan2_broadcast(): + x = np_identity(2) + atan2_x_zeros = np_arctan2(x, np_zeros([2])) + atan2_x_ones = np_arctan2(x, np_ones([2])) + + output_ndarray_float_2(x) + output_ndarray_float_2(atan2_x_zeros) + output_ndarray_float_2(atan2_x_ones) + +def test_ndarray_arctan2_broadcast_lhs_scalar(): + x = np_identity(2) + atan2_x_zeros = np_arctan2(0.0, x) + atan2_x_ones = np_arctan2(1.0, x) + + output_ndarray_float_2(x) + output_ndarray_float_2(atan2_x_zeros) + output_ndarray_float_2(atan2_x_ones) + +def test_ndarray_arctan2_broadcast_rhs_scalar(): + x = np_identity(2) + atan2_x_zeros = np_arctan2(x, 0.0) + atan2_x_ones = np_arctan2(x, 1.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(atan2_x_zeros) + output_ndarray_float_2(atan2_x_ones) + +def test_ndarray_copysign(): + x = np_identity(2) + ones = np_ones([2, 2]) + negones = np_full([2, 2], -1.0) + copysign_x_ones = np_copysign(x, ones) + copysign_x_negones = np_copysign(x, ones) + + output_ndarray_float_2(x) + output_ndarray_float_2(ones) + output_ndarray_float_2(negones) + output_ndarray_float_2(copysign_x_ones) + output_ndarray_float_2(copysign_x_negones) + +def test_ndarray_copysign_broadcast(): + x = np_identity(2) + copysign_x_ones = np_copysign(x, np_ones([2])) + copysign_x_negones = np_copysign(x, np_full([2], -1.0)) + + output_ndarray_float_2(x) + output_ndarray_float_2(copysign_x_ones) + output_ndarray_float_2(copysign_x_negones) + +def test_ndarray_copysign_broadcast_lhs_scalar(): + x = np_identity(2) + copysign_x_ones = np_copysign(1.0, x) + copysign_x_negones = np_copysign(-1.0, x) + + output_ndarray_float_2(x) + output_ndarray_float_2(copysign_x_ones) + output_ndarray_float_2(copysign_x_negones) + +def test_ndarray_copysign_broadcast_rhs_scalar(): + x = np_identity(2) + copysign_x_ones = np_copysign(x, 1.0) + copysign_x_negones = np_copysign(x, -1.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(copysign_x_ones) + output_ndarray_float_2(copysign_x_negones) + +def test_ndarray_fmax(): + x = np_identity(2) + ones = np_ones([2, 2]) + negones = np_full([2, 2], -1.0) + fmax_x_ones = np_fmax(x, ones) + fmax_x_negones = np_fmax(x, ones) + + output_ndarray_float_2(x) + output_ndarray_float_2(ones) + output_ndarray_float_2(negones) + output_ndarray_float_2(fmax_x_ones) + output_ndarray_float_2(fmax_x_negones) + +def test_ndarray_fmax_broadcast(): + x = np_identity(2) + fmax_x_ones = np_fmax(x, np_ones([2])) + fmax_x_negones = np_fmax(x, np_full([2], -1.0)) + + output_ndarray_float_2(x) + output_ndarray_float_2(fmax_x_ones) + output_ndarray_float_2(fmax_x_negones) + +def test_ndarray_fmax_broadcast_lhs_scalar(): + x = np_identity(2) + fmax_x_ones = np_fmax(1.0, x) + fmax_x_negones = np_fmax(-1.0, x) + + output_ndarray_float_2(x) + output_ndarray_float_2(fmax_x_ones) + output_ndarray_float_2(fmax_x_negones) + +def test_ndarray_fmax_broadcast_rhs_scalar(): + x = np_identity(2) + fmax_x_ones = np_fmax(x, 1.0) + fmax_x_negones = np_fmax(x, -1.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(fmax_x_ones) + output_ndarray_float_2(fmax_x_negones) + +def test_ndarray_fmin(): + x = np_identity(2) + ones = np_ones([2, 2]) + negones = np_full([2, 2], -1.0) + fmin_x_ones = np_fmin(x, ones) + fmin_x_negones = np_fmin(x, ones) + + output_ndarray_float_2(x) + output_ndarray_float_2(ones) + output_ndarray_float_2(negones) + output_ndarray_float_2(fmin_x_ones) + output_ndarray_float_2(fmin_x_negones) + +def test_ndarray_fmin_broadcast(): + x = np_identity(2) + fmin_x_ones = np_fmin(x, np_ones([2])) + fmin_x_negones = np_fmin(x, np_full([2], -1.0)) + + output_ndarray_float_2(x) + output_ndarray_float_2(fmin_x_ones) + output_ndarray_float_2(fmin_x_negones) + +def test_ndarray_fmin_broadcast_lhs_scalar(): + x = np_identity(2) + fmin_x_ones = np_fmin(1.0, x) + fmin_x_negones = np_fmin(-1.0, x) + + output_ndarray_float_2(x) + output_ndarray_float_2(fmin_x_ones) + output_ndarray_float_2(fmin_x_negones) + +def test_ndarray_fmin_broadcast_rhs_scalar(): + x = np_identity(2) + fmin_x_ones = np_fmin(x, 1.0) + fmin_x_negones = np_fmin(x, -1.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(fmin_x_ones) + output_ndarray_float_2(fmin_x_negones) + +def test_ndarray_ldexp(): + x = np_identity(2) + zeros = np_full([2, 2], 0) + ones = np_full([2, 2], 1) + ldexp_x_zeros = np_ldexp(x, zeros) + ldexp_x_ones = np_ldexp(x, ones) + + output_ndarray_float_2(x) + output_ndarray_int32_2(zeros) + output_ndarray_int32_2(ones) + output_ndarray_float_2(ldexp_x_zeros) + output_ndarray_float_2(ldexp_x_ones) + +def test_ndarray_ldexp_broadcast(): + x = np_identity(2) + ldexp_x_zeros = np_ldexp(x, np_full([2], 0)) + ldexp_x_ones = np_ldexp(x, np_full([2], 1)) + + output_ndarray_float_2(x) + output_ndarray_float_2(ldexp_x_zeros) + output_ndarray_float_2(ldexp_x_ones) + +def test_ndarray_ldexp_broadcast_lhs_scalar(): + x = int32(np_identity(2)) + ldexp_x_zeros = np_ldexp(0.0, x) + ldexp_x_ones = np_ldexp(1.0, x) + + output_ndarray_int32_2(x) + output_ndarray_float_2(ldexp_x_zeros) + output_ndarray_float_2(ldexp_x_ones) + +def test_ndarray_ldexp_broadcast_rhs_scalar(): + x = np_identity(2) + ldexp_x_zeros = np_ldexp(x, 0) + ldexp_x_ones = np_ldexp(x, 1) + + output_ndarray_float_2(x) + output_ndarray_float_2(ldexp_x_zeros) + output_ndarray_float_2(ldexp_x_ones) + +def test_ndarray_hypot(): + x = np_identity(2) + zeros = np_zeros([2, 2]) + ones = np_ones([2, 2]) + hypot_x_zeros = np_hypot(x, zeros) + hypot_x_ones = np_hypot(x, ones) + + output_ndarray_float_2(x) + output_ndarray_float_2(zeros) + output_ndarray_float_2(ones) + output_ndarray_float_2(hypot_x_zeros) + output_ndarray_float_2(hypot_x_ones) + +def test_ndarray_hypot_broadcast(): + x = np_identity(2) + hypot_x_zeros = np_hypot(x, np_zeros([2])) + hypot_x_ones = np_hypot(x, np_ones([2])) + + output_ndarray_float_2(x) + output_ndarray_float_2(hypot_x_zeros) + output_ndarray_float_2(hypot_x_ones) + +def test_ndarray_hypot_broadcast_lhs_scalar(): + x = np_identity(2) + hypot_x_zeros = np_hypot(0.0, x) + hypot_x_ones = np_hypot(1.0, x) + + output_ndarray_float_2(x) + output_ndarray_float_2(hypot_x_zeros) + output_ndarray_float_2(hypot_x_ones) + +def test_ndarray_hypot_broadcast_rhs_scalar(): + x = np_identity(2) + hypot_x_zeros = np_hypot(x, 0.0) + hypot_x_ones = np_hypot(x, 1.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(hypot_x_zeros) + output_ndarray_float_2(hypot_x_ones) + +def test_ndarray_nextafter(): + x = np_identity(2) + zeros = np_zeros([2, 2]) + ones = np_ones([2, 2]) + nextafter_x_zeros = np_nextafter(x, zeros) + nextafter_x_ones = np_nextafter(x, ones) + + output_ndarray_float_2(x) + output_ndarray_float_2(zeros) + output_ndarray_float_2(ones) + output_ndarray_float_2(nextafter_x_zeros) + output_ndarray_float_2(nextafter_x_ones) + +def test_ndarray_nextafter_broadcast(): + x = np_identity(2) + nextafter_x_zeros = np_nextafter(x, np_zeros([2])) + nextafter_x_ones = np_nextafter(x, np_ones([2])) + + output_ndarray_float_2(x) + output_ndarray_float_2(nextafter_x_zeros) + output_ndarray_float_2(nextafter_x_ones) + +def test_ndarray_nextafter_broadcast_lhs_scalar(): + x = np_identity(2) + nextafter_x_zeros = np_nextafter(0.0, x) + nextafter_x_ones = np_nextafter(1.0, x) + + output_ndarray_float_2(x) + output_ndarray_float_2(nextafter_x_zeros) + output_ndarray_float_2(nextafter_x_ones) + +def test_ndarray_nextafter_broadcast_rhs_scalar(): + x = np_identity(2) + nextafter_x_zeros = np_nextafter(x, 0.0) + nextafter_x_ones = np_nextafter(x, 1.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(nextafter_x_zeros) + output_ndarray_float_2(nextafter_x_ones) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -739,4 +1354,76 @@ def run() -> int32: test_ndarray_ge_broadcast_lhs_scalar() test_ndarray_ge_broadcast_rhs_scalar() + test_ndarray_int32() + test_ndarray_int64() + test_ndarray_uint32() + test_ndarray_uint64() + test_ndarray_float() + test_ndarray_bool() + + test_ndarray_round() + test_ndarray_floor() + test_ndarray_abs() + test_ndarray_isnan() + test_ndarray_isinf() + + test_ndarray_sin() + test_ndarray_cos() + test_ndarray_exp() + test_ndarray_exp2() + test_ndarray_log() + test_ndarray_log10() + test_ndarray_log2() + test_ndarray_fabs() + test_ndarray_sqrt() + test_ndarray_rint() + test_ndarray_tan() + test_ndarray_arcsin() + test_ndarray_arccos() + test_ndarray_arctan() + test_ndarray_sinh() + test_ndarray_cosh() + test_ndarray_tanh() + test_ndarray_arcsinh() + test_ndarray_arccosh() + test_ndarray_arctanh() + test_ndarray_expm1() + test_ndarray_cbrt() + + test_ndarray_erf() + test_ndarray_erfc() + test_ndarray_gamma() + test_ndarray_gammaln() + test_ndarray_j0() + test_ndarray_j1() + + test_ndarray_arctan2() + test_ndarray_arctan2_broadcast() + test_ndarray_arctan2_broadcast_lhs_scalar() + test_ndarray_arctan2_broadcast_rhs_scalar() + test_ndarray_copysign() + test_ndarray_copysign_broadcast() + test_ndarray_copysign_broadcast_lhs_scalar() + test_ndarray_copysign_broadcast_rhs_scalar() + test_ndarray_fmax() + test_ndarray_fmax_broadcast() + test_ndarray_fmax_broadcast_lhs_scalar() + test_ndarray_fmax_broadcast_rhs_scalar() + test_ndarray_fmin() + test_ndarray_fmin_broadcast() + test_ndarray_fmin_broadcast_lhs_scalar() + test_ndarray_fmin_broadcast_rhs_scalar() + test_ndarray_ldexp() + test_ndarray_ldexp_broadcast() + test_ndarray_ldexp_broadcast_lhs_scalar() + test_ndarray_ldexp_broadcast_rhs_scalar() + test_ndarray_hypot() + test_ndarray_hypot_broadcast() + test_ndarray_hypot_broadcast_lhs_scalar() + test_ndarray_hypot_broadcast_rhs_scalar() + test_ndarray_nextafter() + test_ndarray_nextafter_broadcast() + test_ndarray_nextafter_broadcast_lhs_scalar() + test_ndarray_nextafter_broadcast_rhs_scalar() + return 0