From 6cbba8fdde439afad1cfab65cb3267ee483e043a Mon Sep 17 00:00:00 2001
From: David Mak <chmakac@connect.ust.hk>
Date: Thu, 19 Dec 2024 11:24:28 +0800
Subject: [PATCH] [core] codegen: Reimplement builtin funcs to support strided
 ndarrays

Based on 7f3c4530: core/ndstrides: update builtin_fns to use ndarray
with strides
---
 nac3core/src/codegen/builtin_fns.rs | 959 +++++++++++-----------------
 1 file changed, 368 insertions(+), 591 deletions(-)

diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs
index 54650ab3..7c8ad7a8 100644
--- a/nac3core/src/codegen/builtin_fns.rs
+++ b/nac3core/src/codegen/builtin_fns.rs
@@ -11,19 +11,16 @@ use super::{
     irrt::calculate_len_for_slice_range,
     llvm_intrinsics,
     macros::codegen_unreachable,
-    numpy,
-    numpy::ndarray_elementwise_unaryop_impl,
-    stmt::gen_for_callback_incrementing,
     types::{ndarray::NDArrayType, ListType, TupleType},
     values::{
-        ndarray::NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
-        UntypedArrayLikeAccessor,
+        ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray},
+        ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
     },
     CodeGenContext, CodeGenerator,
 };
 use crate::{
     toplevel::{
-        helper::{extract_ndims, PrimDef},
+        helper::{arraylike_flatten_element_type, extract_ndims, PrimDef},
         numpy::unpack_ndarray_var_tys,
     },
     typecheck::typedef::{Type, TypeEnum},
@@ -129,18 +126,18 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
             if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
         {
             let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
-            let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty);
+            let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None);
 
-            let ndarray = ndarray_elementwise_unaryop_impl(
-                generator,
-                ctx,
-                ctx.primitives.int32,
-                None,
-                llvm_ndarray_ty.map_value(n, None),
-                |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)),
-            )?;
+            let result = ndarray
+                .map(
+                    generator,
+                    ctx,
+                    NDArrayOut::NewNDArray { dtype: ctx.ctx.i32_type().into() },
+                    |generator, ctx, scalar| call_int32(generator, ctx, (elem_ty, scalar)),
+                )
+                .unwrap();
 
-            ndarray.as_base_value().into()
+            result.as_base_value().into()
         }
 
         _ => unsupported_type(ctx, "int32", &[n_ty]),
@@ -189,18 +186,18 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
             if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
         {
             let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
-            let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty);
+            let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None);
 
-            let ndarray = ndarray_elementwise_unaryop_impl(
-                generator,
-                ctx,
-                ctx.primitives.int64,
-                None,
-                llvm_ndarray_ty.map_value(n, None),
-                |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)),
-            )?;
+            let result = ndarray
+                .map(
+                    generator,
+                    ctx,
+                    NDArrayOut::NewNDArray { dtype: ctx.ctx.i64_type().into() },
+                    |generator, ctx, scalar| call_int64(generator, ctx, (elem_ty, scalar)),
+                )
+                .unwrap();
 
-            ndarray.as_base_value().into()
+            result.as_base_value().into()
         }
 
         _ => unsupported_type(ctx, "int64", &[n_ty]),
@@ -265,18 +262,18 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
             if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
         {
             let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
-            let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty);
+            let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None);
 
-            let ndarray = ndarray_elementwise_unaryop_impl(
-                generator,
-                ctx,
-                ctx.primitives.uint32,
-                None,
-                llvm_ndarray_ty.map_value(n, None),
-                |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)),
-            )?;
+            let result = ndarray
+                .map(
+                    generator,
+                    ctx,
+                    NDArrayOut::NewNDArray { dtype: ctx.ctx.i32_type().into() },
+                    |generator, ctx, scalar| call_uint32(generator, ctx, (elem_ty, scalar)),
+                )
+                .unwrap();
 
-            ndarray.as_base_value().into()
+            result.as_base_value().into()
         }
 
         _ => unsupported_type(ctx, "uint32", &[n_ty]),
@@ -330,18 +327,18 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
             if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
         {
             let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
-            let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty);
+            let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None);
 
