From 7f3c45302df54255bb17ffeff6b72a37ac1f4bb9 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 22:16:51 +0800 Subject: [PATCH] core/ndstrides: update builtin_fns to use ndarray with strides --- nac3core/src/codegen/builtin_fns.rs | 1702 ++++++++++----------------- 1 file changed, 635 insertions(+), 1067 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index d2e02bc..8989c48 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,28 +1,28 @@ use inkwell::{ types::BasicTypeEnum, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, + values::{BasicValue, BasicValueEnum, IntValue}, FloatPredicate, IntPredicate, OptimizationLevel, }; use itertools::Itertools; use super::{ - classes::{ - NDArrayValue, ProxyValue, RangeValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, - }, + classes::RangeValue, expr::destructure_range, extern_fns, irrt, irrt::calculate_len_for_slice_range, llvm_intrinsics, macros::codegen_unreachable, model::*, - numpy, - numpy::ndarray_elementwise_unaryop_impl, - object::{any::AnyObject, list::ListObject, ndarray::NDArrayObject, tuple::TupleObject}, - stmt::gen_for_callback_incrementing, + object::{ + any::AnyObject, + list::ListObject, + ndarray::{NDArrayObject, NDArrayOut, ScalarOrNDArray}, + tuple::TupleObject, + }, CodeGenContext, CodeGenerator, }; use crate::{ - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys}, + toplevel::helper::PrimDef, typecheck::typedef::{Type, TypeEnum}, }; @@ -80,7 +80,6 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; Ok(match n { @@ -114,21 +113,20 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_int_truncate(to_int64, llvm_i32, "conv").map(Into::into).unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int32, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.int32 }, + |generator, ctx, scalar| call_int32(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "int32", &[n_ty]), @@ -142,7 +140,6 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; @@ -176,21 +173,20 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int64, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.int64 }, + |generator, ctx, scalar| call_int64(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "int64", &[n_ty]), @@ -204,7 +200,6 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; @@ -254,21 +249,20 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint32, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.uint32 }, + |generator, ctx, scalar| call_uint32(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "uint32", &[n_ty]), @@ -282,7 +276,6 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; @@ -321,21 +314,20 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint64, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.uint64 }, + |generator, ctx, scalar| call_uint64(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "uint64", &[n_ty]), @@ -349,7 +341,6 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; @@ -387,21 +378,20 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( n.into() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.float }, + |generator, ctx, scalar| call_float(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "float", &[n_ty]), @@ -417,8 +407,6 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "round"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty).into_int_type(); @@ -433,21 +421,22 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ret_elem_ty }, + |generator, ctx, scalar| { + call_round(generator, ctx, (ndarray.dtype, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -462,8 +451,6 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_round"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; Ok(match n { @@ -473,21 +460,22 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_rint(ctx, n, None).into() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.float }, + |generator, ctx, scalar| { + call_numpy_round(generator, ctx, (ndarray.dtype, scalar)) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -502,8 +490,6 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "bool"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; Ok(match n { @@ -538,25 +524,22 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - let elem = call_bool(generator, ctx, (elem_ty, val))?; - - Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) - }, - )?; - - ndarray.as_base_value().into() + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.bool }, + |generator, ctx, scalar| { + let elem = call_bool(generator, ctx, (ndarray.dtype, scalar))?; + Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) + }, + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -572,8 +555,6 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "floor"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); @@ -592,21 +573,21 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( } } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; - - ndarray.as_base_value().into() + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ret_elem_ty }, + |generator, ctx, scalar| { + call_floor(generator, ctx, (ndarray.dtype, scalar), ret_elem_ty) + }, + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -622,8 +603,6 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "ceil"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); @@ -642,21 +621,21 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( } } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; - - ndarray.as_base_value().into() + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ret_elem_ty }, + |generator, ctx, scalar| { + call_ceil(generator, ctx, (ndarray.dtype, scalar), ret_elem_ty) + }, + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -748,47 +727,32 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => + _ if [&x1_ty, &x2_ty] + .into_iter() + .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2).to_ndarray(generator, ctx); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1.dtype, x2.dtype)); + let common_dtype = x1.dtype; - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( + let result = NDArrayObject::broadcast_starmap( generator, ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + &[x1, x2], + NDArrayOut::NewNDArray { dtype: common_dtype }, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_min(ctx, (x1.dtype, x1_scalar), (x2.dtype, x2_scalar))) }, - )? - .as_base_value() - .into() + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -855,7 +819,6 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name)); let llvm_int64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (a_ty, a) = a; Ok(match a { @@ -877,23 +840,21 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( _ => codegen_unreachable!(ctx), } } - BasicValueEnum::PointerValue(n) - if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + _ if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: a_ty, value: a }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); + + let dtype_llvm = ctx.get_llvm_type(generator, ndarray.dtype); + + let zero = Int(SizeT).const_0(generator, ctx.ctx); - let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); - let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx - .builder - .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") - .unwrap(); + let size_isnt_zero = + ndarray.size(generator, ctx).compare(ctx, IntPredicate::NE, zero); ctx.make_assert( generator, - n_sz_eqz, + size_isnt_zero.value, "0:ValueError", format!("zero-size array to reduction operation {fn_name}").as_str(), [None, None, None], @@ -901,44 +862,43 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ); } - let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?; + let extremum = generator.gen_var_alloc(ctx, dtype_llvm, None)?; + let extremum_idx = Int(SizeT).var_alloca(generator, ctx, None)?; - unsafe { - let identity = - n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_store(accumulator_addr, identity).unwrap(); - ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap(); - } + let first_value = ndarray.get_nth_scalar(generator, ctx, zero).value; + ctx.builder.build_store(extremum, first_value).unwrap(); + extremum_idx.store(ctx, zero); - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_int64.const_int(1, false), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); + // The first element is iterated, but this doesn't matter. + ndarray + .foreach(generator, ctx, |generator, ctx, _hooks, nditer| { + let old_extremum = ctx.builder.build_load(extremum, "").unwrap(); + let old_extremum_idx = extremum_idx.load(generator, ctx); - let result = match fn_name { - "np_argmin" | "np_min" => { - call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)) - } - "np_argmax" | "np_max" => { - call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)) - } + let curr_value = nditer.get_scalar(generator, ctx).value; + let curr_idx = nditer.get_index(generator, ctx); + + let new_extremum = match fn_name { + "np_argmin" | "np_min" => call_min( + ctx, + (ndarray.dtype, old_extremum), + (ndarray.dtype, curr_value), + ), + "np_argmax" | "np_max" => call_max( + ctx, + (ndarray.dtype, old_extremum), + (ndarray.dtype, curr_value), + ), _ => codegen_unreachable!(ctx), }; - let updated_idx = match (accumulator, result) { + let new_extremum_idx = match (old_extremum, new_extremum) { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx .builder .build_select( ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(), - idx.into(), - cur_idx, + curr_idx.value, + old_extremum_idx.value, "", ) .unwrap(), @@ -948,24 +908,31 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ctx.builder .build_float_compare(FloatPredicate::ONE, m, n, "") .unwrap(), - idx.into(), - cur_idx, + curr_idx.value, + old_extremum_idx.value, "", ) .unwrap(), - _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]), + _ => unsupported_type(ctx, fn_name, &[ndarray.dtype, ndarray.dtype]), }; - ctx.builder.build_store(res_idx, updated_idx).unwrap(); - ctx.builder.build_store(accumulator_addr, result).unwrap(); + + ctx.builder.build_store(extremum, new_extremum).unwrap(); + + let new_extremum_idx = + unsafe { Int(SizeT).believe_value(new_extremum_idx.into_int_value()) }; + extremum_idx.store(ctx, new_extremum_idx); Ok(()) - }, - llvm_int64.const_int(1, false), - )?; + }) + .unwrap(); match fn_name { - "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(), - "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(), + "np_argmin" | "np_argmax" => extremum_idx + .load(generator, ctx) + .s_extend_or_bit_cast(generator, ctx, Int64) + .value + .as_basic_value_enum(), + "np_max" | "np_min" => ctx.builder.build_load(extremum, "").unwrap(), _ => codegen_unreachable!(ctx), } } @@ -1010,47 +977,32 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => + _ if [&x1_ty, &x2_ty] + .into_iter() + .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2).to_ndarray(generator, ctx); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1.dtype, x2.dtype)); + let common_dtype = x1.dtype; - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( + let result = NDArrayObject::broadcast_starmap( generator, ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + &[x1, x2], + NDArrayOut::NewNDArray { dtype: common_dtype }, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_max(ctx, (x1.dtype, x1_scalar), (x2.dtype, x2_scalar))) }, - )? - .as_base_value() - .into() + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1083,39 +1035,19 @@ where ) -> Option>, RetElemFn: Fn(&mut CodeGenContext<'ctx, '_>, Type) -> Type, { - let result = match arg_val { - BasicValueEnum::PointerValue(x) - if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let llvm_usize = generator.get_size_type(ctx.ctx); - let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); - let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); + let arg = AnyObject { ty: arg_ty, value: arg_val }; + let arg = ScalarOrNDArray::split_object(generator, ctx, arg); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, elem_val| { - helper_call_numpy_unary_elementwise( - generator, - ctx, - (arg_elem_ty, elem_val), - fn_name, - get_ret_elem_type, - on_scalar, - ) - }, - )?; - ndarray.as_base_value().into() - } + let dtype = arg.get_dtype(); - _ => on_scalar(generator, ctx, arg_ty, arg_val) - .unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])), - }; - - Ok(result) + let ret_ty = get_ret_elem_type(ctx, dtype); + let result = arg.map(generator, ctx, ret_ty, |generator, ctx, scalar| { + let Some(result) = on_scalar(generator, ctx, dtype, scalar) else { + unsupported_type(ctx, fn_name, &[arg_ty]) + }; + Ok(result) + })?; + Ok(result.to_basic_value_enum()) } pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( @@ -1434,391 +1366,220 @@ create_helper_call_numpy_unary_elementwise_float_to_float!( pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_arctan2"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - extern_fns::call_atan2(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_atan2(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_copysign` builtin function. pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_copysign"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_copysign(ctx, x1, x2, None) + .as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmax` builtin function. pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_fmax"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmin` builtin function. pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_fmin"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_ldexp` builtin function. pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_ldexp"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - extern_fns::call_ldexp(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1_scalar), BasicValueEnum::IntValue(x2_scalar)) => { + debug_assert!(ctx.unifier.unioned(x1.get_dtype(), ctx.primitives.float)); + debug_assert!(ctx.unifier.unioned(x2.get_dtype(), ctx.primitives.int32)); + Ok(extern_fns::call_ldexp(ctx, x1_scalar, x2_scalar, None) + .as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = - if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; - - let x1_scalar_ty = dtype; - let x2_scalar_ty = - if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_hypot` builtin function. pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_hypot"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - extern_fns::call_hypot(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_hypot(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_nextafter` builtin function. @@ -1833,555 +1594,362 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( let (x1_ty, x1) = x1; let (x2_ty, x2) = x2; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - extern_fns::call_nextafter(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_nextafter(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - codegen_unreachable!(ctx) - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) -} - -/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it -fn build_output_struct<'ctx>( - ctx: &mut CodeGenContext<'ctx, '_>, - out_matrices: Vec>, -) -> PointerValue<'ctx> { - let field_ty = - out_matrices.iter().map(BasicValueEnum::get_type).collect::>(); - let out_ty = ctx.ctx.struct_type(&field_ty, false); - let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap(); - - for (i, v) in out_matrices.into_iter().enumerate() { - unsafe { - let ptr = ctx - .builder - .build_in_bounds_gep( - out_ptr, - &[ - ctx.ctx.i32_type().const_zero(), - ctx.ctx.i32_type().const_int(i as u64, false), - ], - "", - ) - .unwrap(); - ctx.builder.build_store(ptr, v).unwrap(); - } - } - out_ptr + Ok(result.to_basic_value_enum()) } /// Invokes the `np_linalg_cholesky` linalg function pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_cholesky"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let out = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2); + out.copy_shape_from_ndarray(generator, ctx, x1); + out.create_data(generator, ctx); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let out_c = out.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_cholesky( + ctx, + x1_c.value.as_basic_value_enum(), + out_c.value.as_basic_value_enum(), + None, + ); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_cholesky(ctx, x1, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + Ok(out.instance.value.as_basic_value_enum()) } /// Invokes the `np_linalg_qr` linalg function pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_qr"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let x1_shape = x1.instance.get(generator, ctx, |f| f.shape); + let d0 = x1_shape.get_index_const(generator, ctx, 0); + let d1 = x1_shape.get_index_const(generator, ctx, 1); + let dk = unsafe { + Int(SizeT).believe_value(llvm_intrinsics::call_int_smin(ctx, d0.value, d1.value, None)) + }; - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unimplemented!("{FN_NAME} operates on float type NdArrays only"); - }; + let q = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[d0, dk]); + q.create_data(generator, ctx); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); + let r = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[dk, d1]); + r.create_data(generator, ctx); - let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let q_c = q.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let r_c = r.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_qr( + ctx, + x1_c.value.as_basic_value_enum(), + q_c.value.as_basic_value_enum(), + r_c.value.as_basic_value_enum(), + None, + ); - extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None); - - let out_ptr = build_output_struct(ctx, vec![out_q, out_r]); - - Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + let q = q.to_any(ctx); + let r = r.to_any(ctx); + let tuple = TupleObject::from_objects(generator, ctx, [q, r]); + Ok(tuple.value.as_basic_value_enum()) } /// Invokes the `np_linalg_svd` linalg function pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_svd"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let x1_shape = x1.instance.get(generator, ctx, |f| f.shape); + let d0 = x1_shape.get_index_const(generator, ctx, 0); + let d1 = x1_shape.get_index_const(generator, ctx, 1); + let dk = unsafe { + Int(SizeT).believe_value(llvm_intrinsics::call_int_smin(ctx, d0.value, d1.value, None)) + }; - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let u = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[d0, d0]); + u.create_data(generator, ctx); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let s = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[dk]); + s.create_data(generator, ctx); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); + let vh = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[d1, d1]); + vh.create_data(generator, ctx); - let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let u_c = u.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let s_c = s.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let vh_c = vh.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_svd( + ctx, + x1_c.value.as_basic_value_enum(), + u_c.value.as_basic_value_enum(), + s_c.value.as_basic_value_enum(), + vh_c.value.as_basic_value_enum(), + None, + ); - extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None); - - let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]); - - Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + let u = u.to_any(ctx); + let s = s.to_any(ctx); + let vh = vh.to_any(ctx); + let tuple = TupleObject::from_objects(generator, ctx, [u, s, vh]); + Ok(tuple.value.as_basic_value_enum()) } /// Invokes the `np_linalg_inv` linalg function pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_inv"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let out = NDArrayObject::alloca(generator, ctx, x1.dtype, 2); + out.copy_shape_from_ndarray(generator, ctx, x1); + out.create_data(generator, ctx); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let out_c = out.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_inv( + ctx, + x1_c.value.as_basic_value_enum(), + out_c.value.as_basic_value_enum(), + None, + ); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_inv(ctx, x1, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + Ok(out.instance.value.as_basic_value_enum()) } /// Invokes the `np_linalg_pinv` linalg function pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_pinv"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let x1_shape = x1.instance.get(generator, ctx, |f| f.shape); + let d0 = x1_shape.get_index_const(generator, ctx, 0); + let d1 = x1_shape.get_index_const(generator, ctx, 1); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let out = NDArrayObject::alloca_dynamic_shape(generator, ctx, x1.dtype, &[d1, d0]); + out.create_data(generator, ctx); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let out_c = out.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_pinv( + ctx, + x1_c.value.as_basic_value_enum(), + out_c.value.as_basic_value_enum(), + None, + ); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_pinv(ctx, x1, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + Ok(out.instance.value.as_basic_value_enum()) } /// Invokes the `sp_linalg_lu` linalg function pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "sp_linalg_lu"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let x1_shape = x1.instance.get(generator, ctx, |f| f.shape); + let d0 = x1_shape.get_index_const(generator, ctx, 0); + let d1 = x1_shape.get_index_const(generator, ctx, 1); + let dk = unsafe { + Int(SizeT).believe_value(llvm_intrinsics::call_int_smin(ctx, d0.value, d1.value, None)) + }; - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let l = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[d0, dk]); + l.create_data(generator, ctx); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let u = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[dk, d1]); + u.create_data(generator, ctx); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let l_c = l.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let u_c = u.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_sp_linalg_lu( + ctx, + x1_c.value.as_basic_value_enum(), + l_c.value.as_basic_value_enum(), + u_c.value.as_basic_value_enum(), + None, + ); - let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None); - - let out_ptr = build_output_struct(ctx, vec![out_l, out_u]); - Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + let l = l.to_any(ctx); + let u = u.to_any(ctx); + let tuple = TupleObject::from_objects(generator, ctx, [l, u]); + Ok(tuple.value.as_basic_value_enum()) } /// Invokes the `np_linalg_matrix_power` linalg function pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_matrix_power"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); + + // x2 is a float, but we are promoting this to a 1D ndarray (.shape == [1]) for uniformity in function call. let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap(); + let x2 = AnyObject { ty: ctx.primitives.float, value: x2 }; + let x2 = NDArrayObject::make_unsized(generator, ctx, x2); // x2.shape == [] + let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1] - let llvm_usize = generator.get_size_type(ctx.ctx); - if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let out = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2); + out.copy_shape_from_ndarray(generator, ctx, x1); + out.create_data(generator, ctx); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); - }; + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let x2_c = x2.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let out_c = out.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_matrix_power( + ctx, + x1_c.value.as_basic_value_enum(), + x2_c.value.as_basic_value_enum(), + out_c.value.as_basic_value_enum(), + None, + ); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - // Changing second parameter to a `NDArray` for uniformity in function call - let n2_array = numpy::create_ndarray_const_shape( - generator, - ctx, - elem_ty, - &[llvm_usize.const_int(1, false)], - ) - .unwrap(); - unsafe { - n2_array.data().set_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - n2.as_basic_value_enum(), - ); - }; - let n2_array = n2_array.as_base_value().as_basic_value_enum(); - - let outdim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let outdim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) - } + Ok(out.instance.value.as_basic_value_enum()) } /// Invokes the `np_linalg_det` linalg function pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_matrix_power"; - let (x1_ty, x1) = x1; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - let llvm_usize = generator.get_size_type(ctx.ctx); - if let BasicValueEnum::PointerValue(_) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + // The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call. + let det = NDArrayObject::alloca_constant_shape(generator, ctx, ctx.primitives.float, &[1]); + det.create_data(generator, ctx); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let out_c = det.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_det( + ctx, + x1_c.value.as_basic_value_enum(), + out_c.value.as_basic_value_enum(), + None, + ); - // Changing second parameter to a `NDArray` for uniformity in function call - let out = numpy::create_ndarray_const_shape( - generator, - ctx, - elem_ty, - &[llvm_usize.const_int(1, false)], - ) - .unwrap(); - extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None); - let res = - unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - Ok(res) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + // Get the determinant out of `out` + let zero = Int(SizeT).const_0(generator, ctx.ctx); + let det = det.get_nth_scalar(generator, ctx, zero); + Ok(det.value) } /// Invokes the `sp_linalg_schur` linalg function pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "sp_linalg_schur"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); + assert_eq!(x1.ndims, 2); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let t = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2); + t.copy_shape_from_ndarray(generator, ctx, x1); + t.create_data(generator, ctx); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let z = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2); + z.copy_shape_from_ndarray(generator, ctx, x1); + z.create_data(generator, ctx); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let t_c = t.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let z_c = z.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_sp_linalg_schur( + ctx, + x1_c.value.as_basic_value_enum(), + t_c.value.as_basic_value_enum(), + z_c.value.as_basic_value_enum(), + None, + ); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None); - - let out_ptr = build_output_struct(ctx, vec![out_t, out_z]); - Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + let t = t.to_any(ctx); + let z = z.to_any(ctx); + let tuple = TupleObject::from_objects(generator, ctx, [t, z]); + Ok(tuple.value.as_basic_value_enum()) } /// Invokes the `sp_linalg_hessenberg` linalg function pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "sp_linalg_hessenberg"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); + assert_eq!(x1.ndims, 2); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let h = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2); + h.copy_shape_from_ndarray(generator, ctx, x1); + h.create_data(generator, ctx); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let q = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2); + q.copy_shape_from_ndarray(generator, ctx, x1); + q.create_data(generator, ctx); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let h_c = h.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let q_c = q.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_sp_linalg_hessenberg( + ctx, + x1_c.value.as_basic_value_enum(), + h_c.value.as_basic_value_enum(), + q_c.value.as_basic_value_enum(), + None, + ); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None); - - let out_ptr = build_output_struct(ctx, vec![out_h, out_q]); - Ok(ctx - .builder - .build_load(out_ptr, "Hessenberg_decomposition_result") - .map(Into::into) - .unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + let h = h.to_any(ctx); + let q = q.to_any(ctx); + let tuple = TupleObject::from_objects(generator, ctx, [h, q]); + Ok(tuple.value.as_basic_value_enum()) }