use crate::typecheck::typedef::Type;

use super::{
    classes::{
        ArrayLikeIndexer,
        ArrayLikeValue,
        ArraySliceValue,
        ListValue,
        NDArrayValue,
        TypedArrayLikeAdapter,
        UntypedArrayLikeAccessor,
    },
    CodeGenContext,
    CodeGenerator,
    llvm_intrinsics,
};
use inkwell::{
    attributes::{Attribute, AttributeLoc},
    context::Context,
    memory_buffer::MemoryBuffer,
    module::Module,
    types::{BasicTypeEnum, IntType},
    values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
    AddressSpace, IntPredicate,
};
use itertools::Either;
use nac3parser::ast::Expr;
use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing;

#[must_use]
pub fn load_irrt(ctx: &Context) -> Module {
    let bitcode_buf = MemoryBuffer::create_from_memory_range(
        include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
        "irrt_bitcode_buffer",
    );
    let irrt_mod = Module::parse_bitcode_from_buffer(&bitcode_buf, ctx).unwrap();
    let inline_attr = Attribute::get_named_enum_kind_id("alwaysinline");
    for symbol in &[
        "__nac3_int_exp_int32_t",
        "__nac3_int_exp_int64_t",
        "__nac3_range_slice_len",
        "__nac3_slice_index_bound",
    ] {
        let function = irrt_mod.get_function(symbol).unwrap();
        function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
    }
    irrt_mod
}

// repeated squaring method adapted from GNU Scientific Library:
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
    generator: &mut G,
    ctx: &mut CodeGenContext<'ctx, '_>,
    base: IntValue<'ctx>,
    exp: IntValue<'ctx>,
    signed: bool,
) -> IntValue<'ctx> {
    let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
        (32, 32, true) => "__nac3_int_exp_int32_t",
        (64, 64, true) => "__nac3_int_exp_int64_t",
        (32, 32, false) => "__nac3_int_exp_uint32_t",
        (64, 64, false) => "__nac3_int_exp_uint64_t",
        _ => unreachable!(),
    };
    let base_type = base.get_type();
    let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
        let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
        ctx.module.add_function(symbol, fn_type, None)
    });
    // throw exception when exp < 0
    let ge_zero = ctx.builder.build_int_compare(
        IntPredicate::SGE,
        exp,
        exp.get_type().const_zero(),
        "assert_int_pow_ge_0",
    ).unwrap();
    ctx.make_assert(
        generator,
        ge_zero,
        "0:ValueError",
        "integer power must be positive or zero",
        [None, None, None],
        ctx.current_loc,
    );
    ctx.builder
        .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
        .map(CallSiteValue::try_as_basic_value)
        .map(|v| v.map_left(BasicValueEnum::into_int_value))
        .map(Either::unwrap_left)
        .unwrap()
}

pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
    generator: &mut G,
    ctx: &mut CodeGenContext<'ctx, '_>,
    start: IntValue<'ctx>,
    end: IntValue<'ctx>,
    step: IntValue<'ctx>,
) -> IntValue<'ctx> {
    const SYMBOL: &str = "__nac3_range_slice_len";
    let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
        let i32_t = ctx.ctx.i32_type();
        let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
        ctx.module.add_function(SYMBOL, fn_t, None)
    });

    // assert step != 0, throw exception if not
    let not_zero = ctx.builder.build_int_compare(
        IntPredicate::NE,
        step,
        step.get_type().const_zero(),
        "range_step_ne",
    ).unwrap();
    ctx.make_assert(
        generator,
        not_zero,
        "0:ValueError",
        "step must not be zero",
        [None, None, None],
        ctx.current_loc,
    );
    ctx.builder
        .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
        .map(CallSiteValue::try_as_basic_value)
        .map(|v| v.map_left(BasicValueEnum::into_int_value))
        .map(Either::unwrap_left)
        .unwrap()
}