-            let ndarray = ndarray_elementwise_unaryop_impl(
-                generator,
-                ctx,
-                ctx.primitives.uint64,
-                None,
-                llvm_ndarray_ty.map_value(n, None),
-                |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)),
-            )?;
+            let result = ndarray
+                .map(
+                    generator,
+                    ctx,
+                    NDArrayOut::NewNDArray { dtype: ctx.ctx.i64_type().into() },
+                    |generator, ctx, scalar| call_uint64(generator, ctx, (elem_ty, scalar)),
+                )
+                .unwrap();
 
-            ndarray.as_base_value().into()
+            result.as_base_value().into()
         }
 
         _ => unsupported_type(ctx, "uint64", &[n_ty]),
@@ -355,7 +352,6 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
     (n_ty, n): (Type, BasicValueEnum<'ctx>),
 ) -> Result<BasicValueEnum<'ctx>, String> {
     let llvm_f64 = ctx.ctx.f64_type();
-    let llvm_usize = generator.get_size_type(ctx.ctx);
 
     Ok(match n {
         BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => {
@@ -394,20 +390,19 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
         BasicValueEnum::PointerValue(n)
             if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
         {
-            let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
-            let ndims = extract_ndims(&ctx.unifier, ndims);
-            let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
+            let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
+            let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None);
 
-            let ndarray = ndarray_elementwise_unaryop_impl(
-                generator,
-                ctx,
-                ctx.primitives.float,
-                None,
-                NDArrayValue::from_pointer_value(n, llvm_elem_ty, Some(ndims), llvm_usize, None),
-                |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)),
-            )?;
+            let result = ndarray
+                .map(
+                    generator,
+                    ctx,
+                    NDArrayOut::NewNDArray { dtype: ctx.ctx.f64_type().into() },
+                    |generator, ctx, scalar| call_float(generator, ctx, (elem_ty, scalar)),
+                )
+                .unwrap();
 
-            ndarray.as_base_value().into()
+            result.as_base_value().into()
         }
 
         _ => unsupported_type(ctx, "float", &[n_ty]),
@@ -440,18 +435,20 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
             if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
         {
             let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
-            let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty);
+            let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None);
 
-            let ndarray = ndarray_elementwise_unaryop_impl(
-                generator,
-                ctx,
-                ret_elem_ty,
-                None,
-                llvm_ndarray_ty.map_value(n, None),
-                |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty),
-            )?;
+            let result = ndarray
+                .map(
+                    generator,
+                    ctx,
+                    NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty.into() },
+                    |generator, ctx, scalar| {
+                        call_round(generator, ctx, (elem_ty, scalar), ret_elem_ty)
+                    },
+                )
+                .unwrap();
 
-            ndarray.as_base_value().into()
+            result.as_base_value().into()
         }
 
         _ => unsupported_type(ctx, FN_NAME, &[n_ty]),
@@ -477,18 +474,18 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
             if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
         {
             let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
-            let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty);
+            let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None);
 
-            let ndarray = ndarray_elementwise_unaryop_impl(
-                generator,
-                ctx,
-                ctx.primitives.float,
-                None,
-                llvm_ndarray_ty.map_value(n, None),
-                |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)),
-            )?;
+            let result = ndarray
+                .map(
+                    generator,
+                    ctx,
+                    NDArrayOut::NewNDArray { dtype: ctx.ctx.f64_type().into() },
+                    |generator, ctx, scalar| call_numpy_round(generator, ctx, (elem_ty, scalar)),
+                )
+                .unwrap();
 
-            ndarray.as_base_value().into()
+            result.as_base_value().into()
         }
 
         _ => unsupported_type(ctx, FN_NAME, &[n_ty]),
@@ -539,22 +536,21 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
             if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
         {
             let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
-            let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty);
+            let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None);
 
-            let ndarray = ndarray_elementwise_unaryop_impl(
-                generator,
-                ctx,
-                ctx.primitives.bool,
-                None,
-                llvm_ndarray_ty.map_value(n, None),
-                |generator, ctx, val| {
-                    let elem = call_bool(generator, ctx, (elem_ty, val))?;
+            let result = ndarray
+                .map(
+                    generator,
+                    ctx,
+                    NDArrayOut::NewNDArray { dtype: ctx.ctx.i8_type().into() },
+                    |generator, ctx, scalar| {
+                        let elem = call_bool(generator, ctx, (elem_ty, scalar))?;
+                        Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into())
+                    },
+                )
+                .unwrap();
 
-                    Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into())
-                },
-            )?;
-
-            ndarray.as_base_value().into()
+            result.as_base_value().into()
         }
 
         _ => unsupported_type(ctx, FN_NAME, &[n_ty]),
