2024-04-25 15:47:16 +08:00
use inkwell ::types ::BasicTypeEnum ;
use inkwell ::values ::BasicValueEnum ;
2024-06-12 14:45:03 +08:00
use inkwell ::{ FloatPredicate , IntPredicate , OptimizationLevel } ;
2024-04-24 17:40:25 +08:00
use itertools ::Itertools ;
2024-06-06 12:16:09 +08:00
use crate ::codegen ::classes ::{ NDArrayValue , ProxyValue , UntypedArrayLikeAccessor } ;
2024-04-25 15:47:16 +08:00
use crate ::codegen ::numpy ::ndarray_elementwise_unaryop_impl ;
2024-05-08 17:42:19 +08:00
use crate ::codegen ::stmt ::gen_for_callback_incrementing ;
2024-06-12 14:45:03 +08:00
use crate ::codegen ::{ extern_fns , irrt , llvm_intrinsics , numpy , CodeGenContext , CodeGenerator } ;
2024-06-12 15:01:01 +08:00
use crate ::toplevel ::helper ::PrimDef ;
2024-04-25 15:47:16 +08:00
use crate ::toplevel ::numpy ::unpack_ndarray_var_tys ;
2024-04-24 17:40:25 +08:00
use crate ::typecheck ::typedef ::Type ;
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
///
/// The generated message will contain the function name and the name of the unsupported type.
2024-06-12 14:45:03 +08:00
fn unsupported_type ( ctx : & CodeGenContext < '_ , '_ > , fn_name : & str , tys : & [ Type ] ) -> ! {
2024-04-24 17:40:25 +08:00
unreachable! (
" {fn_name}() not supported for '{}' " ,
tys . iter ( ) . map ( | ty | format! ( " ' {} ' " , ctx . unifier . stringify ( * ty ) ) ) . join ( " , " ) ,
)
}
/// Invokes the `int32` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_int32 < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
n : ( Type , BasicValueEnum < ' ctx > ) ,
2024-04-25 15:47:16 +08:00
) -> Result < BasicValueEnum < ' ctx > , String > {
2024-04-24 17:40:25 +08:00
let llvm_i32 = ctx . ctx . i32_type ( ) ;
2024-04-25 15:47:16 +08:00
let llvm_usize = generator . get_size_type ( ctx . ctx ) ;
2024-04-24 17:40:25 +08:00
let ( n_ty , n ) = n ;
2024-04-25 15:47:16 +08:00
Ok ( match n {
BasicValueEnum ::IntValue ( n ) if matches! ( n . get_type ( ) . get_bit_width ( ) , 1 | 8 ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . bool ) ) ;
2024-06-12 14:45:03 +08:00
ctx . builder . build_int_z_extend ( n , llvm_i32 , " zext " ) . map ( Into ::into ) . unwrap ( )
2024-04-24 17:40:25 +08:00
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::IntValue ( n ) if n . get_type ( ) . get_bit_width ( ) = = 32 = > {
2024-06-12 14:45:03 +08:00
debug_assert! ( [ ctx . primitives . int32 , ctx . primitives . uint32 , ]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( n_ty , * ty ) ) ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
n . into ( )
2024-04-24 17:40:25 +08:00
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::IntValue ( n ) if n . get_type ( ) . get_bit_width ( ) = = 64 = > {
2024-06-12 14:45:03 +08:00
debug_assert! ( [ ctx . primitives . int64 , ctx . primitives . uint64 , ]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( n_ty , * ty ) ) ) ;
2024-04-24 17:40:25 +08:00
2024-06-12 14:45:03 +08:00
ctx . builder . build_int_truncate ( n , llvm_i32 , " trunc " ) . map ( Into ::into ) . unwrap ( )
2024-04-24 17:40:25 +08:00
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::FloatValue ( n ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . float ) ) ;
2024-06-12 14:45:03 +08:00
let to_int64 =
ctx . builder . build_float_to_signed_int ( n , ctx . ctx . i64_type ( ) , " " ) . unwrap ( ) ;
ctx . builder . build_int_truncate ( to_int64 , llvm_i32 , " conv " ) . map ( Into ::into ) . unwrap ( )
2024-04-24 17:40:25 +08:00
}
2024-06-12 14:45:03 +08:00
BasicValueEnum ::PointerValue ( n )
2024-06-12 15:01:01 +08:00
if n_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) = >
2024-06-12 14:45:03 +08:00
{
2024-04-25 15:47:16 +08:00
let ( elem_ty , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , n_ty ) ;
let ndarray = ndarray_elementwise_unaryop_impl (
generator ,
ctx ,
ctx . primitives . int32 ,
None ,
NDArrayValue ::from_ptr_val ( n , llvm_usize , None ) ,
2024-06-12 14:45:03 +08:00
| generator , ctx , val | call_int32 ( generator , ctx , ( elem_ty , val ) ) ,
2024-04-25 15:47:16 +08:00
) ? ;
2024-06-06 12:16:09 +08:00
ndarray . as_base_value ( ) . into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , " int32 " , & [ n_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `int64` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_int64 < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
n : ( Type , BasicValueEnum < ' ctx > ) ,
2024-04-25 15:47:16 +08:00
) -> Result < BasicValueEnum < ' ctx > , String > {
2024-04-24 17:40:25 +08:00
let llvm_i64 = ctx . ctx . i64_type ( ) ;
2024-04-25 15:47:16 +08:00
let llvm_usize = generator . get_size_type ( ctx . ctx ) ;
2024-04-24 17:40:25 +08:00
let ( n_ty , n ) = n ;
2024-04-25 15:47:16 +08:00
Ok ( match n {
BasicValueEnum ::IntValue ( n ) if matches! ( n . get_type ( ) . get_bit_width ( ) , 1 | 8 | 32 ) = > {
2024-06-12 14:45:03 +08:00
debug_assert! ( [ ctx . primitives . bool , ctx . primitives . int32 , ctx . primitives . uint32 , ]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( n_ty , * ty ) ) ) ;
2024-04-24 17:40:25 +08:00
if ctx . unifier . unioned ( n_ty , ctx . primitives . int32 ) {
2024-06-12 14:45:03 +08:00
ctx . builder . build_int_s_extend ( n , llvm_i64 , " sext " ) . map ( Into ::into ) . unwrap ( )
2024-04-24 17:40:25 +08:00
} else {
2024-06-12 14:45:03 +08:00
ctx . builder . build_int_z_extend ( n , llvm_i64 , " zext " ) . map ( Into ::into ) . unwrap ( )
2024-04-24 17:40:25 +08:00
}
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::IntValue ( n ) if n . get_type ( ) . get_bit_width ( ) = = 64 = > {
2024-06-12 14:45:03 +08:00
debug_assert! ( [ ctx . primitives . int64 , ctx . primitives . uint64 , ]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( n_ty , * ty ) ) ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
n . into ( )
2024-04-24 17:40:25 +08:00
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::FloatValue ( n ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . float ) ) ;
ctx . builder
2024-04-25 15:47:16 +08:00
. build_float_to_signed_int ( n , ctx . ctx . i64_type ( ) , " fptosi " )
. map ( Into ::into )
2024-04-24 17:40:25 +08:00
. unwrap ( )
}
2024-06-12 14:45:03 +08:00
BasicValueEnum ::PointerValue ( n )
2024-06-12 15:01:01 +08:00
if n_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) = >
2024-06-12 14:45:03 +08:00
{
2024-04-25 15:47:16 +08:00
let ( elem_ty , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , n_ty ) ;
let ndarray = ndarray_elementwise_unaryop_impl (
generator ,
ctx ,
ctx . primitives . int64 ,
None ,
NDArrayValue ::from_ptr_val ( n , llvm_usize , None ) ,
2024-06-12 14:45:03 +08:00
| generator , ctx , val | call_int64 ( generator , ctx , ( elem_ty , val ) ) ,
2024-04-25 15:47:16 +08:00
) ? ;
2024-06-06 12:16:09 +08:00
ndarray . as_base_value ( ) . into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , " int64 " , & [ n_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `uint32` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_uint32 < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
n : ( Type , BasicValueEnum < ' ctx > ) ,
2024-04-25 15:47:16 +08:00
) -> Result < BasicValueEnum < ' ctx > , String > {
2024-04-24 17:40:25 +08:00
let llvm_i32 = ctx . ctx . i32_type ( ) ;
2024-04-25 15:47:16 +08:00
let llvm_usize = generator . get_size_type ( ctx . ctx ) ;
2024-04-24 17:40:25 +08:00
let ( n_ty , n ) = n ;
2024-04-25 15:47:16 +08:00
Ok ( match n {
BasicValueEnum ::IntValue ( n ) if matches! ( n . get_type ( ) . get_bit_width ( ) , 1 | 8 ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . bool ) ) ;
2024-06-12 14:45:03 +08:00
ctx . builder . build_int_z_extend ( n , llvm_i32 , " zext " ) . map ( Into ::into ) . unwrap ( )
2024-04-24 17:40:25 +08:00
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::IntValue ( n ) if n . get_type ( ) . get_bit_width ( ) = = 32 = > {
2024-06-12 14:45:03 +08:00
debug_assert! ( [ ctx . primitives . int32 , ctx . primitives . uint32 , ]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( n_ty , * ty ) ) ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
n . into ( )
2024-04-24 17:40:25 +08:00
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::IntValue ( n ) if n . get_type ( ) . get_bit_width ( ) = = 64 = > {
2024-04-24 17:40:25 +08:00
debug_assert! (
ctx . unifier . unioned ( n_ty , ctx . primitives . int64 )
| | ctx . unifier . unioned ( n_ty , ctx . primitives . uint64 )
) ;
2024-06-12 14:45:03 +08:00
ctx . builder . build_int_truncate ( n , llvm_i32 , " trunc " ) . map ( Into ::into ) . unwrap ( )
2024-04-24 17:40:25 +08:00
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::FloatValue ( n ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . float ) ) ;
2024-06-12 14:45:03 +08:00
let n_gez = ctx
. builder
2024-04-25 15:47:16 +08:00
. build_float_compare ( FloatPredicate ::OGE , n , n . get_type ( ) . const_zero ( ) , " " )
2024-04-24 17:40:25 +08:00
. unwrap ( ) ;
2024-06-12 14:45:03 +08:00
let to_int32 = ctx . builder . build_float_to_signed_int ( n , llvm_i32 , " " ) . unwrap ( ) ;
let to_uint64 =
ctx . builder . build_float_to_unsigned_int ( n , ctx . ctx . i64_type ( ) , " " ) . unwrap ( ) ;
2024-04-24 17:40:25 +08:00
ctx . builder
. build_select (
2024-04-25 15:47:16 +08:00
n_gez ,
2024-04-24 17:40:25 +08:00
ctx . builder . build_int_truncate ( to_uint64 , llvm_i32 , " " ) . unwrap ( ) ,
to_int32 ,
" conv " ,
)
. unwrap ( )
}
2024-06-12 14:45:03 +08:00
BasicValueEnum ::PointerValue ( n )
2024-06-12 15:01:01 +08:00
if n_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) = >
2024-06-12 14:45:03 +08:00
{
2024-04-25 15:47:16 +08:00
let ( elem_ty , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , n_ty ) ;
let ndarray = ndarray_elementwise_unaryop_impl (
generator ,
ctx ,
ctx . primitives . uint32 ,
None ,
NDArrayValue ::from_ptr_val ( n , llvm_usize , None ) ,
2024-06-12 14:45:03 +08:00
| generator , ctx , val | call_uint32 ( generator , ctx , ( elem_ty , val ) ) ,
2024-04-25 15:47:16 +08:00
) ? ;
2024-06-06 12:16:09 +08:00
ndarray . as_base_value ( ) . into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , " uint32 " , & [ n_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `uint64` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_uint64 < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
n : ( Type , BasicValueEnum < ' ctx > ) ,
2024-04-25 15:47:16 +08:00
) -> Result < BasicValueEnum < ' ctx > , String > {
2024-04-24 17:40:25 +08:00
let llvm_i64 = ctx . ctx . i64_type ( ) ;
2024-04-25 15:47:16 +08:00
let llvm_usize = generator . get_size_type ( ctx . ctx ) ;
2024-04-24 17:40:25 +08:00
let ( n_ty , n ) = n ;
2024-04-25 15:47:16 +08:00
Ok ( match n {
BasicValueEnum ::IntValue ( n ) if matches! ( n . get_type ( ) . get_bit_width ( ) , 1 | 8 | 32 ) = > {
2024-06-12 14:45:03 +08:00
debug_assert! ( [ ctx . primitives . bool , ctx . primitives . int32 , ctx . primitives . uint32 , ]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( n_ty , * ty ) ) ) ;
2024-04-24 17:40:25 +08:00
if ctx . unifier . unioned ( n_ty , ctx . primitives . int32 ) {
2024-06-12 14:45:03 +08:00
ctx . builder . build_int_s_extend ( n , llvm_i64 , " sext " ) . map ( Into ::into ) . unwrap ( )
2024-04-24 17:40:25 +08:00
} else {
2024-06-12 14:45:03 +08:00
ctx . builder . build_int_z_extend ( n , llvm_i64 , " zext " ) . map ( Into ::into ) . unwrap ( )
2024-04-24 17:40:25 +08:00
}
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::IntValue ( n ) if n . get_type ( ) . get_bit_width ( ) = = 64 = > {
2024-06-12 14:45:03 +08:00
debug_assert! ( [ ctx . primitives . int64 , ctx . primitives . uint64 , ]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( n_ty , * ty ) ) ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
n . into ( )
2024-04-24 17:40:25 +08:00
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::FloatValue ( n ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . float ) ) ;
2024-06-12 14:45:03 +08:00
let val_gez = ctx
. builder
2024-04-25 15:47:16 +08:00
. build_float_compare ( FloatPredicate ::OGE , n , n . get_type ( ) . const_zero ( ) , " " )
2024-04-24 17:40:25 +08:00
. unwrap ( ) ;
2024-06-12 14:45:03 +08:00
let to_int64 = ctx . builder . build_float_to_signed_int ( n , llvm_i64 , " " ) . unwrap ( ) ;
let to_uint64 = ctx . builder . build_float_to_unsigned_int ( n , llvm_i64 , " " ) . unwrap ( ) ;
2024-04-24 17:40:25 +08:00
2024-06-12 14:45:03 +08:00
ctx . builder . build_select ( val_gez , to_uint64 , to_int64 , " conv " ) . unwrap ( )
2024-04-24 17:40:25 +08:00
}
2024-06-12 14:45:03 +08:00
BasicValueEnum ::PointerValue ( n )
2024-06-12 15:01:01 +08:00
if n_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) = >
2024-06-12 14:45:03 +08:00
{
2024-04-25 15:47:16 +08:00
let ( elem_ty , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , n_ty ) ;
let ndarray = ndarray_elementwise_unaryop_impl (
generator ,
ctx ,
ctx . primitives . uint64 ,
None ,
NDArrayValue ::from_ptr_val ( n , llvm_usize , None ) ,
2024-06-12 14:45:03 +08:00
| generator , ctx , val | call_uint64 ( generator , ctx , ( elem_ty , val ) ) ,
2024-04-25 15:47:16 +08:00
) ? ;
2024-06-06 12:16:09 +08:00
ndarray . as_base_value ( ) . into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , " uint64 " , & [ n_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `float` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_float < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
n : ( Type , BasicValueEnum < ' ctx > ) ,
2024-04-25 15:47:16 +08:00
) -> Result < BasicValueEnum < ' ctx > , String > {
2024-04-24 17:40:25 +08:00
let llvm_f64 = ctx . ctx . f64_type ( ) ;
2024-04-25 15:47:16 +08:00
let llvm_usize = generator . get_size_type ( ctx . ctx ) ;
2024-04-24 17:40:25 +08:00
let ( n_ty , n ) = n ;
2024-04-25 15:47:16 +08:00
Ok ( match n {
BasicValueEnum ::IntValue ( n ) if matches! ( n . get_type ( ) . get_bit_width ( ) , 1 | 8 | 32 | 64 ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( [
ctx . primitives . bool ,
ctx . primitives . int32 ,
ctx . primitives . uint32 ,
ctx . primitives . int64 ,
ctx . primitives . uint64 ,
2024-06-12 14:45:03 +08:00
]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( n_ty , * ty ) ) ) ;
if [ ctx . primitives . bool , ctx . primitives . int32 , ctx . primitives . int64 ]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( n_ty , * ty ) )
{
2024-04-24 17:40:25 +08:00
ctx . builder
2024-04-25 15:47:16 +08:00
. build_signed_int_to_float ( n , llvm_f64 , " sitofp " )
. map ( Into ::into )
2024-04-24 17:40:25 +08:00
. unwrap ( )
} else {
ctx . builder
2024-04-25 15:47:16 +08:00
. build_unsigned_int_to_float ( n , llvm_f64 , " uitofp " )
. map ( Into ::into )
. unwrap ( )
2024-04-24 17:40:25 +08:00
}
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::FloatValue ( n ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . float ) ) ;
2024-04-25 15:47:16 +08:00
n . into ( )
}
2024-06-12 14:45:03 +08:00
BasicValueEnum ::PointerValue ( n )
2024-06-12 15:01:01 +08:00
if n_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) = >
2024-06-12 14:45:03 +08:00
{
2024-04-25 15:47:16 +08:00
let ( elem_ty , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , n_ty ) ;
let ndarray = ndarray_elementwise_unaryop_impl (
generator ,
ctx ,
ctx . primitives . float ,
None ,
NDArrayValue ::from_ptr_val ( n , llvm_usize , None ) ,
2024-06-12 14:45:03 +08:00
| generator , ctx , val | call_float ( generator , ctx , ( elem_ty , val ) ) ,
2024-04-25 15:47:16 +08:00
) ? ;
2024-06-06 12:16:09 +08:00
ndarray . as_base_value ( ) . into ( )
2024-04-24 17:40:25 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , " float " , & [ n_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `round` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_round < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-04-25 15:47:16 +08:00
n : ( Type , BasicValueEnum < ' ctx > ) ,
ret_elem_ty : Type ,
) -> Result < BasicValueEnum < ' ctx > , String > {
2024-04-24 17:40:25 +08:00
const FN_NAME : & str = " round " ;
2024-04-25 15:47:16 +08:00
let llvm_usize = generator . get_size_type ( ctx . ctx ) ;
2024-04-24 17:40:25 +08:00
let ( n_ty , n ) = n ;
2024-04-25 15:47:16 +08:00
let llvm_ret_elem_ty = ctx . get_llvm_abi_type ( generator , ret_elem_ty ) . into_int_type ( ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
Ok ( match n {
BasicValueEnum ::FloatValue ( n ) = > {
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . float ) ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
let val = llvm_intrinsics ::call_float_round ( ctx , n , None ) ;
ctx . builder
. build_float_to_signed_int ( val , llvm_ret_elem_ty , FN_NAME )
. map ( Into ::into )
2024-06-06 12:16:09 +08:00
. unwrap ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
BasicValueEnum ::PointerValue ( n )
2024-06-12 15:01:01 +08:00
if n_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) = >
2024-06-12 14:45:03 +08:00
{
2024-04-25 15:47:16 +08:00
let ( elem_ty , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , n_ty ) ;
let ndarray = ndarray_elementwise_unaryop_impl (
generator ,
ctx ,
ret_elem_ty ,
None ,
NDArrayValue ::from_ptr_val ( n , llvm_usize , None ) ,
2024-06-12 14:45:03 +08:00
| generator , ctx , val | call_round ( generator , ctx , ( elem_ty , val ) , ret_elem_ty ) ,
2024-04-25 15:47:16 +08:00
) ? ;
2024-06-06 12:16:09 +08:00
ndarray . as_base_value ( ) . into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ n_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `np_round` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_numpy_round < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-04-25 15:47:16 +08:00
n : ( Type , BasicValueEnum < ' ctx > ) ,
) -> Result < BasicValueEnum < ' ctx > , String > {
const FN_NAME : & str = " np_round " ;
let llvm_usize = generator . get_size_type ( ctx . ctx ) ;
2024-04-24 17:40:25 +08:00
let ( n_ty , n ) = n ;
2024-04-25 15:47:16 +08:00
Ok ( match n {
BasicValueEnum ::FloatValue ( n ) = > {
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . float ) ) ;
2024-04-24 17:40:25 +08:00
2024-07-05 13:35:22 +08:00
llvm_intrinsics ::call_float_rint ( ctx , n , None ) . into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
BasicValueEnum ::PointerValue ( n )
2024-06-12 15:01:01 +08:00
if n_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) = >
2024-06-12 14:45:03 +08:00
{
2024-04-25 15:47:16 +08:00
let ( elem_ty , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , n_ty ) ;
let ndarray = ndarray_elementwise_unaryop_impl (
generator ,
ctx ,
ctx . primitives . float ,
None ,
NDArrayValue ::from_ptr_val ( n , llvm_usize , None ) ,
2024-06-12 14:45:03 +08:00
| generator , ctx , val | call_numpy_round ( generator , ctx , ( elem_ty , val ) ) ,
2024-04-25 15:47:16 +08:00
) ? ;
2024-06-06 12:16:09 +08:00
ndarray . as_base_value ( ) . into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ n_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `bool` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_bool < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
n : ( Type , BasicValueEnum < ' ctx > ) ,
2024-04-25 15:47:16 +08:00
) -> Result < BasicValueEnum < ' ctx > , String > {
2024-04-24 17:40:25 +08:00
const FN_NAME : & str = " bool " ;
2024-04-25 15:47:16 +08:00
let llvm_usize = generator . get_size_type ( ctx . ctx ) ;
2024-04-24 17:40:25 +08:00
let ( n_ty , n ) = n ;
2024-04-25 15:47:16 +08:00
Ok ( match n {
BasicValueEnum ::IntValue ( n ) if matches! ( n . get_type ( ) . get_bit_width ( ) , 1 | 8 ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . bool ) ) ;
2024-04-25 15:47:16 +08:00
n . into ( )
2024-04-24 17:40:25 +08:00
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::IntValue ( n ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( [
ctx . primitives . int32 ,
ctx . primitives . uint32 ,
ctx . primitives . int64 ,
ctx . primitives . uint64 ,
2024-06-12 14:45:03 +08:00
]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( n_ty , * ty ) ) ) ;
2024-04-24 17:40:25 +08:00
ctx . builder
2024-04-25 15:47:16 +08:00
. build_int_compare ( IntPredicate ::NE , n , n . get_type ( ) . const_zero ( ) , FN_NAME )
. map ( Into ::into )
2024-04-24 17:40:25 +08:00
. unwrap ( )
}
2024-04-25 15:47:16 +08:00
BasicValueEnum ::FloatValue ( n ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . float ) ) ;
ctx . builder
2024-04-25 15:47:16 +08:00
. build_float_compare ( FloatPredicate ::UNE , n , n . get_type ( ) . const_zero ( ) , FN_NAME )
. map ( Into ::into )
2024-04-24 17:40:25 +08:00
. unwrap ( )
}
2024-06-12 14:45:03 +08:00
BasicValueEnum ::PointerValue ( n )
2024-06-12 15:01:01 +08:00
if n_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) = >
2024-06-12 14:45:03 +08:00
{
2024-04-25 15:47:16 +08:00
let ( elem_ty , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , n_ty ) ;
let ndarray = ndarray_elementwise_unaryop_impl (
generator ,
ctx ,
ctx . primitives . bool ,
None ,
NDArrayValue ::from_ptr_val ( n , llvm_usize , None ) ,
| generator , ctx , val | {
2024-06-12 14:45:03 +08:00
let elem = call_bool ( generator , ctx , ( elem_ty , val ) ) ? ;
2024-04-25 15:47:16 +08:00
Ok ( generator . bool_to_i8 ( ctx , elem . into_int_value ( ) ) . into ( ) )
} ,
) ? ;
2024-06-06 12:16:09 +08:00
ndarray . as_base_value ( ) . into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ n_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `floor` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_floor < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-04-25 15:47:16 +08:00
n : ( Type , BasicValueEnum < ' ctx > ) ,
ret_elem_ty : Type ,
) -> Result < BasicValueEnum < ' ctx > , String > {
2024-04-24 17:40:25 +08:00
const FN_NAME : & str = " floor " ;
2024-04-25 15:47:16 +08:00
let llvm_usize = generator . get_size_type ( ctx . ctx ) ;
2024-04-24 17:40:25 +08:00
let ( n_ty , n ) = n ;
2024-04-25 15:47:16 +08:00
let llvm_ret_elem_ty = ctx . get_llvm_abi_type ( generator , ret_elem_ty ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
Ok ( match n {
BasicValueEnum ::FloatValue ( n ) = > {
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . float ) ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
let val = llvm_intrinsics ::call_float_floor ( ctx , n , None ) ;
if let BasicTypeEnum ::IntType ( llvm_ret_elem_ty ) = llvm_ret_elem_ty {
ctx . builder
. build_float_to_signed_int ( val , llvm_ret_elem_ty , FN_NAME )
. map ( Into ::into )
. unwrap ( )
} else {
val . into ( )
}
}
2024-06-12 14:45:03 +08:00
BasicValueEnum ::PointerValue ( n )
2024-06-12 15:01:01 +08:00
if n_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) = >
2024-06-12 14:45:03 +08:00
{
2024-04-25 15:47:16 +08:00
let ( elem_ty , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , n_ty ) ;
let ndarray = ndarray_elementwise_unaryop_impl (
generator ,
ctx ,
ret_elem_ty ,
None ,
NDArrayValue ::from_ptr_val ( n , llvm_usize , None ) ,
2024-06-12 14:45:03 +08:00
| generator , ctx , val | call_floor ( generator , ctx , ( elem_ty , val ) , ret_elem_ty ) ,
2024-04-25 15:47:16 +08:00
) ? ;
2024-06-06 12:16:09 +08:00
ndarray . as_base_value ( ) . into ( )
2024-04-24 17:40:25 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ n_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `ceil` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_ceil < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-04-25 15:47:16 +08:00
n : ( Type , BasicValueEnum < ' ctx > ) ,
ret_elem_ty : Type ,
) -> Result < BasicValueEnum < ' ctx > , String > {
2024-04-24 17:40:25 +08:00
const FN_NAME : & str = " ceil " ;
2024-04-25 15:47:16 +08:00
let llvm_usize = generator . get_size_type ( ctx . ctx ) ;
2024-04-24 17:40:25 +08:00
let ( n_ty , n ) = n ;
2024-04-25 15:47:16 +08:00
let llvm_ret_elem_ty = ctx . get_llvm_abi_type ( generator , ret_elem_ty ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
Ok ( match n {
BasicValueEnum ::FloatValue ( n ) = > {
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . float ) ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
let val = llvm_intrinsics ::call_float_ceil ( ctx , n , None ) ;
if let BasicTypeEnum ::IntType ( llvm_ret_elem_ty ) = llvm_ret_elem_ty {
ctx . builder
. build_float_to_signed_int ( val , llvm_ret_elem_ty , FN_NAME )
. map ( Into ::into )
. unwrap ( )
} else {
val . into ( )
}
}
2024-06-12 14:45:03 +08:00
BasicValueEnum ::PointerValue ( n )
2024-06-12 15:01:01 +08:00
if n_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) = >
2024-06-12 14:45:03 +08:00
{
2024-04-25 15:47:16 +08:00
let ( elem_ty , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , n_ty ) ;
let ndarray = ndarray_elementwise_unaryop_impl (
generator ,
ctx ,
ret_elem_ty ,
None ,
NDArrayValue ::from_ptr_val ( n , llvm_usize , None ) ,
2024-06-12 14:45:03 +08:00
| generator , ctx , val | call_floor ( generator , ctx , ( elem_ty , val ) , ret_elem_ty ) ,
2024-04-25 15:47:16 +08:00
) ? ;
2024-06-06 12:16:09 +08:00
ndarray . as_base_value ( ) . into ( )
2024-04-24 17:40:25 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ n_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `min` builtin function.
pub fn call_min < ' ctx > (
ctx : & mut CodeGenContext < ' ctx , '_ > ,
m : ( Type , BasicValueEnum < ' ctx > ) ,
n : ( Type , BasicValueEnum < ' ctx > ) ,
) -> BasicValueEnum < ' ctx > {
const FN_NAME : & str = " min " ;
let ( m_ty , m ) = m ;
let ( n_ty , n ) = n ;
2024-04-25 15:47:16 +08:00
let common_ty = if ctx . unifier . unioned ( m_ty , n_ty ) {
m_ty
} else {
2024-04-24 17:40:25 +08:00
unsupported_type ( ctx , FN_NAME , & [ m_ty , n_ty ] )
2024-04-25 15:47:16 +08:00
} ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
match ( m , n ) {
( BasicValueEnum ::IntValue ( m ) , BasicValueEnum ::IntValue ( n ) ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( [
ctx . primitives . bool ,
ctx . primitives . int32 ,
ctx . primitives . uint32 ,
ctx . primitives . int64 ,
ctx . primitives . uint64 ,
2024-06-12 14:45:03 +08:00
]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( common_ty , * ty ) ) ) ;
if [ ctx . primitives . int32 , ctx . primitives . int64 ]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( common_ty , * ty ) )
{
2024-04-24 17:40:25 +08:00
llvm_intrinsics ::call_int_smin ( ctx , m , n , Some ( FN_NAME ) ) . into ( )
} else {
llvm_intrinsics ::call_int_umin ( ctx , m , n , Some ( FN_NAME ) ) . into ( )
}
}
2024-04-25 15:47:16 +08:00
( BasicValueEnum ::FloatValue ( m ) , BasicValueEnum ::FloatValue ( n ) ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( ctx . unifier . unioned ( common_ty , ctx . primitives . float ) ) ;
llvm_intrinsics ::call_float_minnum ( ctx , m , n , Some ( FN_NAME ) ) . into ( )
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ m_ty , n_ty ] ) ,
2024-04-24 17:40:25 +08:00
}
}
2024-05-08 18:29:11 +08:00
/// Invokes the `np_minimum` builtin function.
pub fn call_numpy_minimum < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
ctx : & mut CodeGenContext < ' ctx , '_ > ,
x1 : ( Type , BasicValueEnum < ' ctx > ) ,
x2 : ( Type , BasicValueEnum < ' ctx > ) ,
) -> Result < BasicValueEnum < ' ctx > , String > {
const FN_NAME : & str = " np_minimum " ;
let ( x1_ty , x1 ) = x1 ;
let ( x2_ty , x2 ) = x2 ;
2024-06-12 14:45:03 +08:00
let common_ty = if ctx . unifier . unioned ( x1_ty , x2_ty ) { Some ( x1_ty ) } else { None } ;
2024-05-08 18:29:11 +08:00
Ok ( match ( x1 , x2 ) {
( BasicValueEnum ::IntValue ( x1 ) , BasicValueEnum ::IntValue ( x2 ) ) = > {
debug_assert! ( [
ctx . primitives . bool ,
ctx . primitives . int32 ,
ctx . primitives . uint32 ,
ctx . primitives . int64 ,
ctx . primitives . uint64 ,
ctx . primitives . float ,
2024-06-12 14:45:03 +08:00
]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( common_ty . unwrap ( ) , * ty ) ) ) ;
2024-05-08 18:29:11 +08:00
call_min ( ctx , ( x1_ty , x1 . into ( ) ) , ( x2_ty , x2 . into ( ) ) )
}
( BasicValueEnum ::FloatValue ( x1 ) , BasicValueEnum ::FloatValue ( x2 ) ) = > {
debug_assert! ( ctx . unifier . unioned ( common_ty . unwrap ( ) , ctx . primitives . float ) ) ;
call_min ( ctx , ( x1_ty , x1 . into ( ) ) , ( x2_ty , x2 . into ( ) ) )
}
2024-06-12 14:45:03 +08:00
( x1 , x2 )
if [ & x1_ty , & x2_ty ] . into_iter ( ) . any ( | ty | {
2024-06-12 15:01:01 +08:00
ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) )
2024-06-12 14:45:03 +08:00
} ) = >
{
let is_ndarray1 =
2024-06-12 15:01:01 +08:00
x1_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-06-12 14:45:03 +08:00
let is_ndarray2 =
2024-06-12 15:01:01 +08:00
x2_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-05-08 18:29:11 +08:00
let dtype = if is_ndarray1 & & is_ndarray2 {
let ( ndarray_dtype1 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) ;
let ( ndarray_dtype2 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) ;
debug_assert! ( ctx . unifier . unioned ( ndarray_dtype1 , ndarray_dtype2 ) ) ;
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) . 0
} else if is_ndarray2 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) . 0
} else {
2024-06-12 14:45:03 +08:00
unreachable! ( )
2024-05-08 18:29:11 +08:00
} ;
2024-06-12 14:45:03 +08:00
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty } ;
let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty } ;
2024-05-08 18:29:11 +08:00
numpy ::ndarray_elementwise_binop_impl (
generator ,
ctx ,
dtype ,
None ,
( x1 , ! is_ndarray1 ) ,
( x2 , ! is_ndarray2 ) ,
| generator , ctx , ( lhs , rhs ) | {
call_numpy_minimum ( generator , ctx , ( x1_scalar_ty , lhs ) , ( x2_scalar_ty , rhs ) )
} ,
2024-06-12 14:45:03 +08:00
) ?
. as_base_value ( )
. into ( )
2024-05-08 18:29:11 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ x1_ty , x2_ty ] ) ,
2024-05-08 18:29:11 +08:00
} )
}
2024-04-24 17:40:25 +08:00
/// Invokes the `max` builtin function.
pub fn call_max < ' ctx > (
ctx : & mut CodeGenContext < ' ctx , '_ > ,
m : ( Type , BasicValueEnum < ' ctx > ) ,
n : ( Type , BasicValueEnum < ' ctx > ) ,
) -> BasicValueEnum < ' ctx > {
const FN_NAME : & str = " max " ;
let ( m_ty , m ) = m ;
let ( n_ty , n ) = n ;
2024-04-25 15:47:16 +08:00
let common_ty = if ctx . unifier . unioned ( m_ty , n_ty ) {
m_ty
} else {
2024-04-24 17:40:25 +08:00
unsupported_type ( ctx , FN_NAME , & [ m_ty , n_ty ] )
2024-04-25 15:47:16 +08:00
} ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
match ( m , n ) {
( BasicValueEnum ::IntValue ( m ) , BasicValueEnum ::IntValue ( n ) ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( [
ctx . primitives . bool ,
ctx . primitives . int32 ,
ctx . primitives . uint32 ,
ctx . primitives . int64 ,
ctx . primitives . uint64 ,
2024-06-12 14:45:03 +08:00
]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( common_ty , * ty ) ) ) ;
if [ ctx . primitives . int32 , ctx . primitives . int64 ]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( common_ty , * ty ) )
{
2024-04-24 17:40:25 +08:00
llvm_intrinsics ::call_int_smax ( ctx , m , n , Some ( FN_NAME ) ) . into ( )
} else {
llvm_intrinsics ::call_int_umax ( ctx , m , n , Some ( FN_NAME ) ) . into ( )
}
}
2024-04-25 15:47:16 +08:00
( BasicValueEnum ::FloatValue ( m ) , BasicValueEnum ::FloatValue ( n ) ) = > {
2024-04-24 17:40:25 +08:00
debug_assert! ( ctx . unifier . unioned ( common_ty , ctx . primitives . float ) ) ;
llvm_intrinsics ::call_float_maxnum ( ctx , m , n , Some ( FN_NAME ) ) . into ( )
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ m_ty , n_ty ] ) ,
2024-04-24 17:40:25 +08:00
}
}
2024-07-12 18:18:54 +08:00
/// Invokes the np_max, np_min, np_argmax, np_argmin functions
/// * `fn_name`: Can be one of "np_argmin", "np_argmax", "np_max", "np_min"
pub fn call_numpy_max_min < ' ctx , G : CodeGenerator + ? Sized > (
2024-05-08 17:42:19 +08:00
generator : & mut G ,
ctx : & mut CodeGenContext < ' ctx , '_ > ,
a : ( Type , BasicValueEnum < ' ctx > ) ,
2024-07-12 18:18:54 +08:00
fn_name : & str ,
2024-05-08 17:42:19 +08:00
) -> Result < BasicValueEnum < ' ctx > , String > {
2024-07-12 18:18:54 +08:00
debug_assert! ( [ " np_argmin " , " np_argmax " , " np_max " , " np_min " ] . iter ( ) . any ( | f | * f = = fn_name ) ) ;
2024-05-08 17:42:19 +08:00
2024-07-12 18:18:54 +08:00
let llvm_int64 = ctx . ctx . i64_type ( ) ;
2024-05-08 17:42:19 +08:00
let llvm_usize = generator . get_size_type ( ctx . ctx ) ;
let ( a_ty , a ) = a ;
2024-07-12 21:16:38 +08:00
Ok ( match a {
2024-05-08 17:42:19 +08:00
BasicValueEnum ::IntValue ( _ ) | BasicValueEnum ::FloatValue ( _ ) = > {
debug_assert! ( [
ctx . primitives . bool ,
ctx . primitives . int32 ,
ctx . primitives . uint32 ,
ctx . primitives . int64 ,
ctx . primitives . uint64 ,
ctx . primitives . float ,
2024-06-12 14:45:03 +08:00
]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( a_ty , * ty ) ) ) ;
2024-07-12 21:16:38 +08:00
2024-07-12 18:18:54 +08:00
match fn_name {
" np_argmin " | " np_argmax " = > llvm_int64 . const_zero ( ) . into ( ) ,
" np_max " | " np_min " = > a ,
2024-07-12 21:16:38 +08:00
_ = > unreachable! ( ) ,
2024-07-12 18:18:54 +08:00
}
2024-05-08 17:42:19 +08:00
}
2024-06-12 14:45:03 +08:00
BasicValueEnum ::PointerValue ( n )
2024-06-12 15:01:01 +08:00
if a_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) = >
2024-06-12 14:45:03 +08:00
{
2024-07-12 21:16:38 +08:00
let ( elem_ty , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , a_ty ) ;
2024-05-08 17:42:19 +08:00
let llvm_ndarray_ty = ctx . get_llvm_type ( generator , elem_ty ) ;
let n = NDArrayValue ::from_ptr_val ( n , llvm_usize , None ) ;
2024-05-27 15:58:06 +08:00
let n_sz = irrt ::call_ndarray_calc_size ( generator , ctx , & n . dim_sizes ( ) , ( None , None ) ) ;
2024-05-08 17:42:19 +08:00
if ctx . registry . llvm_options . opt_level = = OptimizationLevel ::None {
2024-06-12 14:45:03 +08:00
let n_sz_eqz = ctx
. builder
. build_int_compare ( IntPredicate ::NE , n_sz , n_sz . get_type ( ) . const_zero ( ) , " " )
2024-05-08 17:42:19 +08:00
. unwrap ( ) ;
ctx . make_assert (
generator ,
n_sz_eqz ,
" 0:ValueError " ,
2024-07-12 18:18:54 +08:00
format! ( " zero-size array to reduction operation {} " , fn_name ) . as_str ( ) ,
2024-05-08 17:42:19 +08:00
[ None , None , None ] ,
ctx . current_loc ,
) ;
}
let accumulator_addr = generator . gen_var_alloc ( ctx , llvm_ndarray_ty , None ) ? ;
2024-07-12 18:18:54 +08:00
let res_idx = generator . gen_var_alloc ( ctx , llvm_int64 . into ( ) , None ) ? ;
2024-05-08 17:42:19 +08:00
unsafe {
2024-06-12 14:45:03 +08:00
let identity =
n . data ( ) . get_unchecked ( ctx , generator , & llvm_usize . const_zero ( ) , None ) ;
2024-05-08 17:42:19 +08:00
ctx . builder . build_store ( accumulator_addr , identity ) . unwrap ( ) ;
2024-07-12 18:18:54 +08:00
ctx . builder . build_store ( res_idx , llvm_int64 . const_zero ( ) ) . unwrap ( ) ;
2024-05-08 17:42:19 +08:00
}
gen_for_callback_incrementing (
generator ,
ctx ,
2024-07-12 18:18:54 +08:00
llvm_int64 . const_int ( 1 , false ) ,
2024-05-08 17:42:19 +08:00
( n_sz , false ) ,
2024-07-12 21:16:38 +08:00
| generator , ctx , _ , idx | {
2024-06-12 14:45:03 +08:00
let elem = unsafe { n . data ( ) . get_unchecked ( ctx , generator , & idx , None ) } ;
2024-05-08 17:42:19 +08:00
let accumulator = ctx . builder . build_load ( accumulator_addr , " " ) . unwrap ( ) ;
2024-07-12 18:18:54 +08:00
let cur_idx = ctx . builder . build_load ( res_idx , " " ) . unwrap ( ) ;
let result = match fn_name {
2024-07-12 21:16:38 +08:00
" np_argmin " | " np_min " = > {
call_min ( ctx , ( elem_ty , accumulator ) , ( elem_ty , elem ) )
}
" np_argmax " | " np_max " = > {
call_max ( ctx , ( elem_ty , accumulator ) , ( elem_ty , elem ) )
}
_ = > unreachable! ( ) ,
2024-07-12 18:18:54 +08:00
} ;
2024-07-12 21:16:38 +08:00
let updated_idx = match ( accumulator , result ) {
( BasicValueEnum ::IntValue ( m ) , BasicValueEnum ::IntValue ( n ) ) = > ctx
. builder
. build_select (
ctx . builder . build_int_compare ( IntPredicate ::NE , m , n , " " ) . unwrap ( ) ,
idx . into ( ) ,
2024-07-12 18:18:54 +08:00
cur_idx ,
2024-07-12 21:16:38 +08:00
" " ,
)
. unwrap ( ) ,
( BasicValueEnum ::FloatValue ( m ) , BasicValueEnum ::FloatValue ( n ) ) = > ctx
. builder
. build_select (
ctx . builder
. build_float_compare ( FloatPredicate ::ONE , m , n , " " )
. unwrap ( ) ,
idx . into ( ) ,
2024-07-12 18:18:54 +08:00
cur_idx ,
2024-07-12 21:16:38 +08:00
" " ,
)
. unwrap ( ) ,
2024-07-12 18:18:54 +08:00
_ = > unsupported_type ( ctx , fn_name , & [ elem_ty , elem_ty ] ) ,
} ;
ctx . builder . build_store ( res_idx , updated_idx ) . unwrap ( ) ;
2024-05-08 17:42:19 +08:00
ctx . builder . build_store ( accumulator_addr , result ) . unwrap ( ) ;
Ok ( ( ) )
} ,
2024-07-12 18:18:54 +08:00
llvm_int64 . const_int ( 1 , false ) ,
2024-05-08 17:42:19 +08:00
) ? ;
2024-07-12 18:18:54 +08:00
match fn_name {
" np_argmin " | " np_argmax " = > ctx . builder . build_load ( res_idx , " " ) . unwrap ( ) ,
" np_max " | " np_min " = > ctx . builder . build_load ( accumulator_addr , " " ) . unwrap ( ) ,
2024-07-12 21:16:38 +08:00
_ = > unreachable! ( ) ,
2024-07-12 18:18:54 +08:00
}
2024-05-08 17:42:19 +08:00
}
2024-07-12 21:16:38 +08:00
_ = > unsupported_type ( ctx , fn_name , & [ a_ty ] ) ,
2024-05-08 17:42:19 +08:00
} )
}
2024-05-08 18:29:11 +08:00
/// Invokes the `np_maximum` builtin function.
pub fn call_numpy_maximum < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
ctx : & mut CodeGenContext < ' ctx , '_ > ,
x1 : ( Type , BasicValueEnum < ' ctx > ) ,
x2 : ( Type , BasicValueEnum < ' ctx > ) ,
) -> Result < BasicValueEnum < ' ctx > , String > {
const FN_NAME : & str = " np_maximum " ;
let ( x1_ty , x1 ) = x1 ;
let ( x2_ty , x2 ) = x2 ;
2024-06-12 14:45:03 +08:00
let common_ty = if ctx . unifier . unioned ( x1_ty , x2_ty ) { Some ( x1_ty ) } else { None } ;
2024-05-08 18:29:11 +08:00
Ok ( match ( x1 , x2 ) {
( BasicValueEnum ::IntValue ( x1 ) , BasicValueEnum ::IntValue ( x2 ) ) = > {
debug_assert! ( [
ctx . primitives . bool ,
ctx . primitives . int32 ,
ctx . primitives . uint32 ,
ctx . primitives . int64 ,
ctx . primitives . uint64 ,
ctx . primitives . float ,
2024-06-12 14:45:03 +08:00
]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( common_ty . unwrap ( ) , * ty ) ) ) ;
2024-05-08 18:29:11 +08:00
call_max ( ctx , ( x1_ty , x1 . into ( ) ) , ( x2_ty , x2 . into ( ) ) )
}
( BasicValueEnum ::FloatValue ( x1 ) , BasicValueEnum ::FloatValue ( x2 ) ) = > {
debug_assert! ( ctx . unifier . unioned ( common_ty . unwrap ( ) , ctx . primitives . float ) ) ;
call_max ( ctx , ( x1_ty , x1 . into ( ) ) , ( x2_ty , x2 . into ( ) ) )
}
2024-06-12 14:45:03 +08:00
( x1 , x2 )
if [ & x1_ty , & x2_ty ] . into_iter ( ) . any ( | ty | {
2024-06-12 15:01:01 +08:00
ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) )
2024-06-12 14:45:03 +08:00
} ) = >
{
let is_ndarray1 =
2024-06-12 15:01:01 +08:00
x1_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-06-12 14:45:03 +08:00
let is_ndarray2 =
2024-06-12 15:01:01 +08:00
x2_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-05-08 18:29:11 +08:00
let dtype = if is_ndarray1 & & is_ndarray2 {
let ( ndarray_dtype1 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) ;
let ( ndarray_dtype2 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) ;
debug_assert! ( ctx . unifier . unioned ( ndarray_dtype1 , ndarray_dtype2 ) ) ;
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) . 0
} else if is_ndarray2 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) . 0
} else {
2024-06-12 14:45:03 +08:00
unreachable! ( )
2024-05-08 18:29:11 +08:00
} ;
2024-06-12 14:45:03 +08:00
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty } ;
let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty } ;
2024-05-08 18:29:11 +08:00
numpy ::ndarray_elementwise_binop_impl (
generator ,
ctx ,
dtype ,
None ,
( x1 , ! is_ndarray1 ) ,
( x2 , ! is_ndarray2 ) ,
| generator , ctx , ( lhs , rhs ) | {
call_numpy_maximum ( generator , ctx , ( x1_scalar_ty , lhs ) , ( x2_scalar_ty , rhs ) )
} ,
2024-06-12 14:45:03 +08:00
) ?
. as_base_value ( )
. into ( )
2024-05-08 18:29:11 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ x1_ty , x2_ty ] ) ,
2024-05-08 18:29:11 +08:00
} )
}
2024-06-20 12:48:44 +08:00
/// Helper function to create a built-in elementwise unary numpy function that takes in either an ndarray or a scalar.
///
/// * `(arg_ty, arg_val)`: The [`Type`] and llvm value of the input argument.
/// * `fn_name`: The name of the function, only used when throwing an error with [`unsupported_type`]
/// * `get_ret_elem_type`: A function that takes in the input scalar [`Type`], and returns the function's return scalar [`Type`].
/// Return a constant [`Type`] here if the return type does not depend on the input type.
/// * `on_scalar`: The function that acts on the scalars of the input. Returns [`Option::None`]
/// if the scalar type & value are faulty and should panic with [`unsupported_type`].
fn helper_call_numpy_unary_elementwise < ' ctx , OnScalarFn , RetElemFn , G > (
2024-04-24 17:40:25 +08:00
generator : & mut G ,
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-06-20 12:48:44 +08:00
( arg_ty , arg_val ) : ( Type , BasicValueEnum < ' ctx > ) ,
fn_name : & str ,
get_ret_elem_type : & RetElemFn ,
on_scalar : & OnScalarFn ,
) -> Result < BasicValueEnum < ' ctx > , String >
where
G : CodeGenerator + ? Sized ,
OnScalarFn : Fn (
& mut G ,
& mut CodeGenContext < ' ctx , '_ > ,
Type ,
BasicValueEnum < ' ctx > ,
) -> Option < BasicValueEnum < ' ctx > > ,
RetElemFn : Fn ( & mut CodeGenContext < ' ctx , '_ > , Type ) -> Type ,
{
let result = match arg_val {
2024-06-12 14:45:03 +08:00
BasicValueEnum ::PointerValue ( x )
2024-06-20 12:48:44 +08:00
if arg_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) = >
2024-06-12 14:45:03 +08:00
{
2024-06-20 12:48:44 +08:00
let llvm_usize = generator . get_size_type ( ctx . ctx ) ;
let ( arg_elem_ty , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , arg_ty ) ;
let ret_elem_ty = get_ret_elem_type ( ctx , arg_elem_ty ) ;
2024-04-25 15:47:16 +08:00
let ndarray = ndarray_elementwise_unaryop_impl (
generator ,
ctx ,
2024-06-20 12:48:44 +08:00
ret_elem_ty ,
2024-04-25 15:47:16 +08:00
None ,
NDArrayValue ::from_ptr_val ( x , llvm_usize , None ) ,
2024-06-20 12:48:44 +08:00
| generator , ctx , elem_val | {
helper_call_numpy_unary_elementwise (
generator ,
ctx ,
( arg_elem_ty , elem_val ) ,
fn_name ,
get_ret_elem_type ,
on_scalar ,
)
2024-04-25 15:47:16 +08:00
} ,
) ? ;
2024-06-06 12:16:09 +08:00
ndarray . as_base_value ( ) . into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-20 12:48:44 +08:00
_ = > on_scalar ( generator , ctx , arg_ty , arg_val )
. unwrap_or_else ( | | unsupported_type ( ctx , fn_name , & [ arg_ty ] ) ) ,
} ;
2024-04-25 15:47:16 +08:00
2024-06-20 12:48:44 +08:00
Ok ( result )
2024-04-24 17:40:25 +08:00
}
2024-06-20 12:48:44 +08:00
pub fn call_abs < ' ctx , G : CodeGenerator + ? Sized > (
2024-04-25 15:47:16 +08:00
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-06-20 12:48:44 +08:00
n : ( Type , BasicValueEnum < ' ctx > ) ,
2024-04-25 15:47:16 +08:00
) -> Result < BasicValueEnum < ' ctx > , String > {
2024-06-20 12:48:44 +08:00
const FN_NAME : & str = " abs " ;
helper_call_numpy_unary_elementwise (
generator ,
ctx ,
n ,
FN_NAME ,
& | _ctx , elem_ty | elem_ty ,
& | _generator , ctx , val_ty , val | match val {
BasicValueEnum ::IntValue ( n ) = > Some ( {
debug_assert! ( [
ctx . primitives . bool ,
ctx . primitives . int32 ,
ctx . primitives . uint32 ,
ctx . primitives . int64 ,
ctx . primitives . uint64 ,
]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( val_ty , * ty ) ) ) ;
if [ ctx . primitives . int32 , ctx . primitives . int64 ]
. iter ( )
. any ( | ty | ctx . unifier . unioned ( val_ty , * ty ) )
{
llvm_intrinsics ::call_int_abs (
ctx ,
n ,
ctx . ctx . bool_type ( ) . const_zero ( ) ,
Some ( FN_NAME ) ,
)
. into ( )
} else {
n . into ( )
}
} ) ,
BasicValueEnum ::FloatValue ( n ) = > Some ( {
debug_assert! ( ctx . unifier . unioned ( val_ty , ctx . primitives . float ) ) ;
llvm_intrinsics ::call_float_fabs ( ctx , n , Some ( FN_NAME ) ) . into ( )
} ) ,
_ = > None ,
} ,
)
2024-04-24 17:40:25 +08:00
}
2024-06-20 12:48:44 +08:00
/// Macro to conveniently generate numpy functions with [`helper_call_numpy_unary_elementwise`].
///
/// Arguments:
/// * `$name:ident`: The identifier of the rust function to be generated.
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`]
/// * `$get_ret_elem_type:expr`: To be passed to the `get_ret_elem_type` parameter of [`helper_call_numpy_unary_elementwise`].
/// But there is no need to make it a reference.
/// * `$on_scalar:expr`: To be passed to the `on_scalar` parameter of [`helper_call_numpy_unary_elementwise`].
/// But there is no need to make it a reference.
macro_rules ! create_helper_call_numpy_unary_elementwise {
( $name :ident , $fn_name :literal , $get_ret_elem_type :expr , $on_scalar :expr ) = > {
#[ allow(clippy::redundant_closure_call) ]
pub fn $name < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
ctx : & mut CodeGenContext < ' ctx , '_ > ,
arg : ( Type , BasicValueEnum < ' ctx > ) ,
) -> Result < BasicValueEnum < ' ctx > , String > {
helper_call_numpy_unary_elementwise (
2024-04-25 15:47:16 +08:00
generator ,
ctx ,
2024-06-20 12:48:44 +08:00
arg ,
$fn_name ,
& $get_ret_elem_type ,
& $on_scalar ,
)
2024-04-25 15:47:16 +08:00
}
2024-06-20 12:48:44 +08:00
} ;
2024-04-24 17:40:25 +08:00
}
2024-06-20 12:48:44 +08:00
/// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns boolean (as an `i8`) elementwise.
///
/// Arguments:
/// * `$name:ident`: The identifier of the rust function to be generated.
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`].
/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns
/// the boolean results of LLVM type `i1`. The returned `i1` value will be converted into an `i8`.
///
2024-06-20 13:47:49 +08:00
/// ```ignore
2024-06-20 12:48:44 +08:00
/// // Type of `$on_scalar:expr`
/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>(
/// generator: &mut G,
/// ctx: &mut CodeGenContext<'ctx, '_>,
/// arg: FloatValue<'ctx>
/// ) -> IntValue<'ctx> // of LLVM type `i1`
/// ```
macro_rules ! create_helper_call_numpy_unary_elementwise_float_to_bool {
( $name :ident , $fn_name :literal , $on_scalar :expr ) = > {
create_helper_call_numpy_unary_elementwise! (
$name ,
$fn_name ,
| ctx , _ | ctx . primitives . bool ,
| generator , ctx , n_ty , val | {
match val {
BasicValueEnum ::FloatValue ( n ) = > {
debug_assert! ( ctx . unifier . unioned ( n_ty , ctx . primitives . float ) ) ;
let ret = $on_scalar ( generator , ctx , n ) ;
Some ( generator . bool_to_i8 ( ctx , ret ) . into ( ) )
}
_ = > None ,
}
}
) ;
} ;
2024-04-24 17:40:25 +08:00
}
2024-06-20 12:48:44 +08:00
/// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns float elementwise.
///
/// Arguments:
/// * `$name:ident`: The identifier of the rust function to be generated.
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`].
/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns float results.
///
2024-06-20 13:47:49 +08:00
/// ```ignore
2024-06-20 12:48:44 +08:00
/// // Type of `$on_scalar:expr`
/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>(
/// generator: &mut G,
/// ctx: &mut CodeGenContext<'ctx, '_>,
/// arg: FloatValue<'ctx>
/// ) -> FloatValue<'ctx>
/// ```
macro_rules ! create_helper_call_numpy_unary_elementwise_float_to_float {
( $name :ident , $fn_name :literal , $elem_call :expr ) = > {
create_helper_call_numpy_unary_elementwise! (
$name ,
$fn_name ,
| ctx , _ | ctx . primitives . float ,
| _generator , ctx , val_ty , val | {
match val {
BasicValueEnum ::FloatValue ( n ) = > {
debug_assert! ( ctx . unifier . unioned ( val_ty , ctx . primitives . float ) ) ;
Some ( $elem_call ( ctx , n , Option ::< & str > ::None ) . into ( ) )
}
_ = > None ,
}
}
) ;
} ;
2024-04-24 17:40:25 +08:00
}
2024-06-20 12:48:44 +08:00
create_helper_call_numpy_unary_elementwise_float_to_bool! (
call_numpy_isnan ,
" np_isnan " ,
irrt ::call_isnan
) ;
create_helper_call_numpy_unary_elementwise_float_to_bool! (
call_numpy_isinf ,
" np_isinf " ,
irrt ::call_isinf
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_sin ,
" np_sin " ,
llvm_intrinsics ::call_float_sin
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_cos ,
" np_cos " ,
llvm_intrinsics ::call_float_cos
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_tan ,
" np_tan " ,
extern_fns ::call_tan
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_arcsin ,
" np_arcsin " ,
extern_fns ::call_asin
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_arccos ,
" np_arccos " ,
extern_fns ::call_acos
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_arctan ,
" np_arctan " ,
extern_fns ::call_atan
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_sinh ,
" np_sinh " ,
extern_fns ::call_sinh
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_cosh ,
" np_cosh " ,
extern_fns ::call_cosh
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_tanh ,
" np_tanh " ,
extern_fns ::call_tanh
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_arcsinh ,
" np_arcsinh " ,
extern_fns ::call_asinh
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_arccosh ,
" np_arccosh " ,
extern_fns ::call_acosh
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_arctanh ,
" np_arctanh " ,
extern_fns ::call_atanh
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_exp ,
" np_exp " ,
llvm_intrinsics ::call_float_exp
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_exp2 ,
" np_exp2 " ,
llvm_intrinsics ::call_float_exp2
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_expm1 ,
" np_expm1 " ,
extern_fns ::call_expm1
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_log ,
" np_log " ,
llvm_intrinsics ::call_float_log
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_log2 ,
" np_log2 " ,
llvm_intrinsics ::call_float_log2
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_log10 ,
" np_log10 " ,
llvm_intrinsics ::call_float_log10
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_sqrt ,
" np_sqrt " ,
llvm_intrinsics ::call_float_sqrt
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_cbrt ,
" np_cbrt " ,
extern_fns ::call_cbrt
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_fabs ,
" np_fabs " ,
llvm_intrinsics ::call_float_fabs
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_numpy_rint ,
" np_rint " ,
2024-07-05 13:35:22 +08:00
llvm_intrinsics ::call_float_rint
2024-06-20 12:48:44 +08:00
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_scipy_special_erf ,
" sp_spec_erf " ,
extern_fns ::call_erf
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_scipy_special_erfc ,
" sp_spec_erfc " ,
extern_fns ::call_erfc
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_scipy_special_gamma ,
" sp_spec_gamma " ,
| ctx , val , _ | irrt ::call_gamma ( ctx , val )
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_scipy_special_gammaln ,
" sp_spec_gammaln " ,
| ctx , val , _ | irrt ::call_gammaln ( ctx , val )
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_scipy_special_j0 ,
" sp_spec_j0 " ,
| ctx , val , _ | irrt ::call_j0 ( ctx , val )
) ;
create_helper_call_numpy_unary_elementwise_float_to_float! (
call_scipy_special_j1 ,
" sp_spec_j1 " ,
extern_fns ::call_j1
) ;
2024-04-24 17:40:25 +08:00
/// Invokes the `np_arctan2` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_numpy_arctan2 < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-04-25 15:47:16 +08:00
x1 : ( Type , BasicValueEnum < ' ctx > ) ,
x2 : ( Type , BasicValueEnum < ' ctx > ) ,
) -> Result < BasicValueEnum < ' ctx > , String > {
const FN_NAME : & str = " np_arctan2 " ;
2024-04-24 17:40:25 +08:00
let ( x1_ty , x1 ) = x1 ;
let ( x2_ty , x2 ) = x2 ;
2024-04-25 15:47:16 +08:00
Ok ( match ( x1 , x2 ) {
( BasicValueEnum ::FloatValue ( x1 ) , BasicValueEnum ::FloatValue ( x2 ) ) = > {
debug_assert! ( ctx . unifier . unioned ( x1_ty , ctx . primitives . float ) ) ;
debug_assert! ( ctx . unifier . unioned ( x2_ty , ctx . primitives . float ) ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
extern_fns ::call_atan2 ( ctx , x1 , x2 , None ) . into ( )
}
2024-06-12 14:45:03 +08:00
( x1 , x2 )
if [ & x1_ty , & x2_ty ] . into_iter ( ) . any ( | ty | {
2024-06-12 15:01:01 +08:00
ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) )
2024-06-12 14:45:03 +08:00
} ) = >
{
let is_ndarray1 =
2024-06-12 15:01:01 +08:00
x1_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-06-12 14:45:03 +08:00
let is_ndarray2 =
2024-06-12 15:01:01 +08:00
x2_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-04-25 15:47:16 +08:00
let dtype = if is_ndarray1 & & is_ndarray2 {
let ( ndarray_dtype1 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) ;
2024-06-06 12:16:09 +08:00
let ( ndarray_dtype2 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) ;
2024-04-25 15:47:16 +08:00
debug_assert! ( ctx . unifier . unioned ( ndarray_dtype1 , ndarray_dtype2 ) ) ;
ndarray_dtype1
} else if is_ndarray1 {
2024-06-06 12:16:09 +08:00
unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) . 0
2024-04-25 15:47:16 +08:00
} else if is_ndarray2 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) . 0
} else {
2024-06-12 14:45:03 +08:00
unreachable! ( )
2024-04-25 15:47:16 +08:00
} ;
2024-06-12 14:45:03 +08:00
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty } ;
let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty } ;
2024-04-25 15:47:16 +08:00
numpy ::ndarray_elementwise_binop_impl (
generator ,
ctx ,
dtype ,
None ,
( x1 , ! is_ndarray1 ) ,
( x2 , ! is_ndarray2 ) ,
| generator , ctx , ( lhs , rhs ) | {
call_numpy_arctan2 ( generator , ctx , ( x1_scalar_ty , lhs ) , ( x2_scalar_ty , rhs ) )
} ,
2024-06-12 14:45:03 +08:00
) ?
. as_base_value ( )
. into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ x1_ty , x2_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `np_copysign` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_numpy_copysign < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-04-25 15:47:16 +08:00
x1 : ( Type , BasicValueEnum < ' ctx > ) ,
x2 : ( Type , BasicValueEnum < ' ctx > ) ,
) -> Result < BasicValueEnum < ' ctx > , String > {
const FN_NAME : & str = " np_copysign " ;
2024-04-24 17:40:25 +08:00
let ( x1_ty , x1 ) = x1 ;
let ( x2_ty , x2 ) = x2 ;
2024-04-25 15:47:16 +08:00
Ok ( match ( x1 , x2 ) {
( BasicValueEnum ::FloatValue ( x1 ) , BasicValueEnum ::FloatValue ( x2 ) ) = > {
debug_assert! ( ctx . unifier . unioned ( x1_ty , ctx . primitives . float ) ) ;
debug_assert! ( ctx . unifier . unioned ( x2_ty , ctx . primitives . float ) ) ;
llvm_intrinsics ::call_float_copysign ( ctx , x1 , x2 , None ) . into ( )
}
2024-04-24 17:40:25 +08:00
2024-06-12 14:45:03 +08:00
( x1 , x2 )
if [ & x1_ty , & x2_ty ] . into_iter ( ) . any ( | ty | {
2024-06-12 15:01:01 +08:00
ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) )
2024-06-12 14:45:03 +08:00
} ) = >
{
let is_ndarray1 =
2024-06-12 15:01:01 +08:00
x1_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-06-12 14:45:03 +08:00
let is_ndarray2 =
2024-06-12 15:01:01 +08:00
x2_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-04-25 15:47:16 +08:00
let dtype = if is_ndarray1 & & is_ndarray2 {
let ( ndarray_dtype1 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) ;
let ( ndarray_dtype2 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) ;
debug_assert! ( ctx . unifier . unioned ( ndarray_dtype1 , ndarray_dtype2 ) ) ;
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) . 0
} else if is_ndarray2 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) . 0
} else {
2024-06-12 14:45:03 +08:00
unreachable! ( )
2024-04-25 15:47:16 +08:00
} ;
2024-06-12 14:45:03 +08:00
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty } ;
let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty } ;
2024-04-25 15:47:16 +08:00
numpy ::ndarray_elementwise_binop_impl (
generator ,
ctx ,
dtype ,
None ,
( x1 , ! is_ndarray1 ) ,
( x2 , ! is_ndarray2 ) ,
| generator , ctx , ( lhs , rhs ) | {
call_numpy_copysign ( generator , ctx , ( x1_scalar_ty , lhs ) , ( x2_scalar_ty , rhs ) )
} ,
2024-06-12 14:45:03 +08:00
) ?
. as_base_value ( )
. into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ x1_ty , x2_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `np_fmax` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_numpy_fmax < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-04-25 15:47:16 +08:00
x1 : ( Type , BasicValueEnum < ' ctx > ) ,
x2 : ( Type , BasicValueEnum < ' ctx > ) ,
) -> Result < BasicValueEnum < ' ctx > , String > {
const FN_NAME : & str = " np_fmax " ;
2024-04-24 17:40:25 +08:00
let ( x1_ty , x1 ) = x1 ;
let ( x2_ty , x2 ) = x2 ;
2024-04-25 15:47:16 +08:00
Ok ( match ( x1 , x2 ) {
( BasicValueEnum ::FloatValue ( x1 ) , BasicValueEnum ::FloatValue ( x2 ) ) = > {
debug_assert! ( ctx . unifier . unioned ( x1_ty , ctx . primitives . float ) ) ;
debug_assert! ( ctx . unifier . unioned ( x2_ty , ctx . primitives . float ) ) ;
llvm_intrinsics ::call_float_maxnum ( ctx , x1 , x2 , None ) . into ( )
}
2024-06-12 14:45:03 +08:00
( x1 , x2 )
if [ & x1_ty , & x2_ty ] . into_iter ( ) . any ( | ty | {
2024-06-12 15:01:01 +08:00
ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) )
2024-06-12 14:45:03 +08:00
} ) = >
{
let is_ndarray1 =
2024-06-12 15:01:01 +08:00
x1_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-06-12 14:45:03 +08:00
let is_ndarray2 =
2024-06-12 15:01:01 +08:00
x2_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
let dtype = if is_ndarray1 & & is_ndarray2 {
let ( ndarray_dtype1 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) ;
let ( ndarray_dtype2 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) ;
debug_assert! ( ctx . unifier . unioned ( ndarray_dtype1 , ndarray_dtype2 ) ) ;
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) . 0
} else if is_ndarray2 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) . 0
} else {
2024-06-12 14:45:03 +08:00
unreachable! ( )
2024-04-25 15:47:16 +08:00
} ;
2024-06-12 14:45:03 +08:00
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty } ;
let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty } ;
2024-04-25 15:47:16 +08:00
numpy ::ndarray_elementwise_binop_impl (
generator ,
ctx ,
dtype ,
None ,
( x1 , ! is_ndarray1 ) ,
( x2 , ! is_ndarray2 ) ,
| generator , ctx , ( lhs , rhs ) | {
call_numpy_fmax ( generator , ctx , ( x1_scalar_ty , lhs ) , ( x2_scalar_ty , rhs ) )
} ,
2024-06-12 14:45:03 +08:00
) ?
. as_base_value ( )
. into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ x1_ty , x2_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `np_fmin` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_numpy_fmin < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-04-25 15:47:16 +08:00
x1 : ( Type , BasicValueEnum < ' ctx > ) ,
x2 : ( Type , BasicValueEnum < ' ctx > ) ,
) -> Result < BasicValueEnum < ' ctx > , String > {
const FN_NAME : & str = " np_fmin " ;
2024-04-24 17:40:25 +08:00
let ( x1_ty , x1 ) = x1 ;
let ( x2_ty , x2 ) = x2 ;
2024-04-25 15:47:16 +08:00
Ok ( match ( x1 , x2 ) {
( BasicValueEnum ::FloatValue ( x1 ) , BasicValueEnum ::FloatValue ( x2 ) ) = > {
debug_assert! ( ctx . unifier . unioned ( x1_ty , ctx . primitives . float ) ) ;
debug_assert! ( ctx . unifier . unioned ( x2_ty , ctx . primitives . float ) ) ;
llvm_intrinsics ::call_float_minnum ( ctx , x1 , x2 , None ) . into ( )
}
2024-06-12 14:45:03 +08:00
( x1 , x2 )
if [ & x1_ty , & x2_ty ] . into_iter ( ) . any ( | ty | {
2024-06-12 15:01:01 +08:00
ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) )
2024-06-12 14:45:03 +08:00
} ) = >
{
let is_ndarray1 =
2024-06-12 15:01:01 +08:00
x1_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-06-12 14:45:03 +08:00
let is_ndarray2 =
2024-06-12 15:01:01 +08:00
x2_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-04-25 15:47:16 +08:00
let dtype = if is_ndarray1 & & is_ndarray2 {
let ( ndarray_dtype1 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) ;
let ( ndarray_dtype2 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) ;
debug_assert! ( ctx . unifier . unioned ( ndarray_dtype1 , ndarray_dtype2 ) ) ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) . 0
} else if is_ndarray2 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) . 0
} else {
2024-06-12 14:45:03 +08:00
unreachable! ( )
2024-04-25 15:47:16 +08:00
} ;
2024-06-12 14:45:03 +08:00
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty } ;
let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty } ;
2024-04-25 15:47:16 +08:00
numpy ::ndarray_elementwise_binop_impl (
generator ,
ctx ,
dtype ,
None ,
( x1 , ! is_ndarray1 ) ,
( x2 , ! is_ndarray2 ) ,
| generator , ctx , ( lhs , rhs ) | {
call_numpy_fmin ( generator , ctx , ( x1_scalar_ty , lhs ) , ( x2_scalar_ty , rhs ) )
} ,
2024-06-12 14:45:03 +08:00
) ?
. as_base_value ( )
. into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ x1_ty , x2_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `np_ldexp` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_numpy_ldexp < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-04-25 15:47:16 +08:00
x1 : ( Type , BasicValueEnum < ' ctx > ) ,
x2 : ( Type , BasicValueEnum < ' ctx > ) ,
) -> Result < BasicValueEnum < ' ctx > , String > {
const FN_NAME : & str = " np_ldexp " ;
2024-04-24 17:40:25 +08:00
let ( x1_ty , x1 ) = x1 ;
let ( x2_ty , x2 ) = x2 ;
2024-04-25 15:47:16 +08:00
Ok ( match ( x1 , x2 ) {
( BasicValueEnum ::FloatValue ( x1 ) , BasicValueEnum ::IntValue ( x2 ) ) = > {
debug_assert! ( ctx . unifier . unioned ( x1_ty , ctx . primitives . float ) ) ;
debug_assert! ( ctx . unifier . unioned ( x2_ty , ctx . primitives . int32 ) ) ;
extern_fns ::call_ldexp ( ctx , x1 , x2 , None ) . into ( )
}
2024-06-12 14:45:03 +08:00
( x1 , x2 )
if [ & x1_ty , & x2_ty ] . into_iter ( ) . any ( | ty | {
2024-06-12 15:01:01 +08:00
ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) )
2024-06-12 14:45:03 +08:00
} ) = >
{
let is_ndarray1 =
2024-06-12 15:01:01 +08:00
x1_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-06-12 14:45:03 +08:00
let is_ndarray2 =
2024-06-12 15:01:01 +08:00
x2_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-04-25 15:47:16 +08:00
2024-06-12 14:45:03 +08:00
let dtype =
if is_ndarray1 { unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) . 0 } else { x1_ty } ;
2024-04-24 17:40:25 +08:00
2024-04-25 15:47:16 +08:00
let x1_scalar_ty = dtype ;
2024-06-12 14:45:03 +08:00
let x2_scalar_ty =
if is_ndarray2 { unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) . 0 } else { x2_ty } ;
2024-04-25 15:47:16 +08:00
numpy ::ndarray_elementwise_binop_impl (
generator ,
ctx ,
dtype ,
None ,
( x1 , ! is_ndarray1 ) ,
( x2 , ! is_ndarray2 ) ,
| generator , ctx , ( lhs , rhs ) | {
call_numpy_ldexp ( generator , ctx , ( x1_scalar_ty , lhs ) , ( x2_scalar_ty , rhs ) )
} ,
2024-06-12 14:45:03 +08:00
) ?
. as_base_value ( )
. into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ x1_ty , x2_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `np_hypot` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_numpy_hypot < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-04-25 15:47:16 +08:00
x1 : ( Type , BasicValueEnum < ' ctx > ) ,
x2 : ( Type , BasicValueEnum < ' ctx > ) ,
) -> Result < BasicValueEnum < ' ctx > , String > {
const FN_NAME : & str = " np_hypot " ;
2024-04-24 17:40:25 +08:00
let ( x1_ty , x1 ) = x1 ;
let ( x2_ty , x2 ) = x2 ;
2024-04-25 15:47:16 +08:00
Ok ( match ( x1 , x2 ) {
( BasicValueEnum ::FloatValue ( x1 ) , BasicValueEnum ::FloatValue ( x2 ) ) = > {
debug_assert! ( ctx . unifier . unioned ( x1_ty , ctx . primitives . float ) ) ;
debug_assert! ( ctx . unifier . unioned ( x2_ty , ctx . primitives . float ) ) ;
extern_fns ::call_hypot ( ctx , x1 , x2 , None ) . into ( )
}
2024-04-24 17:40:25 +08:00
2024-06-12 14:45:03 +08:00
( x1 , x2 )
if [ & x1_ty , & x2_ty ] . into_iter ( ) . any ( | ty | {
2024-06-12 15:01:01 +08:00
ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) )
2024-06-12 14:45:03 +08:00
} ) = >
{
let is_ndarray1 =
2024-06-12 15:01:01 +08:00
x1_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-06-12 14:45:03 +08:00
let is_ndarray2 =
2024-06-12 15:01:01 +08:00
x2_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-04-25 15:47:16 +08:00
let dtype = if is_ndarray1 & & is_ndarray2 {
let ( ndarray_dtype1 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) ;
let ( ndarray_dtype2 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) ;
debug_assert! ( ctx . unifier . unioned ( ndarray_dtype1 , ndarray_dtype2 ) ) ;
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) . 0
} else if is_ndarray2 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) . 0
} else {
2024-06-12 14:45:03 +08:00
unreachable! ( )
2024-04-25 15:47:16 +08:00
} ;
2024-06-12 14:45:03 +08:00
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty } ;
let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty } ;
2024-04-25 15:47:16 +08:00
numpy ::ndarray_elementwise_binop_impl (
generator ,
ctx ,
dtype ,
None ,
( x1 , ! is_ndarray1 ) ,
( x2 , ! is_ndarray2 ) ,
| generator , ctx , ( lhs , rhs ) | {
call_numpy_hypot ( generator , ctx , ( x1_scalar_ty , lhs ) , ( x2_scalar_ty , rhs ) )
} ,
2024-06-12 14:45:03 +08:00
) ?
. as_base_value ( )
. into ( )
2024-04-25 15:47:16 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ x1_ty , x2_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-04-24 17:40:25 +08:00
}
/// Invokes the `np_nextafter` builtin function.
2024-04-25 15:47:16 +08:00
pub fn call_numpy_nextafter < ' ctx , G : CodeGenerator + ? Sized > (
generator : & mut G ,
2024-04-24 17:40:25 +08:00
ctx : & mut CodeGenContext < ' ctx , '_ > ,
2024-04-25 15:47:16 +08:00
x1 : ( Type , BasicValueEnum < ' ctx > ) ,
x2 : ( Type , BasicValueEnum < ' ctx > ) ,
) -> Result < BasicValueEnum < ' ctx > , String > {
const FN_NAME : & str = " np_nextafter " ;
2024-04-24 17:40:25 +08:00
let ( x1_ty , x1 ) = x1 ;
let ( x2_ty , x2 ) = x2 ;
2024-04-25 15:47:16 +08:00
Ok ( match ( x1 , x2 ) {
( BasicValueEnum ::FloatValue ( x1 ) , BasicValueEnum ::FloatValue ( x2 ) ) = > {
debug_assert! ( ctx . unifier . unioned ( x1_ty , ctx . primitives . float ) ) ;
debug_assert! ( ctx . unifier . unioned ( x2_ty , ctx . primitives . float ) ) ;
extern_fns ::call_nextafter ( ctx , x1 , x2 , None ) . into ( )
}
2024-06-12 14:45:03 +08:00
( x1 , x2 )
if [ & x1_ty , & x2_ty ] . into_iter ( ) . any ( | ty | {
2024-06-12 15:01:01 +08:00
ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) )
2024-06-12 14:45:03 +08:00
} ) = >
{
let is_ndarray1 =
2024-06-12 15:01:01 +08:00
x1_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-06-12 14:45:03 +08:00
let is_ndarray2 =
2024-06-12 15:01:01 +08:00
x2_ty . obj_id ( & ctx . unifier ) . is_some_and ( | id | id = = PrimDef ::NDArray . id ( ) ) ;
2024-04-25 15:47:16 +08:00
let dtype = if is_ndarray1 & & is_ndarray2 {
let ( ndarray_dtype1 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) ;
let ( ndarray_dtype2 , _ ) = unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) ;
debug_assert! ( ctx . unifier . unioned ( ndarray_dtype1 , ndarray_dtype2 ) ) ;
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x1_ty ) . 0
} else if is_ndarray2 {
unpack_ndarray_var_tys ( & mut ctx . unifier , x2_ty ) . 0
} else {
2024-06-12 14:45:03 +08:00
unreachable! ( )
2024-04-25 15:47:16 +08:00
} ;
2024-06-12 14:45:03 +08:00
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty } ;
let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty } ;
2024-04-25 15:47:16 +08:00
numpy ::ndarray_elementwise_binop_impl (
generator ,
ctx ,
dtype ,
None ,
( x1 , ! is_ndarray1 ) ,
( x2 , ! is_ndarray2 ) ,
| generator , ctx , ( lhs , rhs ) | {
call_numpy_nextafter ( generator , ctx , ( x1_scalar_ty , lhs ) , ( x2_scalar_ty , rhs ) )
} ,
2024-06-12 14:45:03 +08:00
) ?
. as_base_value ( )
. into ( )
2024-04-25 15:47:16 +08:00
}
2024-04-24 17:40:25 +08:00
2024-06-12 14:45:03 +08:00
_ = > unsupported_type ( ctx , FN_NAME , & [ x1_ty , x2_ty ] ) ,
2024-04-25 15:47:16 +08:00
} )
2024-06-12 14:45:03 +08:00
}