From 6cbba8fdde439afad1cfab65cb3267ee483e043a Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 11:24:28 +0800 Subject: [PATCH] [core] codegen: Reimplement builtin funcs to support strided ndarrays Based on 7f3c4530: core/ndstrides: update builtin_fns to use ndarray with strides --- nac3core/src/codegen/builtin_fns.rs | 959 +++++++++++----------------- 1 file changed, 368 insertions(+), 591 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 54650ab3..7c8ad7a8 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -11,19 +11,16 @@ use super::{ irrt::calculate_len_for_slice_range, llvm_intrinsics, macros::codegen_unreachable, - numpy, - numpy::ndarray_elementwise_unaryop_impl, - stmt::gen_for_callback_incrementing, types::{ndarray::NDArrayType, ListType, TupleType}, values::{ - ndarray::NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, - UntypedArrayLikeAccessor, + ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, + ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, }; use crate::{ toplevel::{ - helper::{extract_ndims, PrimDef}, + helper::{arraylike_flatten_element_type, extract_ndims, PrimDef}, numpy::unpack_ndarray_var_tys, }, typecheck::typedef::{Type, TypeEnum}, @@ -129,18 +126,18 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int32, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i32_type().into() }, + |generator, ctx, scalar| call_int32(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "int32", &[n_ty]), @@ -189,18 +186,18 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int64, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i64_type().into() }, + |generator, ctx, scalar| call_int64(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "int64", &[n_ty]), @@ -265,18 +262,18 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint32, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i32_type().into() }, + |generator, ctx, scalar| call_uint32(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "uint32", &[n_ty]), @@ -330,18 +327,18 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint64, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i64_type().into() }, + |generator, ctx, scalar| call_uint64(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "uint64", &[n_ty]), @@ -355,7 +352,6 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => { @@ -394,20 +390,19 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndims = extract_ndims(&ctx.unifier, ndims); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None), - |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.f64_type().into() }, + |generator, ctx, scalar| call_float(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, "float", &[n_ty]), @@ -440,18 +435,20 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty.into() }, + |generator, ctx, scalar| { + call_round(generator, ctx, (elem_ty, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -477,18 +474,18 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.f64_type().into() }, + |generator, ctx, scalar| call_numpy_round(generator, ctx, (elem_ty, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -539,22 +536,21 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| { - let elem = call_bool(generator, ctx, (elem_ty, val))?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.ctx.i8_type().into() }, + |generator, ctx, scalar| { + let elem = call_bool(generator, ctx, (elem_ty, scalar))?; + Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) + }, + ) + .unwrap(); - Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) - }, - )?; - - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -591,18 +587,20 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty }, + |generator, ctx, scalar| { + call_floor(generator, ctx, (elem_ty, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -639,18 +637,20 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(n, None), - |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty }, + |generator, ctx, scalar| { + call_ceil(generator, ctx, (elem_ty, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -741,42 +741,37 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = + ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)).to_ndarray(generator, ctx); + let x2 = + ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + let x2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); + let llvm_common_dtype = x1.get_type().element_type(); - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( + let result = NDArrayType::new_broadcast( + generator, + ctx.ctx, + llvm_common_dtype, + &[x1.get_type(), x2.get_type()], + ) + .broadcast_starmap( generator, ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_min(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) }, - )? - .as_base_value() - .into() + ) + .unwrap(); + + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -861,23 +856,26 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( _ => codegen_unreachable!(ctx), } } + BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty); - let n = llvm_ndarray_ty.map_value(n, None); - let n_sz = n.size(generator, ctx); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, a_ty).map_value(n, None); + let llvm_dtype = ndarray.get_type().element_type(); + + let zero = llvm_usize.const_zero(); + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx + let size_nez = ctx .builder - .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") + .build_int_compare(IntPredicate::NE, ndarray.size(generator, ctx), zero, "") .unwrap(); ctx.make_assert( generator, - n_sz_eqz, + size_nez, "0:ValueError", format!("zero-size array to reduction operation {fn_name}").as_str(), [None, None, None], @@ -885,54 +883,43 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ); } - let accumulator_addr = - generator.gen_var_alloc(ctx, llvm_ndarray_ty.element_type(), None)?; - let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?; + let extremum = generator.gen_var_alloc(ctx, llvm_dtype, None)?; + let extremum_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - unsafe { - let identity = - n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_store(accumulator_addr, identity).unwrap(); - ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap(); - } + let first_value = unsafe { ndarray.data().get_unchecked(ctx, generator, &zero, None) }; + ctx.builder.build_store(extremum, first_value).unwrap(); + ctx.builder.build_store(extremum_idx, zero).unwrap(); - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_int64.const_int(1, false), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { - n.data().get_unchecked( - ctx, - generator, - &ctx.builder - .build_int_truncate_or_bit_cast(idx, llvm_usize, "") - .unwrap(), - None, - ) - }; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); + // The first element is iterated, but this doesn't matter. + ndarray + .foreach(generator, ctx, |_, ctx, _, nditer| { + let old_extremum = ctx.builder.build_load(extremum, "").unwrap(); + let old_extremum_idx = ctx + .builder + .build_load(extremum_idx, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); - let result = match fn_name { + let curr_value = nditer.get_scalar(ctx); + let curr_idx = nditer.get_index(ctx); + + let new_extremum = match fn_name { "np_argmin" | "np_min" => { - call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)) + call_min(ctx, (elem_ty, old_extremum), (elem_ty, curr_value)) } "np_argmax" | "np_max" => { - call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)) + call_max(ctx, (elem_ty, old_extremum), (elem_ty, curr_value)) } _ => codegen_unreachable!(ctx), }; - let updated_idx = match (accumulator, result) { + let new_extremum_idx = match (old_extremum, new_extremum) { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx .builder .build_select( ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(), - idx.into(), - cur_idx, + curr_idx, + old_extremum_idx, "", ) .unwrap(), @@ -942,24 +929,35 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ctx.builder .build_float_compare(FloatPredicate::ONE, m, n, "") .unwrap(), - idx.into(), - cur_idx, + curr_idx, + old_extremum_idx, "", ) .unwrap(), _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]), }; - ctx.builder.build_store(res_idx, updated_idx).unwrap(); - ctx.builder.build_store(accumulator_addr, result).unwrap(); + + ctx.builder.build_store(extremum, new_extremum).unwrap(); + ctx.builder.build_store(extremum_idx, new_extremum_idx).unwrap(); Ok(()) - }, - llvm_int64.const_int(1, false), - )?; + }) + .unwrap(); match fn_name { - "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(), - "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(), + "np_argmin" | "np_argmax" => ctx + .builder + .build_int_s_extend_or_bit_cast( + ctx.builder + .build_load(extremum_idx, "") + .map(BasicValueEnum::into_int_value) + .unwrap(), + ctx.ctx.i64_type(), + "", + ) + .unwrap() + .into(), + "np_max" | "np_min" => ctx.builder.build_load(extremum, "").unwrap(), _ => codegen_unreachable!(ctx), } } @@ -1006,42 +1004,37 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = + ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)).to_ndarray(generator, ctx); + let x2 = + ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + let x2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); + let llvm_common_dtype = x1.get_type().element_type(); - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( + let result = NDArrayType::new_broadcast( + generator, + ctx.ctx, + llvm_common_dtype, + &[x1.get_type(), x2.get_type()], + ) + .broadcast_starmap( generator, ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_max(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) }, - )? - .as_base_value() - .into() + ) + .unwrap(); + + result.as_base_value().into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1074,39 +1067,20 @@ where ) -> Option>, RetElemFn: Fn(&mut CodeGenContext<'ctx, '_>, Type) -> Type, { - let result = match arg_val { - BasicValueEnum::PointerValue(x) - if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, arg_ty); - let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); + let arg = ScalarOrNDArray::from_value(generator, ctx, (arg_ty, arg_val)); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - llvm_ndarray_ty.map_value(x, None), - |generator, ctx, elem_val| { - helper_call_numpy_unary_elementwise( - generator, - ctx, - (arg_elem_ty, elem_val), - fn_name, - get_ret_elem_type, - on_scalar, - ) - }, - )?; - ndarray.as_base_value().into() - } + let dtype = arraylike_flatten_element_type(&mut ctx.unifier, arg_ty); - _ => on_scalar(generator, ctx, arg_ty, arg_val) - .unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])), - }; + let ret_ty = get_ret_elem_type(ctx, dtype); + let llvm_ret_ty = ctx.get_llvm_type(generator, ret_ty); + let result = arg.map(generator, ctx, llvm_ret_ty, |generator, ctx, scalar| { + let Some(result) = on_scalar(generator, ctx, dtype, scalar) else { + unsupported_type(ctx, fn_name, &[arg_ty]) + }; + Ok(result) + })?; - Ok(result) + Ok(result.to_basic_value_enum()) } pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( @@ -1431,59 +1405,29 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_arctan2"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_atan2(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_atan2(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_copysign` builtin function. @@ -1495,59 +1439,29 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_copysign"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmax` builtin function. @@ -1559,59 +1473,29 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_fmax"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmin` builtin function. @@ -1623,59 +1507,29 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_fmin"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_ldexp` builtin function. @@ -1687,48 +1541,31 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_ldexp"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_ldexp(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1_scalar), BasicValueEnum::IntValue(x2_scalar)) => { + debug_assert_eq!(x1.get_dtype(), ctx.ctx.f64_type().into()); + debug_assert_eq!(x2.get_dtype(), ctx.ctx.i32_type().into()); + Ok(extern_fns::call_ldexp(ctx, x1_scalar, x2_scalar, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = - if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; - - let x1_scalar_ty = dtype; - let x2_scalar_ty = - if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_hypot` builtin function. @@ -1740,59 +1577,29 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_hypot"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_hypot(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_hypot(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_nextafter` builtin function. @@ -1804,59 +1611,29 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_nextafter"; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)); + let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)); - extern_fns::call_nextafter(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.ctx.f64_type().into(), + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_nextafter(ctx, x1, x2, None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1_ty, x1, !is_ndarray1), - (x2_ty, x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_linalg_cholesky` linalg function