@@ -591,18 +587,20 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
             if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
         {
             let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
-            let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty);
+            let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None);
 
-            let ndarray = ndarray_elementwise_unaryop_impl(
-                generator,
-                ctx,
-                ret_elem_ty,
-                None,
-                llvm_ndarray_ty.map_value(n, None),
-                |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
-            )?;
+            let result = ndarray
+                .map(
+                    generator,
+                    ctx,
+                    NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty },
+                    |generator, ctx, scalar| {
+                        call_floor(generator, ctx, (elem_ty, scalar), ret_elem_ty)
+                    },
+                )
+                .unwrap();
 
-            ndarray.as_base_value().into()
+            result.as_base_value().into()
         }
 
         _ => unsupported_type(ctx, FN_NAME, &[n_ty]),
@@ -639,18 +637,20 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
             if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
         {
             let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
-            let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, n_ty);
+            let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None);
 
-            let ndarray = ndarray_elementwise_unaryop_impl(
-                generator,
-                ctx,
-                ret_elem_ty,
-                None,
-                llvm_ndarray_ty.map_value(n, None),
-                |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
-            )?;
+            let result = ndarray
+                .map(
+                    generator,
+                    ctx,
+                    NDArrayOut::NewNDArray { dtype: llvm_ret_elem_ty },
+                    |generator, ctx, scalar| {
+                        call_ceil(generator, ctx, (elem_ty, scalar), ret_elem_ty)
+                    },
+                )
+                .unwrap();
 
-            ndarray.as_base_value().into()
+            result.as_base_value().into()
         }
 
         _ => unsupported_type(ctx, FN_NAME, &[n_ty]),
@@ -741,42 +741,37 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
                 ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
             }) =>
         {
-            let is_ndarray1 =
-                x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
-            let is_ndarray2 =
-                x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
+            let x1 =
+                ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)).to_ndarray(generator, ctx);
+            let x2 =
+                ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)).to_ndarray(generator, ctx);
 
-            let dtype = if is_ndarray1 && is_ndarray2 {
-                let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
-                let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
+            let x1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty);
+            let x2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty);
 
-                debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
+            debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype));
+            let llvm_common_dtype = x1.get_type().element_type();
 
-                ndarray_dtype1
-            } else if is_ndarray1 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
-            } else if is_ndarray2 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
-            } else {
-                codegen_unreachable!(ctx)
-            };
-
-            let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
-            let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty };
-
-            numpy::ndarray_elementwise_binop_impl(
+            let result = NDArrayType::new_broadcast(
+                generator,
+                ctx.ctx,
+                llvm_common_dtype,
+                &[x1.get_type(), x2.get_type()],
+            )
+            .broadcast_starmap(
                 generator,
                 ctx,
-                dtype,
-                None,
-                (x1_ty, x1, !is_ndarray1),
-                (x2_ty, x2, !is_ndarray2),
-                |generator, ctx, (lhs, rhs)| {
-                    call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
+                &[x1, x2],
+                NDArrayOut::NewNDArray { dtype: llvm_common_dtype },
+                |_, ctx, scalars| {
+                    let x1_scalar = scalars[0];
+                    let x2_scalar = scalars[1];
+                    Ok(call_min(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar)))
                 },
-            )?
-            .as_base_value()
-            .into()
+            )
+            .unwrap();
+
+            result.as_base_value().into()
         }
 
         _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
@@ -861,23 +856,26 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
                 _ => codegen_unreachable!(ctx),
             }
         }
+
         BasicValueEnum::PointerValue(n)
             if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
         {
             let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
-            let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty);
 
-            let n = llvm_ndarray_ty.map_value(n, None);
-            let n_sz = n.size(generator, ctx);
+            let ndarray = NDArrayType::from_unifier_type(generator, ctx, a_ty).map_value(n, None);
+            let llvm_dtype = ndarray.get_type().element_type();
+
+            let zero = llvm_usize.const_zero();
+
             if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
-                let n_sz_eqz = ctx
+                let size_nez = ctx
                     .builder
-                    .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
+                    .build_int_compare(IntPredicate::NE, ndarray.size(generator, ctx), zero, "")
                     .unwrap();
 
                 ctx.make_assert(
                     generator,
-                    n_sz_eqz,
+                    size_nez,
                     "0:ValueError",
                     format!("zero-size array to reduction operation {fn_name}").as_str(),
                     [None, None, None],
@@ -885,54 +883,43 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
                 );
             }
 