/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
/// NO numeric slice in python.
///
/// equivalent code:
/// ```pseudo_code
/// match (start, end, step):
///     case (s, e, None | Some(step)) if step > 0:
///         return (
///             match s:
///                 case None:
///                     0
///                 case Some(s):
///                     handle_in_bound(s)
///             ,match e:
///                 case None:
///                     length - 1
///                 case Some(e):
///                     handle_in_bound(e) - 1
///             ,step == None ? 1 : step
///         )
///     case (s, e, Some(step)) if step < 0:
///         return (
///             match s:
///                 case None:
///                     length - 1
///                 case Some(s):
///                     s = handle_in_bound(s)
///                     if s == length:
///                         s - 1
///                     else:
///                         s
///             ,match e:
///                 case None:
///                     0
///                 case Some(e):
///                     handle_in_bound(e) + 1
///             ,step
///         )
/// ```
pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
    start: &Option<Box<Expr<Option<Type>>>>,
    end: &Option<Box<Expr<Option<Type>>>>,
    step: &Option<Box<Expr<Option<Type>>>>,
    ctx: &mut CodeGenContext<'ctx, '_>,
    generator: &mut G,
    list: ListValue<'ctx>,
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
    let int32 = ctx.ctx.i32_type();
    let zero = int32.const_zero();
    let one = int32.const_int(1, false);
    let length = list.load_size(ctx, Some("length"));
    let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
    Ok(Some(match (start, end, step) {
        (s, e, None) => (
            if let Some(s) = s.as_ref() {
                match handle_slice_index_bound(s, ctx, generator, length)? {
                    Some(v) => v,
                    None => return Ok(None),
                }
            } else {
                int32.const_zero()
            },
            {
                let e = if let Some(s) = e.as_ref() {
                    match handle_slice_index_bound(s, ctx, generator, length)? {
                        Some(v) => v,
                        None => return Ok(None),
                    }
                } else {
                    length
                };
                ctx.builder.build_int_sub(e, one, "final_end").unwrap()
            },
            one,
        ),
        (s, e, Some(step)) => {
            let step = if let Some(v) = generator.gen_expr(ctx, step)? {
                v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value()
            } else {
                return Ok(None)
            };
            // assert step != 0, throw exception if not
            let not_zero = ctx.builder.build_int_compare(
                IntPredicate::NE,
                step,
                step.get_type().const_zero(),
                "range_step_ne",
            ).unwrap();
            ctx.make_assert(
                generator,
                not_zero,
                "0:ValueError",
                "slice step cannot be zero",
                [None, None, None],
                ctx.current_loc,
            );
            let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap();
            let neg = ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg").unwrap();
            (
                match s {
                    Some(s) => {
                        let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else {
                            return Ok(None)
                        };
                        ctx.builder
                            .build_select(
                                ctx.builder.build_and(
                                    ctx.builder.build_int_compare(
                                        IntPredicate::EQ,
                                        s,
                                        length,
                                        "s_eq_len",
                                    ).unwrap(),
                                    neg,
                                    "should_minus_one",
                                ).unwrap(),
                                ctx.builder.build_int_sub(s, one, "s_min").unwrap(),
                                s,
                                "final_start",
                            )
                            .map(BasicValueEnum::into_int_value)
                            .unwrap()
                    }
                    None => ctx.builder.build_select(neg, len_id, zero, "stt")
                        .map(BasicValueEnum::into_int_value)
                        .unwrap(),
                },
                match e {
                    Some(e) => {
                        let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else {
                            return Ok(None)
                        };
                        ctx.builder
                            .build_select(
                                neg,
                                ctx.builder.build_int_add(e, one, "end_add_one").unwrap(),
                                ctx.builder.build_int_sub(e, one, "end_sub_one").unwrap(),
                                "final_end",
                            )
                            .map(BasicValueEnum::into_int_value)
                            .unwrap()
                    }
                    None => ctx.builder.build_select(neg, zero, len_id, "end")
                        .map(BasicValueEnum::into_int_value)
                        .unwrap(),
                },
                step,
            )
        }
    }))
}

/// this function allows index out of range, since python
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
    i: &Expr<Option<Type>>,
    ctx: &mut CodeGenContext<'ctx, '_>,
    generator: &mut G,
    length: IntValue<'ctx>,
) -> Result<Option<IntValue<'ctx>>, String> {
    const SYMBOL: &str = "__nac3_slice_index_bound";
    let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
        let i32_t = ctx.ctx.i32_type();
        let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
        ctx.module.add_function(SYMBOL, fn_t, None)
    });

    let i = if let Some(v) = generator.gen_expr(ctx, i)? {
        v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
    } else {
        return Ok(None)
    };
    Ok(Some(ctx
        .builder
        .build_call(func, &[i.into(), length.into()], "bounded_ind")
        .map(CallSiteValue::try_as_basic_value)
        .map(|v| v.map_left(BasicValueEnum::into_int_value))
        .map(Either::unwrap_left)
        .unwrap()))
}

