use inkwell::{ attributes::{Attribute, AttributeLoc}, context::Context, memory_buffer::MemoryBuffer, module::Module, types::{BasicTypeEnum, IntType}, values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue}, AddressSpace, IntPredicate, }; use itertools::Either; use nac3parser::ast::Expr; use super::{ llvm_intrinsics, macros::codegen_unreachable, stmt::gen_for_callback_incrementing, values::{ ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, }; use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type}; #[must_use] pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> { let bitcode_buf = MemoryBuffer::create_from_memory_range( include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")), "irrt_bitcode_buffer", ); let irrt_mod = Module::parse_bitcode_from_buffer(&bitcode_buf, ctx).unwrap(); let inline_attr = Attribute::get_named_enum_kind_id("alwaysinline"); for symbol in &[ "__nac3_int_exp_int32_t", "__nac3_int_exp_int64_t", "__nac3_range_slice_len", "__nac3_slice_index_bound", ] { let function = irrt_mod.get_function(symbol).unwrap(); function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0)); } // Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`]. let exn_id_type = ctx.i32_type(); let errors = &[ ("EXN_INDEX_ERROR", "0:IndexError"), ("EXN_VALUE_ERROR", "0:ValueError"), ("EXN_ASSERTION_ERROR", "0:AssertionError"), ("EXN_TYPE_ERROR", "0:TypeError"), ]; for (irrt_name, symbol_name) in errors { let exn_id = symbol_resolver.get_string_id(symbol_name); let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum(); let global = irrt_mod.get_global(irrt_name).unwrap_or_else(|| { panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module") }); global.set_initializer(&exn_id); } irrt_mod } // repeated squaring method adapted from GNU Scientific Library: // https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, base: IntValue<'ctx>, exp: IntValue<'ctx>, signed: bool, ) -> IntValue<'ctx> { let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) { (32, 32, true) => "__nac3_int_exp_int32_t", (64, 64, true) => "__nac3_int_exp_int64_t", (32, 32, false) => "__nac3_int_exp_uint32_t", (64, 64, false) => "__nac3_int_exp_uint64_t", _ => codegen_unreachable!(ctx), }; let base_type = base.get_type(); let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| { let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false); ctx.module.add_function(symbol, fn_type, None) }); // throw exception when exp < 0 let ge_zero = ctx .builder .build_int_compare( IntPredicate::SGE, exp, exp.get_type().const_zero(), "assert_int_pow_ge_0", ) .unwrap(); ctx.make_assert( generator, ge_zero, "0:ValueError", "integer power must be positive or zero", [None, None, None], ctx.current_loc, ); ctx.builder .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(Either::unwrap_left) .unwrap() } pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, start: IntValue<'ctx>, end: IntValue<'ctx>, step: IntValue<'ctx>, ) -> IntValue<'ctx> { const SYMBOL: &str = "__nac3_range_slice_len"; let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { let i32_t = ctx.ctx.i32_type(); let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false); ctx.module.add_function(SYMBOL, fn_t, None) }); // assert step != 0, throw exception if not let not_zero = ctx .builder .build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne") .unwrap(); ctx.make_assert( generator, not_zero, "0:ValueError", "step must not be zero", [None, None, None], ctx.current_loc, ); ctx.builder .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(Either::unwrap_left) .unwrap() } /// NOTE: the output value of the end index of this function should be compared ***inclusively***, /// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to /// NO numeric slice in python. /// /// equivalent code: /// ```pseudo_code /// match (start, end, step): /// case (s, e, None | Some(step)) if step > 0: /// return ( /// match s: /// case None: /// 0 /// case Some(s): /// handle_in_bound(s) /// ,match e: /// case None: /// length - 1 /// case Some(e): /// handle_in_bound(e) - 1 /// ,step == None ? 1 : step /// ) /// case (s, e, Some(step)) if step < 0: /// return ( /// match s: /// case None: /// length - 1 /// case Some(s): /// s = handle_in_bound(s) /// if s == length: /// s - 1 /// else: /// s /// ,match e: /// case None: /// 0 /// case Some(e): /// handle_in_bound(e) + 1 /// ,step /// ) /// ``` pub fn handle_slice_indices<'ctx, G: CodeGenerator>( start: &Option>>>, end: &Option>>>, step: &Option>>>, ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut G, length: IntValue<'ctx>, ) -> Result, IntValue<'ctx>, IntValue<'ctx>)>, String> { let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); let one = int32.const_int(1, false); let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap(); Ok(Some(match (start, end, step) { (s, e, None) => ( if let Some(s) = s.as_ref() { match handle_slice_index_bound(s, ctx, generator, length)? { Some(v) => v, None => return Ok(None), } } else { int32.const_zero() }, { let e = if let Some(s) = e.as_ref() { match handle_slice_index_bound(s, ctx, generator, length)? { Some(v) => v, None => return Ok(None), } } else { length }; ctx.builder.build_int_sub(e, one, "final_end").unwrap() }, one, ), (s, e, Some(step)) => { let step = if let Some(v) = generator.gen_expr(ctx, step)? { v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value() } else { return Ok(None); }; // assert step != 0, throw exception if not let not_zero = ctx .builder .build_int_compare( IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne", ) .unwrap(); ctx.make_assert( generator, not_zero, "0:ValueError", "slice step cannot be zero", [None, None, None], ctx.current_loc, ); let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap(); let neg = ctx .builder .build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg") .unwrap(); ( match s { Some(s) => { let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else { return Ok(None); }; ctx.builder .build_select( ctx.builder .build_and( ctx.builder .build_int_compare( IntPredicate::EQ, s, length, "s_eq_len", ) .unwrap(), neg, "should_minus_one", ) .unwrap(), ctx.builder.build_int_sub(s, one, "s_min").unwrap(), s, "final_start", ) .map(BasicValueEnum::into_int_value) .unwrap() } None => ctx .builder .build_select(neg, len_id, zero, "stt") .map(BasicValueEnum::into_int_value) .unwrap(), }, match e { Some(e) => { let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else { return Ok(None); }; ctx.builder .build_select( neg, ctx.builder.build_int_add(e, one, "end_add_one").unwrap(), ctx.builder.build_int_sub(e, one, "end_sub_one").unwrap(), "final_end", ) .map(BasicValueEnum::into_int_value) .unwrap() } None => ctx .builder .build_select(neg, zero, len_id, "end") .map(BasicValueEnum::into_int_value) .unwrap(), }, step, ) } })) } /// this function allows index out of range, since python /// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`). pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>( i: &Expr>, ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut G, length: IntValue<'ctx>, ) -> Result>, String> { const SYMBOL: &str = "__nac3_slice_index_bound"; let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { let i32_t = ctx.ctx.i32_type(); let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false); ctx.module.add_function(SYMBOL, fn_t, None) }); let i = if let Some(v) = generator.gen_expr(ctx, i)? { v.to_basic_value_enum(ctx, generator, i.custom.unwrap())? } else { return Ok(None); }; Ok(Some( ctx.builder .build_call(func, &[i.into(), length.into()], "bounded_ind") .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(Either::unwrap_left) .unwrap(), )) } /// This function handles 'end' **inclusively**. /// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step'). /// Negative index should be handled before entering this function pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ty: BasicTypeEnum<'ctx>, dest_arr: ListValue<'ctx>, dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), src_arr: ListValue<'ctx>, src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), ) { let size_ty = generator.get_size_type(ctx.ctx); let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); let int32 = ctx.ctx.i32_type(); let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr); let slice_assign_fun = { let ty_vec = vec![ int32.into(), // dest start idx int32.into(), // dest end idx int32.into(), // dest step elem_ptr_type.into(), // dest arr ptr int32.into(), // dest arr len int32.into(), // src start idx int32.into(), // src end idx int32.into(), // src step elem_ptr_type.into(), // src arr ptr int32.into(), // src arr len int32.into(), // size ]; ctx.module.get_function(fun_symbol).unwrap_or_else(|| { let fn_t = int32.fn_type(ty_vec.as_slice(), false); ctx.module.add_function(fun_symbol, fn_t, None) }) }; let zero = int32.const_zero(); let one = int32.const_int(1, false); let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator); let dest_arr_ptr = ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap(); let dest_len = dest_arr.load_size(ctx, Some("dest.len")); let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap(); let src_arr_ptr = src_arr.data().base_ptr(ctx, generator); let src_arr_ptr = ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap(); let src_len = src_arr.load_size(ctx, Some("src.len")); let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap(); // index in bound and positive should be done // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and // throw exception if not satisfied let src_end = ctx .builder .build_select( ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(), ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(), ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(), "final_e", ) .map(BasicValueEnum::into_int_value) .unwrap(); let dest_end = ctx .builder .build_select( ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(), ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(), ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(), "final_e", ) .map(BasicValueEnum::into_int_value) .unwrap(); let src_slice_len = calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2); let dest_slice_len = calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2); let src_eq_dest = ctx .builder .build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest") .unwrap(); let src_slt_dest = ctx .builder .build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest") .unwrap(); let dest_step_eq_one = ctx .builder .build_int_compare( IntPredicate::EQ, dest_idx.2, dest_idx.2.get_type().const_int(1, false), "slice_dest_step_eq_one", ) .unwrap(); let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap(); let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap(); ctx.make_assert( generator, cond, "0:ValueError", "attempt to assign sequence of size {0} to slice of size {1} with step size {2}", [Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)], ctx.current_loc, ); let new_len = { let args = vec![ dest_idx.0.into(), // dest start idx dest_idx.1.into(), // dest end idx dest_idx.2.into(), // dest step dest_arr_ptr.into(), // dest arr ptr dest_len.into(), // dest arr len src_idx.0.into(), // src start idx src_idx.1.into(), // src end idx src_idx.2.into(), // src step src_arr_ptr.into(), // src arr ptr src_len.into(), // src arr len { let s = match ty { BasicTypeEnum::FloatType(t) => t.size_of(), BasicTypeEnum::IntType(t) => t.size_of(), BasicTypeEnum::PointerType(t) => t.size_of(), BasicTypeEnum::StructType(t) => t.size_of().unwrap(), _ => codegen_unreachable!(ctx), }; ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap() } .into(), ]; ctx.builder .build_call(slice_assign_fun, args.as_slice(), "slice_assign") .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(Either::unwrap_left) .unwrap() }; // update length let need_update = ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let update_bb = ctx.ctx.append_basic_block(current, "update"); let cont_bb = ctx.ctx.append_basic_block(current, "cont"); ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); ctx.builder.position_at_end(update_bb); let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap(); dest_arr.store_size(ctx, generator, new_len); ctx.builder.build_unconditional_branch(cont_bb).unwrap(); ctx.builder.position_at_end(cont_bb); } /// Generates a call to `isinf` in IR. Returns an `i1` representing the result. pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>, ) -> IntValue<'ctx> { let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); ctx.module.add_function("__nac3_isinf", fn_type, None) }); let ret = ctx .builder .build_call(intrinsic_fn, &[v.into()], "isinf") .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(Either::unwrap_left) .unwrap(); generator.bool_to_i1(ctx, ret) } /// Generates a call to `isnan` in IR. Returns an `i1` representing the result. pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>, ) -> IntValue<'ctx> { let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); ctx.module.add_function("__nac3_isnan", fn_type, None) }); let ret = ctx .builder .build_call(intrinsic_fn, &[v.into()], "isnan") .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(Either::unwrap_left) .unwrap(); generator.bool_to_i1(ctx, ret) } /// Generates a call to `gamma` in IR. Returns an `f64` representing the result. pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_gamma", fn_type, None) }); ctx.builder .build_call(intrinsic_fn, &[v.into()], "gamma") .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_float_value)) .map(Either::unwrap_left) .unwrap() } /// Generates a call to `gammaln` in IR. Returns an `f64` representing the result. pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_gammaln", fn_type, None) }); ctx.builder .build_call(intrinsic_fn, &[v.into()], "gammaln") .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_float_value)) .map(Either::unwrap_left) .unwrap() } /// Generates a call to `j0` in IR. Returns an `f64` representing the result. pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { let llvm_f64 = ctx.ctx.f64_type(); let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); ctx.module.add_function("__nac3_j0", fn_type, None) }); ctx.builder .build_call(intrinsic_fn, &[v.into()], "j0") .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_float_value)) .map(Either::unwrap_left) .unwrap() } /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the /// calculated total size. /// /// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. /// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, /// or [`None`] if starting from the first dimension and ending at the last dimension /// respectively. pub fn call_ndarray_calc_size<'ctx, G, Dims>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, dims: &Dims, (begin, end): (Option>, Option>), ) -> IntValue<'ctx> where G: CodeGenerator + ?Sized, Dims: ArrayLikeIndexer<'ctx>, { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_calc_size", 64 => "__nac3_ndarray_calc_size64", bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), }; let ndarray_calc_size_fn_t = llvm_usize.fn_type( &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], false, ); let ndarray_calc_size_fn = ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) }); let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); let end = end.unwrap_or_else(|| dims.size(ctx, generator)); ctx.builder .build_call( ndarray_calc_size_fn, &[ dims.base_ptr(ctx, generator).into(), dims.size(ctx, generator).into(), begin.into(), end.into(), ], "", ) .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(Either::unwrap_left) .unwrap() } /// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] /// containing `i32` indices of the flattened index. /// /// * `index` - The index to compute the multidimensional index for. /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an /// `NDArray`. pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( generator: &G, ctx: &mut CodeGenContext<'ctx, '_>, index: IntValue<'ctx>, ndarray: NDArrayValue<'ctx>, ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { let llvm_void = ctx.ctx.void_type(); let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_calc_nd_indices", 64 => "__nac3_ndarray_calc_nd_indices64", bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), }; let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { let fn_type = llvm_void.fn_type( &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], false, ); ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) }); let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_dims = ndarray.dim_sizes(); let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); ctx.builder .build_call( ndarray_calc_nd_indices_fn, &[ index.into(), ndarray_dims.base_ptr(ctx, generator).into(), ndarray_num_dims.into(), indices.into(), ], "", ) .unwrap(); TypedArrayLikeAdapter::from( ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), Box::new(|_, v| v.into_int_value()), Box::new(|_, v| v.into()), ) } fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, indices: &Indices, ) -> IntValue<'ctx> where G: CodeGenerator + ?Sized, Indices: ArrayLikeIndexer<'ctx>, { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); debug_assert_eq!( IntType::try_from(indices.element_type(ctx, generator)) .map(IntType::get_bit_width) .unwrap_or_default(), llvm_i32.get_bit_width(), "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" ); debug_assert_eq!( indices.size(ctx, generator).get_type().get_bit_width(), llvm_usize.get_bit_width(), "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" ); let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_flatten_index", 64 => "__nac3_ndarray_flatten_index64", bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), }; let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { let fn_type = llvm_usize.fn_type( &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], false, ); ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) }); let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_dims = ndarray.dim_sizes(); let index = ctx .builder .build_call( ndarray_flatten_index_fn, &[ ndarray_dims.base_ptr(ctx, generator).into(), ndarray_num_dims.into(), indices.base_ptr(ctx, generator).into(), indices.size(ctx, generator).into(), ], "", ) .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(Either::unwrap_left) .unwrap(); index } /// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the /// multidimensional index. /// /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an /// `NDArray`. /// * `indices` - The multidimensional index to compute the flattened index for. pub fn call_ndarray_flatten_index<'ctx, G, Index>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, indices: &Index, ) -> IntValue<'ctx> where G: CodeGenerator + ?Sized, Index: ArrayLikeIndexer<'ctx>, { call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) } /// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of /// dimension and size of each dimension of the resultant `ndarray`. pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, lhs: NDArrayValue<'ctx>, rhs: NDArrayValue<'ctx>, ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_calc_broadcast", 64 => "__nac3_ndarray_calc_broadcast64", bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), }; let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { let fn_type = llvm_usize.fn_type( &[ llvm_pusize.into(), llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pusize.into(), ], false, ); ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) }); let lhs_ndims = lhs.load_ndims(ctx); let rhs_ndims = rhs.load_ndims(ctx); let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); gen_for_callback_incrementing( generator, ctx, None, llvm_usize.const_zero(), (min_ndims, false), |generator, ctx, _, idx| { let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); let (lhs_dim_sz, rhs_dim_sz) = unsafe { ( lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), ) }; let llvm_usize_const_one = llvm_usize.const_int(1, false); let lhs_eqz = ctx .builder .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") .unwrap(); let rhs_eqz = ctx .builder .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") .unwrap(); let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); let lhs_eq_rhs = ctx .builder .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") .unwrap(); let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); ctx.make_assert( generator, is_compatible, "0:ValueError", "operands could not be broadcast together", [None, None, None], ctx.current_loc, ); Ok(()) }, llvm_usize.const_int(1, false), ) .unwrap(); let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); let lhs_ndims = lhs.load_ndims(ctx); let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); let rhs_ndims = rhs.load_ndims(ctx); let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); ctx.builder .build_call( ndarray_calc_broadcast_fn, &[ lhs_dims.into(), lhs_ndims.into(), rhs_dims.into(), rhs_ndims.into(), out_dims.base_ptr(ctx, generator).into(), ], "", ) .unwrap(); TypedArrayLikeAdapter::from( out_dims, Box::new(|_, v| v.into_int_value()), Box::new(|_, v| v.into()), ) } /// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] /// containing the indices used for accessing `array` corresponding to the index of the broadcasted /// array `broadcast_idx`. pub fn call_ndarray_calc_broadcast_index< 'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, >( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, array: NDArrayValue<'ctx>, broadcast_idx: &BroadcastIdx, ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_calc_broadcast_idx", 64 => "__nac3_ndarray_calc_broadcast_idx64", bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), }; let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { let fn_type = llvm_usize.fn_type( &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], false, ); ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) }); let broadcast_size = broadcast_idx.size(ctx, generator); let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); let array_dims = array.dim_sizes().base_ptr(ctx, generator); let array_ndims = array.load_ndims(ctx); let broadcast_idx_ptr = unsafe { broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; ctx.builder .build_call( ndarray_calc_broadcast_fn, &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], "", ) .unwrap(); TypedArrayLikeAdapter::from( ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), Box::new(|_, v| v.into_int_value()), Box::new(|_, v| v.into()), ) }