-            let accumulator_addr =
-                generator.gen_var_alloc(ctx, llvm_ndarray_ty.element_type(), None)?;
-            let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?;
+            let extremum = generator.gen_var_alloc(ctx, llvm_dtype, None)?;
+            let extremum_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
 
-            unsafe {
-                let identity =
-                    n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
-                ctx.builder.build_store(accumulator_addr, identity).unwrap();
-                ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap();
-            }
+            let first_value = unsafe { ndarray.data().get_unchecked(ctx, generator, &zero, None) };
+            ctx.builder.build_store(extremum, first_value).unwrap();
+            ctx.builder.build_store(extremum_idx, zero).unwrap();
 
-            gen_for_callback_incrementing(
-                generator,
-                ctx,
-                None,
-                llvm_int64.const_int(1, false),
-                (n_sz, false),
-                |generator, ctx, _, idx| {
-                    let elem = unsafe {
-                        n.data().get_unchecked(
-                            ctx,
-                            generator,
-                            &ctx.builder
-                                .build_int_truncate_or_bit_cast(idx, llvm_usize, "")
-                                .unwrap(),
-                            None,
-                        )
-                    };
-                    let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
-                    let cur_idx = ctx.builder.build_load(res_idx, "").unwrap();
+            // The first element is iterated, but this doesn't matter.
+            ndarray
+                .foreach(generator, ctx, |_, ctx, _, nditer| {
+                    let old_extremum = ctx.builder.build_load(extremum, "").unwrap();
+                    let old_extremum_idx = ctx
+                        .builder
+                        .build_load(extremum_idx, "")
+                        .map(BasicValueEnum::into_int_value)
+                        .unwrap();
 
-                    let result = match fn_name {
+                    let curr_value = nditer.get_scalar(ctx);
+                    let curr_idx = nditer.get_index(ctx);
+
+                    let new_extremum = match fn_name {
                         "np_argmin" | "np_min" => {
-                            call_min(ctx, (elem_ty, accumulator), (elem_ty, elem))
+                            call_min(ctx, (elem_ty, old_extremum), (elem_ty, curr_value))
                         }
                         "np_argmax" | "np_max" => {
-                            call_max(ctx, (elem_ty, accumulator), (elem_ty, elem))
+                            call_max(ctx, (elem_ty, old_extremum), (elem_ty, curr_value))
                         }
                         _ => codegen_unreachable!(ctx),
                     };
 
-                    let updated_idx = match (accumulator, result) {
+                    let new_extremum_idx = match (old_extremum, new_extremum) {
                         (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx
                             .builder
                             .build_select(
                                 ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(),
-                                idx.into(),
-                                cur_idx,
+                                curr_idx,
+                                old_extremum_idx,
                                 "",
                             )
                             .unwrap(),
@@ -942,24 +929,35 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
                                 ctx.builder
                                     .build_float_compare(FloatPredicate::ONE, m, n, "")
                                     .unwrap(),
-                                idx.into(),
-                                cur_idx,
+                                curr_idx,
+                                old_extremum_idx,
                                 "",
                             )
                             .unwrap(),
                         _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]),
                     };
-                    ctx.builder.build_store(res_idx, updated_idx).unwrap();
-                    ctx.builder.build_store(accumulator_addr, result).unwrap();
+
+                    ctx.builder.build_store(extremum, new_extremum).unwrap();
+                    ctx.builder.build_store(extremum_idx, new_extremum_idx).unwrap();
 
                     Ok(())
-                },
-                llvm_int64.const_int(1, false),
-            )?;
+                })
+                .unwrap();
 
             match fn_name {
-                "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(),
-                "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(),
+                "np_argmin" | "np_argmax" => ctx
+                    .builder
+                    .build_int_s_extend_or_bit_cast(
+                        ctx.builder
+                            .build_load(extremum_idx, "")
+                            .map(BasicValueEnum::into_int_value)
+                            .unwrap(),
+                        ctx.ctx.i64_type(),
+                        "",
+                    )
+                    .unwrap()
+                    .into(),
+                "np_max" | "np_min" => ctx.builder.build_load(extremum, "").unwrap(),
                 _ => codegen_unreachable!(ctx),
             }
         }