/// This function handles 'end' **inclusively**.
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
/// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
    generator: &mut G,
    ctx: &mut CodeGenContext<'ctx, '_>,
    ty: BasicTypeEnum<'ctx>,
    dest_arr: ListValue<'ctx>,
    dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
    src_arr: ListValue<'ctx>,
    src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) {
    let size_ty = generator.get_size_type(ctx.ctx);
    let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
    let int32 = ctx.ctx.i32_type();
    let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
    let slice_assign_fun = {
        let ty_vec = vec![
            int32.into(),         // dest start idx
            int32.into(),         // dest end idx
            int32.into(),         // dest step
            elem_ptr_type.into(), // dest arr ptr
            int32.into(),         // dest arr len
            int32.into(),         // src start idx
            int32.into(),         // src end idx
            int32.into(),         // src step
            elem_ptr_type.into(), // src arr ptr
            int32.into(),         // src arr len
            int32.into(),         // size
        ];
        ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
            let fn_t = int32.fn_type(ty_vec.as_slice(), false);
            ctx.module.add_function(fun_symbol, fn_t, None)
        })
    };

    let zero = int32.const_zero();
    let one = int32.const_int(1, false);
    let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
    let dest_arr_ptr = ctx.builder.build_pointer_cast(
        dest_arr_ptr,
        elem_ptr_type,
        "dest_arr_ptr_cast",
    ).unwrap();
    let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
    let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
    let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
    let src_arr_ptr = ctx.builder.build_pointer_cast(
        src_arr_ptr,
        elem_ptr_type,
        "src_arr_ptr_cast",
    ).unwrap();
    let src_len = src_arr.load_size(ctx, Some("src.len"));
    let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();

    // index in bound and positive should be done
    // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
    // throw exception if not satisfied
    let src_end = ctx.builder
        .build_select(
            ctx.builder.build_int_compare(
                IntPredicate::SLT,
                src_idx.2,
                zero,
                "is_neg",
            ).unwrap(),
            ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
            ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
            "final_e",
        )
        .map(BasicValueEnum::into_int_value)
        .unwrap();
    let dest_end = ctx.builder
        .build_select(
            ctx.builder.build_int_compare(
                IntPredicate::SLT,
                dest_idx.2,
                zero,
                "is_neg",
            ).unwrap(),
            ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
            ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
            "final_e",
        )
        .map(BasicValueEnum::into_int_value)
        .unwrap();
    let src_slice_len =
        calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
    let dest_slice_len =
        calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
    let src_eq_dest = ctx.builder.build_int_compare(
        IntPredicate::EQ,
        src_slice_len,
        dest_slice_len,
        "slice_src_eq_dest",
    ).unwrap();
    let src_slt_dest = ctx.builder.build_int_compare(
        IntPredicate::SLT,
        src_slice_len,
        dest_slice_len,
        "slice_src_slt_dest",
    ).unwrap();
    let dest_step_eq_one = ctx.builder.build_int_compare(
        IntPredicate::EQ,
        dest_idx.2,
        dest_idx.2.get_type().const_int(1, false),
        "slice_dest_step_eq_one",
    ).unwrap();
    let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
    let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
    ctx.make_assert(
        generator,
        cond,
        "0:ValueError",
        "attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
        [Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
        ctx.current_loc,
    );

    let new_len = {
        let args = vec![
            dest_idx.0.into(),   // dest start idx
            dest_idx.1.into(),   // dest end idx
            dest_idx.2.into(),   // dest step
            dest_arr_ptr.into(), // dest arr ptr
            dest_len.into(),     // dest arr len
            src_idx.0.into(),    // src start idx
            src_idx.1.into(),    // src end idx
            src_idx.2.into(),    // src step
            src_arr_ptr.into(),  // src arr ptr
            src_len.into(),      // src arr len
            {
                let s = match ty {
                    BasicTypeEnum::FloatType(t) => t.size_of(),
                    BasicTypeEnum::IntType(t) => t.size_of(),
                    BasicTypeEnum::PointerType(t) => t.size_of(),
                    BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
                    _ => unreachable!(),
                };
                ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
            }
            .into(),
        ];
        ctx.builder
            .build_call(slice_assign_fun, args.as_slice(), "slice_assign")
            .map(CallSiteValue::try_as_basic_value)
            .map(|v| v.map_left(BasicValueEnum::into_int_value))
            .map(Either::unwrap_left)
            .unwrap()
    };
    // update length
    let need_update = ctx.builder
        .build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update")
        .unwrap();
    let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
    let update_bb = ctx.ctx.append_basic_block(current, "update");
    let cont_bb = ctx.ctx.append_basic_block(current, "cont");
    ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
    ctx.builder.position_at_end(update_bb);
    let new_len = ctx.builder
        .build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len")
        .unwrap();
    dest_arr.store_size(ctx, generator, new_len);
    ctx.builder.build_unconditional_branch(cont_bb).unwrap();
    ctx.builder.position_at_end(cont_bb);
}

