From b9e00eb8a5ac2b9a654f313a8e227b103ecd44aa Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 22:16:51 +0800 Subject: [PATCH] core/ndstrides: partially update builtin_fns to use ndarray with strides nalgebra functions are not updated --- nac3core/src/codegen/builtin_fns.rs | 1078 ++++++++++----------------- 1 file changed, 405 insertions(+), 673 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 4cfc2484..ebede6ed 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -8,8 +8,7 @@ use crate::codegen::classes::{ }; use crate::codegen::expr::destructure_range; use crate::codegen::irrt::calculate_len_for_slice_range; -use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; -use crate::codegen::stmt::gen_for_callback_incrementing; +use crate::codegen::object::ndarray::{NDArrayOut, ScalarOrNDArray}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::toplevel::helper::PrimDef; use crate::toplevel::numpy::unpack_ndarray_var_tys; @@ -67,7 +66,6 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; Ok(match n { @@ -101,21 +99,20 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_int_truncate(to_int64, llvm_i32, "conv").map(Into::into).unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int32, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.int32 }, + |generator, ctx, scalar| call_int32(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "int32", &[n_ty]), @@ -129,7 +126,6 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; @@ -163,21 +159,20 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int64, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.int64 }, + |generator, ctx, scalar| call_int64(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "int64", &[n_ty]), @@ -191,7 +186,6 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; @@ -241,21 +235,20 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint32, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.uint32 }, + |generator, ctx, scalar| call_uint32(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "uint32", &[n_ty]), @@ -269,7 +262,6 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; @@ -308,21 +300,20 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint64, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.uint64 }, + |generator, ctx, scalar| call_uint64(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "uint64", &[n_ty]), @@ -336,7 +327,6 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; @@ -374,21 +364,20 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( n.into() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.float }, + |generator, ctx, scalar| call_float(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "float", &[n_ty]), @@ -404,8 +393,6 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "round"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty).into_int_type(); @@ -420,21 +407,22 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ret_elem_ty }, + |generator, ctx, scalar| { + call_round(generator, ctx, (ndarray.dtype, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -449,8 +437,6 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_round"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; Ok(match n { @@ -460,21 +446,22 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_rint(ctx, n, None).into() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.float }, + |generator, ctx, scalar| { + call_numpy_round(generator, ctx, (ndarray.dtype, scalar)) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -489,8 +476,6 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "bool"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; Ok(match n { @@ -525,25 +510,22 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - let elem = call_bool(generator, ctx, (elem_ty, val))?; - - Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) - }, - )?; - - ndarray.as_base_value().into() + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.bool }, + |generator, ctx, scalar| { + let elem = call_bool(generator, ctx, (ndarray.dtype, scalar))?; + Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) + }, + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -559,8 +541,6 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "floor"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); @@ -579,21 +559,21 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( } } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; - - ndarray.as_base_value().into() + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ret_elem_ty }, + |generator, ctx, scalar| { + call_floor(generator, ctx, (ndarray.dtype, scalar), ret_elem_ty) + }, + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -609,8 +589,6 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "ceil"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); @@ -629,21 +607,21 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( } } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; - - ndarray.as_base_value().into() + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ret_elem_ty }, + |generator, ctx, scalar| { + call_ceil(generator, ctx, (ndarray.dtype, scalar), ret_elem_ty) + }, + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -735,47 +713,32 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => + _ if [&x1_ty, &x2_ty] + .into_iter() + .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2).to_ndarray(generator, ctx); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1.dtype, x2.dtype)); + let common_dtype = x1.dtype; - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( + let result = NDArrayObject::broadcast_starmap( generator, ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + &[x1, x2], + NDArrayOut::NewNDArray { dtype: common_dtype }, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_min(ctx, (x1.dtype, x1_scalar), (x2.dtype, x2_scalar))) }, - )? - .as_base_value() - .into() + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -842,7 +805,6 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name)); let llvm_int64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (a_ty, a) = a; Ok(match a { @@ -864,23 +826,21 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( _ => unreachable!(), } } - BasicValueEnum::PointerValue(n) - if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + _ if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: a_ty, value: a }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); + + let dtype_llvm = ctx.get_llvm_type(generator, ndarray.dtype); + + let zero = Int(SizeT).const_0(generator, ctx.ctx); - let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); - let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx - .builder - .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") - .unwrap(); + let size_isnt_zero = + ndarray.size(generator, ctx).compare(ctx, IntPredicate::NE, zero); ctx.make_assert( generator, - n_sz_eqz, + size_isnt_zero.value, "0:ValueError", format!("zero-size array to reduction operation {fn_name}").as_str(), [None, None, None], @@ -888,44 +848,43 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ); } - let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?; + let extremum = generator.gen_var_alloc(ctx, dtype_llvm, None)?; + let extremum_idx = Int(SizeT).var_alloca(generator, ctx, None)?; - unsafe { - let identity = - n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_store(accumulator_addr, identity).unwrap(); - ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap(); - } + let first_value = ndarray.get_nth_scalar(generator, ctx, zero).value; + ctx.builder.build_store(extremum, first_value).unwrap(); + extremum_idx.store(ctx, zero); - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_int64.const_int(1, false), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); + // The first element is iterated, but this doesn't matter. + ndarray + .foreach(generator, ctx, |generator, ctx, _hooks, nditer| { + let old_extremum = ctx.builder.build_load(extremum, "").unwrap(); + let old_extremum_idx = extremum_idx.load(generator, ctx); - let result = match fn_name { - "np_argmin" | "np_min" => { - call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)) - } - "np_argmax" | "np_max" => { - call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)) - } + let curr_value = nditer.get_scalar(generator, ctx).value; + let curr_idx = nditer.get_index(generator, ctx); + + let new_extremum = match fn_name { + "np_argmin" | "np_min" => call_min( + ctx, + (ndarray.dtype, old_extremum), + (ndarray.dtype, curr_value), + ), + "np_argmax" | "np_max" => call_max( + ctx, + (ndarray.dtype, old_extremum), + (ndarray.dtype, curr_value), + ), _ => unreachable!(), }; - let updated_idx = match (accumulator, result) { + let new_extremum_idx = match (old_extremum, new_extremum) { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx .builder .build_select( ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(), - idx.into(), - cur_idx, + curr_idx.value, + old_extremum_idx.value, "", ) .unwrap(), @@ -935,24 +894,31 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ctx.builder .build_float_compare(FloatPredicate::ONE, m, n, "") .unwrap(), - idx.into(), - cur_idx, + curr_idx.value, + old_extremum_idx.value, "", ) .unwrap(), - _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]), + _ => unsupported_type(ctx, fn_name, &[ndarray.dtype, ndarray.dtype]), }; - ctx.builder.build_store(res_idx, updated_idx).unwrap(); - ctx.builder.build_store(accumulator_addr, result).unwrap(); + + ctx.builder.build_store(extremum, new_extremum).unwrap(); + + let new_extremum_idx = + Int(SizeT).believe_value(new_extremum_idx.into_int_value()); + extremum_idx.store(ctx, new_extremum_idx); Ok(()) - }, - llvm_int64.const_int(1, false), - )?; + }) + .unwrap(); match fn_name { - "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(), - "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(), + "np_argmin" | "np_argmax" => extremum_idx + .load(generator, ctx) + .s_extend(generator, ctx, Int64) + .value + .as_basic_value_enum(), + "np_max" | "np_min" => ctx.builder.build_load(extremum, "").unwrap(), _ => unreachable!(), } } @@ -997,47 +963,32 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => + _ if [&x1_ty, &x2_ty] + .into_iter() + .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2).to_ndarray(generator, ctx); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1.dtype, x2.dtype)); + let common_dtype = x1.dtype; - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( + let result = NDArrayObject::broadcast_starmap( generator, ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + &[x1, x2], + NDArrayOut::NewNDArray { dtype: common_dtype }, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_max(ctx, (x1.dtype, x1_scalar), (x2.dtype, x2_scalar))) }, - )? - .as_base_value() - .into() + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1070,39 +1021,19 @@ where ) -> Option>, RetElemFn: Fn(&mut CodeGenContext<'ctx, '_>, Type) -> Type, { - let result = match arg_val { - BasicValueEnum::PointerValue(x) - if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let llvm_usize = generator.get_size_type(ctx.ctx); - let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); - let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); + let arg = AnyObject { ty: arg_ty, value: arg_val }; + let arg = ScalarOrNDArray::split_object(generator, ctx, arg); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, elem_val| { - helper_call_numpy_unary_elementwise( - generator, - ctx, - (arg_elem_ty, elem_val), - fn_name, - get_ret_elem_type, - on_scalar, - ) - }, - )?; - ndarray.as_base_value().into() - } + let dtype = arg.get_dtype(); - _ => on_scalar(generator, ctx, arg_ty, arg_val) - .unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])), - }; - - Ok(result) + let ret_ty = get_ret_elem_type(ctx, dtype); + let result = arg.map(generator, ctx, ret_ty, |generator, ctx, scalar| { + let Some(result) = on_scalar(generator, ctx, dtype, scalar) else { + unsupported_type(ctx, fn_name, &[arg_ty]) + }; + Ok(result) + })?; + Ok(result.to_basic_value_enum()) } pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( @@ -1421,391 +1352,220 @@ create_helper_call_numpy_unary_elementwise_float_to_float!( pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_arctan2"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - extern_fns::call_atan2(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_atan2(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_copysign` builtin function. pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_copysign"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_copysign(ctx, x1, x2, None) + .as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmax` builtin function. pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_fmax"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmin` builtin function. pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_fmin"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_ldexp` builtin function. pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_ldexp"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - extern_fns::call_ldexp(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1_scalar), BasicValueEnum::IntValue(x2_scalar)) => { + debug_assert!(ctx.unifier.unioned(x1.get_dtype(), ctx.primitives.float)); + debug_assert!(ctx.unifier.unioned(x2.get_dtype(), ctx.primitives.int32)); + Ok(extern_fns::call_ldexp(ctx, x1_scalar, x2_scalar, None) + .as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = - if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; - - let x1_scalar_ty = dtype; - let x2_scalar_ty = - if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_hypot` builtin function. pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_hypot"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - extern_fns::call_hypot(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_hypot(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_nextafter` builtin function. @@ -1820,59 +1580,31 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( let (x1_ty, x1) = x1; let (x2_ty, x2) = x2; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - extern_fns::call_nextafter(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_nextafter(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it