@@ -1006,42 +1004,37 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
                 ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
             }) =>
         {
-            let is_ndarray1 =
-                x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
-            let is_ndarray2 =
-                x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
+            let x1 =
+                ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1)).to_ndarray(generator, ctx);
+            let x2 =
+                ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2)).to_ndarray(generator, ctx);
 
-            let dtype = if is_ndarray1 && is_ndarray2 {
-                let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
-                let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
+            let x1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty);
+            let x2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty);
 
-                debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
+            debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype));
+            let llvm_common_dtype = x1.get_type().element_type();
 
-                ndarray_dtype1
-            } else if is_ndarray1 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
-            } else if is_ndarray2 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
-            } else {
-                codegen_unreachable!(ctx)
-            };
-
-            let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
-            let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty };
-
-            numpy::ndarray_elementwise_binop_impl(
+            let result = NDArrayType::new_broadcast(
+                generator,
+                ctx.ctx,
+                llvm_common_dtype,
+                &[x1.get_type(), x2.get_type()],
+            )
+            .broadcast_starmap(
                 generator,
                 ctx,
-                dtype,
-                None,
-                (x1_ty, x1, !is_ndarray1),
-                (x2_ty, x2, !is_ndarray2),
-                |generator, ctx, (lhs, rhs)| {
-                    call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
+                &[x1, x2],
+                NDArrayOut::NewNDArray { dtype: llvm_common_dtype },
+                |_, ctx, scalars| {
+                    let x1_scalar = scalars[0];
+                    let x2_scalar = scalars[1];
+                    Ok(call_max(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar)))
                 },
-            )?
-            .as_base_value()
-            .into()
+            )
+            .unwrap();
+
+            result.as_base_value().into()
         }
 
         _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
@@ -1074,39 +1067,20 @@ where
     ) -> Option<BasicValueEnum<'ctx>>,
     RetElemFn: Fn(&mut CodeGenContext<'ctx, '_>, Type) -> Type,
 {
-    let result = match arg_val {
-        BasicValueEnum::PointerValue(x)
-            if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
-        {
-            let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
-            let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, arg_ty);
-            let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty);
+    let arg = ScalarOrNDArray::from_value(generator, ctx, (arg_ty, arg_val));
 
-            let ndarray = ndarray_elementwise_unaryop_impl(
-                generator,
-                ctx,
-                ret_elem_ty,
-                None,
-                llvm_ndarray_ty.map_value(x, None),
-                |generator, ctx, elem_val| {
-                    helper_call_numpy_unary_elementwise(
-                        generator,
-                        ctx,
-                        (arg_elem_ty, elem_val),
-                        fn_name,
-                        get_ret_elem_type,
-                        on_scalar,
-                    )
-                },
-            )?;
-            ndarray.as_base_value().into()
-        }
+    let dtype = arraylike_flatten_element_type(&mut ctx.unifier, arg_ty);
 
-        _ => on_scalar(generator, ctx, arg_ty, arg_val)
-            .unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])),
-    };
+    let ret_ty = get_ret_elem_type(ctx, dtype);
+    let llvm_ret_ty = ctx.get_llvm_type(generator, ret_ty);
+    let result = arg.map(generator, ctx, llvm_ret_ty, |generator, ctx, scalar| {
+        let Some(result) = on_scalar(generator, ctx, dtype, scalar) else {
+            unsupported_type(ctx, fn_name, &[arg_ty])
+        };
+        Ok(result)
+    })?;
 
-    Ok(result)
+    Ok(result.to_basic_value_enum())
 }
 
 pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
@@ -1431,59 +1405,29 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
 ) -> Result<BasicValueEnum<'ctx>, String> {
     const FN_NAME: &str = "np_arctan2";
 
-    Ok(match (x1, x2) {
-        (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
-            debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
-            debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
+    let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1));
+    let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2));
 