/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
    generator: &mut G,
    ctx: &CodeGenContext<'ctx, '_>,
    v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
    let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
        let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
        ctx.module.add_function("__nac3_isinf", fn_type, None)
    });

    let ret = ctx.builder
        .build_call(intrinsic_fn, &[v.into()], "isinf")
        .map(CallSiteValue::try_as_basic_value)
        .map(|v| v.map_left(BasicValueEnum::into_int_value))
        .map(Either::unwrap_left)
        .unwrap();

    generator.bool_to_i1(ctx, ret)
}

/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
    generator: &mut G,
    ctx: &CodeGenContext<'ctx, '_>,
    v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
    let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
        let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
        ctx.module.add_function("__nac3_isnan", fn_type, None)
    });

    let ret = ctx.builder
        .build_call(intrinsic_fn, &[v.into()], "isnan")
        .map(CallSiteValue::try_as_basic_value)
        .map(|v| v.map_left(BasicValueEnum::into_int_value))
        .map(Either::unwrap_left)
        .unwrap();

    generator.bool_to_i1(ctx, ret)
}

/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
pub fn call_gamma<'ctx>(
    ctx: &CodeGenContext<'ctx, '_>,
    v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
    let llvm_f64 = ctx.ctx.f64_type();

    let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
        let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
        ctx.module.add_function("__nac3_gamma", fn_type, None)
    });

    ctx.builder
        .build_call(intrinsic_fn, &[v.into()], "gamma")
        .map(CallSiteValue::try_as_basic_value)
        .map(|v| v.map_left(BasicValueEnum::into_float_value))
        .map(Either::unwrap_left)
        .unwrap()
}

/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
pub fn call_gammaln<'ctx>(
    ctx: &CodeGenContext<'ctx, '_>,
    v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
    let llvm_f64 = ctx.ctx.f64_type();

    let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
        let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
        ctx.module.add_function("__nac3_gammaln", fn_type, None)
    });

    ctx.builder
        .build_call(intrinsic_fn, &[v.into()], "gammaln")
        .map(CallSiteValue::try_as_basic_value)
        .map(|v| v.map_left(BasicValueEnum::into_float_value))
        .map(Either::unwrap_left)
        .unwrap()
}

/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>(
    ctx: &CodeGenContext<'ctx, '_>,
    v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
    let llvm_f64 = ctx.ctx.f64_type();

    let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
        let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
        ctx.module.add_function("__nac3_j0", fn_type, None)
    });

    ctx.builder
        .build_call(intrinsic_fn, &[v.into()], "j0")
        .map(CallSiteValue::try_as_basic_value)
        .map(|v| v.map_left(BasicValueEnum::into_float_value))
        .map(Either::unwrap_left)
        .unwrap()
}

/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, 
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
    generator: &G,
    ctx: &CodeGenContext<'ctx, '_>,
    dims: &Dims,
    (begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
    where
        G: CodeGenerator + ?Sized,
        Dims: ArrayLikeIndexer<'ctx>, {
    let llvm_i64 = ctx.ctx.i64_type();
    let llvm_usize = generator.get_size_type(ctx.ctx);

    let llvm_pi64 = llvm_i64.ptr_type(AddressSpace::default());

    let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
        32 => "__nac3_ndarray_calc_size",
        64 => "__nac3_ndarray_calc_size64",
        bw => unreachable!("Unsupported size type bit width: {}", bw)
    };
    let ndarray_calc_size_fn_t = llvm_usize.fn_type(
        &[
            llvm_pi64.into(),
            llvm_usize.into(),
            llvm_usize.into(),
            llvm_usize.into(),
        ],
        false,
    );
    let ndarray_calc_size_fn = ctx.module.get_function(ndarray_calc_size_fn_name)
        .unwrap_or_else(|| {
            ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
        });

    let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
    let end = end.unwrap_or_else(|| dims.size(ctx, generator));
    ctx.builder
        .build_call(
            ndarray_calc_size_fn,
            &[
                dims.base_ptr(ctx, generator).into(),
                dims.size(ctx, generator).into(),
                begin.into(),
                end.into(),
            ],
            "",
        )
        .map(CallSiteValue::try_as_basic_value)
        .map(|v| v.map_left(BasicValueEnum::into_int_value))
        .map(Either::unwrap_left)
        .unwrap()
}