-            extern_fns::call_atan2(ctx, x1, x2, None).into()
-        }
+    let result = ScalarOrNDArray::broadcasting_starmap(
+        generator,
+        ctx,
+        &[x1, x2],
+        ctx.ctx.f64_type().into(),
+        |_, ctx, scalars| {
+            let x1_scalar = scalars[0];
+            let x2_scalar = scalars[1];
 
-        (x1, x2)
-            if [&x1_ty, &x2_ty].into_iter().any(|ty| {
-                ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
-            }) =>
-        {
-            let is_ndarray1 =
-                x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
-            let is_ndarray2 =
-                x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
+            match (x1_scalar, x2_scalar) {
+                (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
+                    Ok(extern_fns::call_atan2(ctx, x1, x2, None).into())
+                }
+                _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
+            }
+        },
+    )
+    .unwrap();
 
-            let dtype = if is_ndarray1 && is_ndarray2 {
-                let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
-                let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
-
-                debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
-
-                ndarray_dtype1
-            } else if is_ndarray1 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
-            } else if is_ndarray2 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
-            } else {
-                codegen_unreachable!(ctx)
-            };
-
-            let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
-            let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty };
-
-            numpy::ndarray_elementwise_binop_impl(
-                generator,
-                ctx,
-                dtype,
-                None,
-                (x1_ty, x1, !is_ndarray1),
-                (x2_ty, x2, !is_ndarray2),
-                |generator, ctx, (lhs, rhs)| {
-                    call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
-                },
-            )?
-            .as_base_value()
-            .into()
-        }
-
-        _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
-    })
+    Ok(result.to_basic_value_enum())
 }
 
 /// Invokes the `np_copysign` builtin function.
@@ -1495,59 +1439,29 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
 ) -> Result<BasicValueEnum<'ctx>, String> {
     const FN_NAME: &str = "np_copysign";
 
-    Ok(match (x1, x2) {
-        (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
-            debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
-            debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
+    let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1));
+    let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2));
 
-            llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into()
-        }
+    let result = ScalarOrNDArray::broadcasting_starmap(
+        generator,
+        ctx,
+        &[x1, x2],
+        ctx.ctx.f64_type().into(),
+        |_, ctx, scalars| {
+            let x1_scalar = scalars[0];
+            let x2_scalar = scalars[1];
 
-        (x1, x2)
-            if [&x1_ty, &x2_ty].into_iter().any(|ty| {
-                ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
-            }) =>
-        {
-            let is_ndarray1 =
-                x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
-            let is_ndarray2 =
-                x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
+            match (x1_scalar, x2_scalar) {
+                (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
+                    Ok(llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into())
+                }
+                _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
+            }
+        },
+    )
+    .unwrap();
 
-            let dtype = if is_ndarray1 && is_ndarray2 {
-                let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
-                let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
-
-                debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
-
-                ndarray_dtype1
-            } else if is_ndarray1 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
-            } else if is_ndarray2 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
-            } else {
-                codegen_unreachable!(ctx)
-            };
-
-            let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
-            let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty };
-
-            numpy::ndarray_elementwise_binop_impl(
-                generator,
-                ctx,
-                dtype,
-                None,
-                (x1_ty, x1, !is_ndarray1),
-                (x2_ty, x2, !is_ndarray2),
-                |generator, ctx, (lhs, rhs)| {
-                    call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
-                },
-            )?
-            .as_base_value()
-            .into()
-        }
-
-        _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
-    })
+    Ok(result.to_basic_value_enum())
 }
 
 /// Invokes the `np_fmax` builtin function.
@@ -1559,59 +1473,29 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
 ) -> Result<BasicValueEnum<'ctx>, String> {
     const FN_NAME: &str = "np_fmax";
 
-    Ok(match (x1, x2) {
-        (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
-            debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
-            debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
+    let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1));
+    let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2));
 
-            llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into()
-        }
+    let result = ScalarOrNDArray::broadcasting_starmap(
+        generator,
+        ctx,
+        &[x1, x2],
+        ctx.ctx.f64_type().into(),
+        |_, ctx, scalars| {
+            let x1_scalar = scalars[0];
+            let x2_scalar = scalars[1];
 
-        (x1, x2)
-            if [&x1_ty, &x2_ty].into_iter().any(|ty| {
-                ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
-            }) =>
-        {
-            let is_ndarray1 =
-                x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
-            let is_ndarray2 =
-                x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
+            match (x1_scalar, x2_scalar) {
+                (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
+                    Ok(llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into())
+                }
+                _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
+            }
+        },
+    )
+    .unwrap();
 
-            let dtype = if is_ndarray1 && is_ndarray2 {
-                let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
-                let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
-
-                debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
-
-                ndarray_dtype1
-            } else if is_ndarray1 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
-            } else if is_ndarray2 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
-            } else {
-                codegen_unreachable!(ctx)
-            };
-
-            let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
-            let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty };
-
-            numpy::ndarray_elementwise_binop_impl(
-                generator,
-                ctx,
-                dtype,
-                None,
-                (x1_ty, x1, !is_ndarray1),
-                (x2_ty, x2, !is_ndarray2),
-                |generator, ctx, (lhs, rhs)| {
-                    call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
-                },
-            )?
-            .as_base_value()
-            .into()
-        }
-
-        _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
-    })
+    Ok(result.to_basic_value_enum())
 }
 
 /// Invokes the `np_fmin` builtin function.
@@ -1623,59 +1507,29 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
 ) -> Result<BasicValueEnum<'ctx>, String> {
     const FN_NAME: &str = "np_fmin";
 
-    Ok(match (x1, x2) {
-        (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
-            debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
-            debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
+    let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1));
+    let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2));
 
-            llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into()
-        }
+    let result = ScalarOrNDArray::broadcasting_starmap(
+        generator,
+        ctx,
+        &[x1, x2],
+        ctx.ctx.f64_type().into(),
+        |_, ctx, scalars| {
+            let x1_scalar = scalars[0];
+            let x2_scalar = scalars[1];
 
-        (x1, x2)
-            if [&x1_ty, &x2_ty].into_iter().any(|ty| {
-                ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
-            }) =>
-        {
-            let is_ndarray1 =
-                x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
-            let is_ndarray2 =
-                x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
+            match (x1_scalar, x2_scalar) {
+                (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
+                    Ok(llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into())
+                }
+                _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
+            }
+        },
+    )
+    .unwrap();
 
-            let dtype = if is_ndarray1 && is_ndarray2 {
-                let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
-                let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
-
-                debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
-
-                ndarray_dtype1
-            } else if is_ndarray1 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
-            } else if is_ndarray2 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
-            } else {
-                codegen_unreachable!(ctx)
-            };
-
-            let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
-            let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty };
-
-            numpy::ndarray_elementwise_binop_impl(
-                generator,
-                ctx,
-                dtype,
-                None,
-                (x1_ty, x1, !is_ndarray1),
-                (x2_ty, x2, !is_ndarray2),
-                |generator, ctx, (lhs, rhs)| {
-                    call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
-                },
-            )?
-            .as_base_value()
-            .into()
-        }
-
-        _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
-    })
+    Ok(result.to_basic_value_enum())
 }
 
 /// Invokes the `np_ldexp` builtin function.
@@ -1687,48 +1541,31 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
 ) -> Result<BasicValueEnum<'ctx>, String> {
     const FN_NAME: &str = "np_ldexp";
 
-    Ok(match (x1, x2) {
-        (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => {
-            debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
-            debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32));
+    let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1));
+    let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2));
 
-            extern_fns::call_ldexp(ctx, x1, x2, None).into()
-        }
+    let result = ScalarOrNDArray::broadcasting_starmap(
+        generator,
+        ctx,
+        &[x1, x2],
+        ctx.ctx.f64_type().into(),
+        |_, ctx, scalars| {
+            let x1_scalar = scalars[0];
+            let x2_scalar = scalars[1];
 
-        (x1, x2)
-            if [&x1_ty, &x2_ty].into_iter().any(|ty| {
-                ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
-            }) =>
-        {
-            let is_ndarray1 =
-                x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
-            let is_ndarray2 =
-                x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
+            match (x1_scalar, x2_scalar) {
+                (BasicValueEnum::FloatValue(x1_scalar), BasicValueEnum::IntValue(x2_scalar)) => {
+                    debug_assert_eq!(x1.get_dtype(), ctx.ctx.f64_type().into());
+                    debug_assert_eq!(x2.get_dtype(), ctx.ctx.i32_type().into());
+                    Ok(extern_fns::call_ldexp(ctx, x1_scalar, x2_scalar, None).into())
+                }
+                _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
+            }
+        },
+    )
+    .unwrap();
 
-            let dtype =
-                if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty };
-
-            let x1_scalar_ty = dtype;
-            let x2_scalar_ty =
-                if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty };
-
-            numpy::ndarray_elementwise_binop_impl(
-                generator,
-                ctx,
-                dtype,
-                None,
-                (x1_ty, x1, !is_ndarray1),
-                (x2_ty, x2, !is_ndarray2),
-                |generator, ctx, (lhs, rhs)| {
-                    call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
-                },
-            )?
-            .as_base_value()
-            .into()
-        }
-
-        _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
-    })
+    Ok(result.to_basic_value_enum())
 }
 
 /// Invokes the `np_hypot` builtin function.