/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
    generator: &G,
    ctx: &mut CodeGenContext<'ctx, '_>,
    index: IntValue<'ctx>,
    ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
    let llvm_void = ctx.ctx.void_type();
    let llvm_i32 = ctx.ctx.i32_type();
    let llvm_usize = generator.get_size_type(ctx.ctx);
    let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
    let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());

    let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
        32 => "__nac3_ndarray_calc_nd_indices",
        64 => "__nac3_ndarray_calc_nd_indices64",
        bw => unreachable!("Unsupported size type bit width: {}", bw)
    };
    let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
        let fn_type = llvm_void.fn_type(
            &[
                llvm_usize.into(),
                llvm_pusize.into(),
                llvm_usize.into(),
                llvm_pi32.into(),
            ],
            false,
        );

        ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
    });

    let ndarray_num_dims = ndarray.load_ndims(ctx);
    let ndarray_dims = ndarray.dim_sizes();

    let indices = ctx.builder.build_array_alloca(
        llvm_i32,
        ndarray_num_dims,
        "",
    ).unwrap();

    ctx.builder
        .build_call(
            ndarray_calc_nd_indices_fn,
            &[
                index.into(),
                ndarray_dims.base_ptr(ctx, generator).into(),
                ndarray_num_dims.into(),
                indices.into(),
            ],
            "",
        )
        .unwrap();

    TypedArrayLikeAdapter::from(
        ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
        Box::new(|_, v| v.into_int_value()),
        Box::new(|_, v| v.into()),
    )
}

fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
    generator: &G,
    ctx: &CodeGenContext<'ctx, '_>,
    ndarray: NDArrayValue<'ctx>,
    indices: &Indices,
) -> IntValue<'ctx>
    where
        G: CodeGenerator + ?Sized,
        Indices: ArrayLikeIndexer<'ctx>, {
    let llvm_i32 = ctx.ctx.i32_type();
    let llvm_usize = generator.get_size_type(ctx.ctx);

    let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
    let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());

    debug_assert_eq!(
        IntType::try_from(indices.element_type(ctx, generator))
            .map(IntType::get_bit_width)
            .unwrap_or_default(),
        llvm_i32.get_bit_width(),
        "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
    );
    debug_assert_eq!(
        indices.size(ctx, generator).get_type().get_bit_width(),
        llvm_usize.get_bit_width(),
        "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
    );

    let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
        32 => "__nac3_ndarray_flatten_index",
        64 => "__nac3_ndarray_flatten_index64",
        bw => unreachable!("Unsupported size type bit width: {}", bw)
    };
    let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
        let fn_type = llvm_usize.fn_type(
            &[
                llvm_pusize.into(),
                llvm_usize.into(),
                llvm_pi32.into(),
                llvm_usize.into(),
            ],
            false,
        );

        ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
    });

    let ndarray_num_dims = ndarray.load_ndims(ctx);
    let ndarray_dims = ndarray.dim_sizes();

    let index = ctx.builder
        .build_call(
            ndarray_flatten_index_fn,
            &[
                ndarray_dims.base_ptr(ctx, generator).into(),
                ndarray_num_dims.into(),
                indices.base_ptr(ctx, generator).into(),
                indices.size(ctx, generator).into(),
            ],
            "",
        )
        .map(CallSiteValue::try_as_basic_value)
        .map(|v| v.map_left(BasicValueEnum::into_int_value))
        .map(Either::unwrap_left)
        .unwrap();

    index
}

/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
    generator: &mut G,
    ctx: &mut CodeGenContext<'ctx, '_>,
    ndarray: NDArrayValue<'ctx>,
    indices: &Index,
) -> IntValue<'ctx>
    where
        G: CodeGenerator + ?Sized,
        Index: ArrayLikeIndexer<'ctx>, {

    call_ndarray_flatten_index_impl(
        generator,
        ctx,
        ndarray,
        indices,
    )
}