@@ -1740,59 +1577,29 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
 ) -> Result<BasicValueEnum<'ctx>, String> {
     const FN_NAME: &str = "np_hypot";
 
-    Ok(match (x1, x2) {
-        (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
-            debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
-            debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
+    let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1));
+    let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2));
 
-            extern_fns::call_hypot(ctx, x1, x2, None).into()
-        }
+    let result = ScalarOrNDArray::broadcasting_starmap(
+        generator,
+        ctx,
+        &[x1, x2],
+        ctx.ctx.f64_type().into(),
+        |_, ctx, scalars| {
+            let x1_scalar = scalars[0];
+            let x2_scalar = scalars[1];
 
-        (x1, x2)
-            if [&x1_ty, &x2_ty].into_iter().any(|ty| {
-                ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
-            }) =>
-        {
-            let is_ndarray1 =
-                x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
-            let is_ndarray2 =
-                x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
+            match (x1_scalar, x2_scalar) {
+                (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
+                    Ok(extern_fns::call_hypot(ctx, x1, x2, None).into())
+                }
+                _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
+            }
+        },
+    )
+    .unwrap();
 
-            let dtype = if is_ndarray1 && is_ndarray2 {
-                let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
-                let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
-
-                debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
-
-                ndarray_dtype1
-            } else if is_ndarray1 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
-            } else if is_ndarray2 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
-            } else {
-                codegen_unreachable!(ctx)
-            };
-
-            let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
-            let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty };
-
-            numpy::ndarray_elementwise_binop_impl(
-                generator,
-                ctx,
-                dtype,
-                None,
-                (x1_ty, x1, !is_ndarray1),
-                (x2_ty, x2, !is_ndarray2),
-                |generator, ctx, (lhs, rhs)| {
-                    call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
-                },
-            )?
-            .as_base_value()
-            .into()
-        }
-
-        _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
-    })
+    Ok(result.to_basic_value_enum())
 }
 
 /// Invokes the `np_nextafter` builtin function.
@@ -1804,59 +1611,29 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
 ) -> Result<BasicValueEnum<'ctx>, String> {
     const FN_NAME: &str = "np_nextafter";
 
-    Ok(match (x1, x2) {
-        (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
-            debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
-            debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
+    let x1 = ScalarOrNDArray::from_value(generator, ctx, (x1_ty, x1));
+    let x2 = ScalarOrNDArray::from_value(generator, ctx, (x2_ty, x2));
 
-            extern_fns::call_nextafter(ctx, x1, x2, None).into()
-        }
+    let result = ScalarOrNDArray::broadcasting_starmap(
+        generator,
+        ctx,
+        &[x1, x2],
+        ctx.ctx.f64_type().into(),
+        |_, ctx, scalars| {
+            let x1_scalar = scalars[0];
+            let x2_scalar = scalars[1];
 
-        (x1, x2)
-            if [&x1_ty, &x2_ty].into_iter().any(|ty| {
-                ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
-            }) =>
-        {
-            let is_ndarray1 =
-                x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
-            let is_ndarray2 =
-                x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
+            match (x1_scalar, x2_scalar) {
+                (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
+                    Ok(extern_fns::call_nextafter(ctx, x1, x2, None).into())
+                }
+                _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
+            }
+        },
+    )
+    .unwrap();
 
-            let dtype = if is_ndarray1 && is_ndarray2 {
-                let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
-                let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
-
-                debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
-
-                ndarray_dtype1
-            } else if is_ndarray1 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
-            } else if is_ndarray2 {
-                unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
-            } else {
-                codegen_unreachable!(ctx)
-            };
-
-            let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
-            let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty };
-
-            numpy::ndarray_elementwise_binop_impl(
-                generator,
-                ctx,
-                dtype,
-                None,
-                (x1_ty, x1, !is_ndarray1),
-                (x2_ty, x2, !is_ndarray2),
-                |generator, ctx, (lhs, rhs)| {
-                    call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
-                },
-            )?
-            .as_base_value()
-            .into()
-        }
-
-        _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
-    })
+    Ok(result.to_basic_value_enum())
 }
 
 /// Invokes the `np_linalg_cholesky` linalg function