/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
    generator: &mut G,
    ctx: &mut CodeGenContext<'ctx, '_>,
    lhs: NDArrayValue<'ctx>,
    rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
    let llvm_usize = generator.get_size_type(ctx.ctx);
    let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());

    let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
        32 => "__nac3_ndarray_calc_broadcast",
        64 => "__nac3_ndarray_calc_broadcast64",
        bw => unreachable!("Unsupported size type bit width: {}", bw)
    };
    let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
        let fn_type = llvm_usize.fn_type(
            &[
                llvm_pusize.into(),
                llvm_usize.into(),
                llvm_pusize.into(),
                llvm_usize.into(),
                llvm_pusize.into(),
            ],
            false,
        );

        ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
    });

    let lhs_ndims = lhs.load_ndims(ctx);
    let rhs_ndims = rhs.load_ndims(ctx);
    let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);

    gen_for_callback_incrementing(
        generator,
        ctx,
        llvm_usize.const_zero(),
        (min_ndims, false),
        |generator, ctx, idx| {
            let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
            let (lhs_dim_sz, rhs_dim_sz) = unsafe {
                (
                    lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
                    rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
                )
            };

            let llvm_usize_const_one = llvm_usize.const_int(1, false);
            let lhs_eqz = ctx.builder.build_int_compare(
                IntPredicate::EQ,
                lhs_dim_sz,
                llvm_usize_const_one,
                "",
            ).unwrap();
            let rhs_eqz = ctx.builder.build_int_compare(
                IntPredicate::EQ,
                rhs_dim_sz,
                llvm_usize_const_one,
                "",
            ).unwrap();
            let lhs_or_rhs_eqz = ctx.builder.build_or(
                lhs_eqz,
                rhs_eqz,
                ""
            ).unwrap();

            let lhs_eq_rhs = ctx.builder.build_int_compare(
                IntPredicate::EQ,
                lhs_dim_sz,
                rhs_dim_sz,
                ""
            ).unwrap();

            let is_compatible = ctx.builder.build_or(
                lhs_or_rhs_eqz,
                lhs_eq_rhs,
                ""
            ).unwrap();

            ctx.make_assert(
                generator,
                is_compatible,
                "0:ValueError",
                "operands could not be broadcast together",
                [None, None, None],
                ctx.current_loc,
            );

            Ok(())
        },
        llvm_usize.const_int(1, false),
    ).unwrap();

    let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
    let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
    let lhs_ndims = lhs.load_ndims(ctx);
    let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
    let rhs_ndims = rhs.load_ndims(ctx);
    let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
    let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);

    ctx.builder
        .build_call(
            ndarray_calc_broadcast_fn,
            &[
                lhs_dims.into(),
                lhs_ndims.into(),
                rhs_dims.into(),
                rhs_ndims.into(),
                out_dims.base_ptr(ctx, generator).into(),
            ],
            "",
        )
        .unwrap();

    TypedArrayLikeAdapter::from(
        out_dims,
        Box::new(|_, v| v.into_int_value()),
        Box::new(|_, v| v.into()),
    )
}

/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>>(
    generator: &mut G,
    ctx: &mut CodeGenContext<'ctx, '_>,
    array: NDArrayValue<'ctx>,
    broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
    let llvm_i32 = ctx.ctx.i32_type();
    let llvm_usize = generator.get_size_type(ctx.ctx);
    let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
    let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());

    let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
        32 => "__nac3_ndarray_calc_broadcast_idx",
        64 => "__nac3_ndarray_calc_broadcast_idx64",
        bw => unreachable!("Unsupported size type bit width: {}", bw)
    };
    let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
        let fn_type = llvm_usize.fn_type(
            &[
                llvm_pusize.into(),
                llvm_usize.into(),
                llvm_pi32.into(),
                llvm_pi32.into(),
            ],
            false,
        );

        ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
    });

    let broadcast_size = broadcast_idx.size(ctx, generator);
    let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();

    let array_dims = array.dim_sizes().base_ptr(ctx, generator);
    let array_ndims = array.load_ndims(ctx);
    let broadcast_idx_ptr = unsafe {
        broadcast_idx.ptr_offset_unchecked(
            ctx,
            generator,
            &llvm_usize.const_zero(),
            None
        )
    };

    ctx.builder
        .build_call(
            ndarray_calc_broadcast_fn,
            &[
                array_dims.into(),
                array_ndims.into(),
                broadcast_idx_ptr.into(),
                out_idx.into(),
            ],
            "",
        )
        .unwrap();

    TypedArrayLikeAdapter::from(
        ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
        Box::new(|_, v| v.into_int_value()),
        Box::new(|_, v| v.into()),
    )
}