From 64ec66d3dd52c46ec965bda50bf37aadc7a936c7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 11 Nov 2024 16:16:23 +0800 Subject: [PATCH] [core] irrt: Break IRRT into several impl files Each IRRT file is now mapped to one Rust file. --- nac3core/src/codegen/irrt/list.rs | 162 ++++++ nac3core/src/codegen/irrt/math.rs | 152 ++++++ nac3core/src/codegen/irrt/mod.rs | 749 +-------------------------- nac3core/src/codegen/irrt/ndarray.rs | 384 ++++++++++++++ nac3core/src/codegen/irrt/slice.rs | 76 +++ 5 files changed, 786 insertions(+), 737 deletions(-) create mode 100644 nac3core/src/codegen/irrt/list.rs create mode 100644 nac3core/src/codegen/irrt/math.rs create mode 100644 nac3core/src/codegen/irrt/ndarray.rs create mode 100644 nac3core/src/codegen/irrt/slice.rs diff --git a/nac3core/src/codegen/irrt/list.rs b/nac3core/src/codegen/irrt/list.rs new file mode 100644 index 00000000..a7fec59d --- /dev/null +++ b/nac3core/src/codegen/irrt/list.rs @@ -0,0 +1,162 @@ +use inkwell::{ + types::BasicTypeEnum, + values::{BasicValueEnum, CallSiteValue, IntValue}, + AddressSpace, IntPredicate, +}; +use itertools::Either; + +use super::calculate_len_for_slice_range; +use crate::codegen::{ + macros::codegen_unreachable, + values::{ArrayLikeValue, ListValue}, + CodeGenContext, CodeGenerator, +}; + +/// This function handles 'end' **inclusively**. +/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step'). +/// Negative index should be handled before entering this function +pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: BasicTypeEnum<'ctx>, + dest_arr: ListValue<'ctx>, + dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), + src_arr: ListValue<'ctx>, + src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), +) { + let size_ty = generator.get_size_type(ctx.ctx); + let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + let int32 = ctx.ctx.i32_type(); + let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr); + let slice_assign_fun = { + let ty_vec = vec![ + int32.into(), // dest start idx + int32.into(), // dest end idx + int32.into(), // dest step + elem_ptr_type.into(), // dest arr ptr + int32.into(), // dest arr len + int32.into(), // src start idx + int32.into(), // src end idx + int32.into(), // src step + elem_ptr_type.into(), // src arr ptr + int32.into(), // src arr len + int32.into(), // size + ]; + ctx.module.get_function(fun_symbol).unwrap_or_else(|| { + let fn_t = int32.fn_type(ty_vec.as_slice(), false); + ctx.module.add_function(fun_symbol, fn_t, None) + }) + }; + + let zero = int32.const_zero(); + let one = int32.const_int(1, false); + let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator); + let dest_arr_ptr = + ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap(); + let dest_len = dest_arr.load_size(ctx, Some("dest.len")); + let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap(); + let src_arr_ptr = src_arr.data().base_ptr(ctx, generator); + let src_arr_ptr = + ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap(); + let src_len = src_arr.load_size(ctx, Some("src.len")); + let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap(); + + // index in bound and positive should be done + // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and + // throw exception if not satisfied + let src_end = ctx + .builder + .build_select( + ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(), + ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(), + ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(), + "final_e", + ) + .map(BasicValueEnum::into_int_value) + .unwrap(); + let dest_end = ctx + .builder + .build_select( + ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(), + ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(), + ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(), + "final_e", + ) + .map(BasicValueEnum::into_int_value) + .unwrap(); + let src_slice_len = + calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2); + let dest_slice_len = + calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2); + let src_eq_dest = ctx + .builder + .build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest") + .unwrap(); + let src_slt_dest = ctx + .builder + .build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest") + .unwrap(); + let dest_step_eq_one = ctx + .builder + .build_int_compare( + IntPredicate::EQ, + dest_idx.2, + dest_idx.2.get_type().const_int(1, false), + "slice_dest_step_eq_one", + ) + .unwrap(); + let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap(); + let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap(); + ctx.make_assert( + generator, + cond, + "0:ValueError", + "attempt to assign sequence of size {0} to slice of size {1} with step size {2}", + [Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)], + ctx.current_loc, + ); + + let new_len = { + let args = vec![ + dest_idx.0.into(), // dest start idx + dest_idx.1.into(), // dest end idx + dest_idx.2.into(), // dest step + dest_arr_ptr.into(), // dest arr ptr + dest_len.into(), // dest arr len + src_idx.0.into(), // src start idx + src_idx.1.into(), // src end idx + src_idx.2.into(), // src step + src_arr_ptr.into(), // src arr ptr + src_len.into(), // src arr len + { + let s = match ty { + BasicTypeEnum::FloatType(t) => t.size_of(), + BasicTypeEnum::IntType(t) => t.size_of(), + BasicTypeEnum::PointerType(t) => t.size_of(), + BasicTypeEnum::StructType(t) => t.size_of().unwrap(), + _ => codegen_unreachable!(ctx), + }; + ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap() + } + .into(), + ]; + ctx.builder + .build_call(slice_assign_fun, args.as_slice(), "slice_assign") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() + }; + // update length + let need_update = + ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let update_bb = ctx.ctx.append_basic_block(current, "update"); + let cont_bb = ctx.ctx.append_basic_block(current, "cont"); + ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); + ctx.builder.position_at_end(update_bb); + let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap(); + dest_arr.store_size(ctx, generator, new_len); + ctx.builder.build_unconditional_branch(cont_bb).unwrap(); + ctx.builder.position_at_end(cont_bb); +} diff --git a/nac3core/src/codegen/irrt/math.rs b/nac3core/src/codegen/irrt/math.rs new file mode 100644 index 00000000..4bc95913 --- /dev/null +++ b/nac3core/src/codegen/irrt/math.rs @@ -0,0 +1,152 @@ +use inkwell::{ + values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}, + IntPredicate, +}; +use itertools::Either; + +use crate::codegen::{ + macros::codegen_unreachable, + {CodeGenContext, CodeGenerator}, +}; + +// repeated squaring method adapted from GNU Scientific Library: +// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c +pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + base: IntValue<'ctx>, + exp: IntValue<'ctx>, + signed: bool, +) -> IntValue<'ctx> { + let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) { + (32, 32, true) => "__nac3_int_exp_int32_t", + (64, 64, true) => "__nac3_int_exp_int64_t", + (32, 32, false) => "__nac3_int_exp_uint32_t", + (64, 64, false) => "__nac3_int_exp_uint64_t", + _ => codegen_unreachable!(ctx), + }; + let base_type = base.get_type(); + let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| { + let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false); + ctx.module.add_function(symbol, fn_type, None) + }); + // throw exception when exp < 0 + let ge_zero = ctx + .builder + .build_int_compare( + IntPredicate::SGE, + exp, + exp.get_type().const_zero(), + "assert_int_pow_ge_0", + ) + .unwrap(); + ctx.make_assert( + generator, + ge_zero, + "0:ValueError", + "integer power must be positive or zero", + [None, None, None], + ctx.current_loc, + ); + ctx.builder + .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Generates a call to `isinf` in IR. Returns an `i1` representing the result. +pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + v: FloatValue<'ctx>, +) -> IntValue<'ctx> { + let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { + let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); + ctx.module.add_function("__nac3_isinf", fn_type, None) + }); + + let ret = ctx + .builder + .build_call(intrinsic_fn, &[v.into()], "isinf") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(); + + generator.bool_to_i1(ctx, ret) +} + +/// Generates a call to `isnan` in IR. Returns an `i1` representing the result. +pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + v: FloatValue<'ctx>, +) -> IntValue<'ctx> { + let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { + let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); + ctx.module.add_function("__nac3_isnan", fn_type, None) + }); + + let ret = ctx + .builder + .build_call(intrinsic_fn, &[v.into()], "isnan") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(); + + generator.bool_to_i1(ctx, ret) +} + +/// Generates a call to `gamma` in IR. Returns an `f64` representing the result. +pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { + let llvm_f64 = ctx.ctx.f64_type(); + + let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + ctx.module.add_function("__nac3_gamma", fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[v.into()], "gamma") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result. +pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { + let llvm_f64 = ctx.ctx.f64_type(); + + let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + ctx.module.add_function("__nac3_gammaln", fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[v.into()], "gammaln") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Generates a call to `j0` in IR. Returns an `f64` representing the result. +pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { + let llvm_f64 = ctx.ctx.f64_type(); + + let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + ctx.module.add_function("__nac3_j0", fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[v.into()], "j0") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 7e70a369..f6c4a1eb 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -3,25 +3,23 @@ use inkwell::{ context::Context, memory_buffer::MemoryBuffer, module::Module, - types::{BasicTypeEnum, IntType}, - values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue}, - AddressSpace, IntPredicate, + values::{BasicValue, BasicValueEnum, IntValue}, + IntPredicate, }; -use itertools::Either; use nac3parser::ast::Expr; -use super::{ - llvm_intrinsics, - macros::codegen_unreachable, - stmt::gen_for_callback_incrementing, - values::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, - }, - CodeGenContext, CodeGenerator, -}; +use super::{CodeGenContext, CodeGenerator}; use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type}; +pub use list::*; +pub use math::*; +pub use ndarray::*; +pub use slice::*; + +mod list; +mod math; +mod ndarray; +mod slice; #[must_use] pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> { @@ -62,88 +60,6 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) irrt_mod } -// repeated squaring method adapted from GNU Scientific Library: -// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c -pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - base: IntValue<'ctx>, - exp: IntValue<'ctx>, - signed: bool, -) -> IntValue<'ctx> { - let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) { - (32, 32, true) => "__nac3_int_exp_int32_t", - (64, 64, true) => "__nac3_int_exp_int64_t", - (32, 32, false) => "__nac3_int_exp_uint32_t", - (64, 64, false) => "__nac3_int_exp_uint64_t", - _ => codegen_unreachable!(ctx), - }; - let base_type = base.get_type(); - let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| { - let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false); - ctx.module.add_function(symbol, fn_type, None) - }); - // throw exception when exp < 0 - let ge_zero = ctx - .builder - .build_int_compare( - IntPredicate::SGE, - exp, - exp.get_type().const_zero(), - "assert_int_pow_ge_0", - ) - .unwrap(); - ctx.make_assert( - generator, - ge_zero, - "0:ValueError", - "integer power must be positive or zero", - [None, None, None], - ctx.current_loc, - ); - ctx.builder - .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() -} - -pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - start: IntValue<'ctx>, - end: IntValue<'ctx>, - step: IntValue<'ctx>, -) -> IntValue<'ctx> { - const SYMBOL: &str = "__nac3_range_slice_len"; - let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { - let i32_t = ctx.ctx.i32_type(); - let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false); - ctx.module.add_function(SYMBOL, fn_t, None) - }); - - // assert step != 0, throw exception if not - let not_zero = ctx - .builder - .build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne") - .unwrap(); - ctx.make_assert( - generator, - not_zero, - "0:ValueError", - "step must not be zero", - [None, None, None], - ctx.current_loc, - ); - ctx.builder - .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() -} - /// NOTE: the output value of the end index of this function should be compared ***inclusively***, /// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to /// NO numeric slice in python. @@ -309,644 +225,3 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( } })) } - -/// this function allows index out of range, since python -/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`). -pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>( - i: &Expr>, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - length: IntValue<'ctx>, -) -> Result>, String> { - const SYMBOL: &str = "__nac3_slice_index_bound"; - let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { - let i32_t = ctx.ctx.i32_type(); - let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false); - ctx.module.add_function(SYMBOL, fn_t, None) - }); - - let i = if let Some(v) = generator.gen_expr(ctx, i)? { - v.to_basic_value_enum(ctx, generator, i.custom.unwrap())? - } else { - return Ok(None); - }; - Ok(Some( - ctx.builder - .build_call(func, &[i.into(), length.into()], "bounded_ind") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(), - )) -} - -/// This function handles 'end' **inclusively**. -/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step'). -/// Negative index should be handled before entering this function -pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ty: BasicTypeEnum<'ctx>, - dest_arr: ListValue<'ctx>, - dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), - src_arr: ListValue<'ctx>, - src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), -) { - let size_ty = generator.get_size_type(ctx.ctx); - let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let int32 = ctx.ctx.i32_type(); - let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr); - let slice_assign_fun = { - let ty_vec = vec![ - int32.into(), // dest start idx - int32.into(), // dest end idx - int32.into(), // dest step - elem_ptr_type.into(), // dest arr ptr - int32.into(), // dest arr len - int32.into(), // src start idx - int32.into(), // src end idx - int32.into(), // src step - elem_ptr_type.into(), // src arr ptr - int32.into(), // src arr len - int32.into(), // size - ]; - ctx.module.get_function(fun_symbol).unwrap_or_else(|| { - let fn_t = int32.fn_type(ty_vec.as_slice(), false); - ctx.module.add_function(fun_symbol, fn_t, None) - }) - }; - - let zero = int32.const_zero(); - let one = int32.const_int(1, false); - let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator); - let dest_arr_ptr = - ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap(); - let dest_len = dest_arr.load_size(ctx, Some("dest.len")); - let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap(); - let src_arr_ptr = src_arr.data().base_ptr(ctx, generator); - let src_arr_ptr = - ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap(); - let src_len = src_arr.load_size(ctx, Some("src.len")); - let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap(); - - // index in bound and positive should be done - // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and - // throw exception if not satisfied - let src_end = ctx - .builder - .build_select( - ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(), - ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(), - ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(), - "final_e", - ) - .map(BasicValueEnum::into_int_value) - .unwrap(); - let dest_end = ctx - .builder - .build_select( - ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(), - ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(), - ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(), - "final_e", - ) - .map(BasicValueEnum::into_int_value) - .unwrap(); - let src_slice_len = - calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2); - let dest_slice_len = - calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2); - let src_eq_dest = ctx - .builder - .build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest") - .unwrap(); - let src_slt_dest = ctx - .builder - .build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest") - .unwrap(); - let dest_step_eq_one = ctx - .builder - .build_int_compare( - IntPredicate::EQ, - dest_idx.2, - dest_idx.2.get_type().const_int(1, false), - "slice_dest_step_eq_one", - ) - .unwrap(); - let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap(); - let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap(); - ctx.make_assert( - generator, - cond, - "0:ValueError", - "attempt to assign sequence of size {0} to slice of size {1} with step size {2}", - [Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)], - ctx.current_loc, - ); - - let new_len = { - let args = vec![ - dest_idx.0.into(), // dest start idx - dest_idx.1.into(), // dest end idx - dest_idx.2.into(), // dest step - dest_arr_ptr.into(), // dest arr ptr - dest_len.into(), // dest arr len - src_idx.0.into(), // src start idx - src_idx.1.into(), // src end idx - src_idx.2.into(), // src step - src_arr_ptr.into(), // src arr ptr - src_len.into(), // src arr len - { - let s = match ty { - BasicTypeEnum::FloatType(t) => t.size_of(), - BasicTypeEnum::IntType(t) => t.size_of(), - BasicTypeEnum::PointerType(t) => t.size_of(), - BasicTypeEnum::StructType(t) => t.size_of().unwrap(), - _ => codegen_unreachable!(ctx), - }; - ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap() - } - .into(), - ]; - ctx.builder - .build_call(slice_assign_fun, args.as_slice(), "slice_assign") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() - }; - // update length - let need_update = - ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); - let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); - let update_bb = ctx.ctx.append_basic_block(current, "update"); - let cont_bb = ctx.ctx.append_basic_block(current, "cont"); - ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); - ctx.builder.position_at_end(update_bb); - let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap(); - dest_arr.store_size(ctx, generator, new_len); - ctx.builder.build_unconditional_branch(cont_bb).unwrap(); - ctx.builder.position_at_end(cont_bb); -} - -/// Generates a call to `isinf` in IR. Returns an `i1` representing the result. -pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &CodeGenContext<'ctx, '_>, - v: FloatValue<'ctx>, -) -> IntValue<'ctx> { - let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { - let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); - ctx.module.add_function("__nac3_isinf", fn_type, None) - }); - - let ret = ctx - .builder - .build_call(intrinsic_fn, &[v.into()], "isinf") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - generator.bool_to_i1(ctx, ret) -} - -/// Generates a call to `isnan` in IR. Returns an `i1` representing the result. -pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &CodeGenContext<'ctx, '_>, - v: FloatValue<'ctx>, -) -> IntValue<'ctx> { - let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { - let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); - ctx.module.add_function("__nac3_isnan", fn_type, None) - }); - - let ret = ctx - .builder - .build_call(intrinsic_fn, &[v.into()], "isnan") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - generator.bool_to_i1(ctx, ret) -} - -/// Generates a call to `gamma` in IR. Returns an `f64` representing the result. -pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { - let llvm_f64 = ctx.ctx.f64_type(); - - let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_gamma", fn_type, None) - }); - - ctx.builder - .build_call(intrinsic_fn, &[v.into()], "gamma") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result. -pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { - let llvm_f64 = ctx.ctx.f64_type(); - - let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_gammaln", fn_type, None) - }); - - ctx.builder - .build_call(intrinsic_fn, &[v.into()], "gammaln") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `j0` in IR. Returns an `f64` representing the result. -pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { - let llvm_f64 = ctx.ctx.f64_type(); - - let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_j0", fn_type, None) - }); - - ctx.builder - .build_call(intrinsic_fn, &[v.into()], "j0") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the -/// calculated total size. -/// -/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. -/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, -/// or [`None`] if starting from the first dimension and ending at the last dimension -/// respectively. -pub fn call_ndarray_calc_size<'ctx, G, Dims>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - dims: &Dims, - (begin, end): (Option>, Option>), -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Dims: ArrayLikeIndexer<'ctx>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_size", - 64 => "__nac3_ndarray_calc_size64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_size_fn_t = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], - false, - ); - let ndarray_calc_size_fn = - ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { - ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) - }); - - let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); - let end = end.unwrap_or_else(|| dims.size(ctx, generator)); - ctx.builder - .build_call( - ndarray_calc_size_fn, - &[ - dims.base_ptr(ctx, generator).into(), - dims.size(ctx, generator).into(), - begin.into(), - end.into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] -/// containing `i32` indices of the flattened index. -/// -/// * `index` - The index to compute the multidimensional index for. -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &mut CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - ndarray: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_void = ctx.ctx.void_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_nd_indices", - 64 => "__nac3_ndarray_calc_nd_indices64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_nd_indices_fn = - ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); - - let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); - - ctx.builder - .build_call( - ndarray_calc_nd_indices_fn, - &[ - index.into(), - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} - -fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Indices, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Indices: ArrayLikeIndexer<'ctx>, -{ - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - debug_assert_eq!( - IntType::try_from(indices.element_type(ctx, generator)) - .map(IntType::get_bit_width) - .unwrap_or_default(), - llvm_i32.get_bit_width(), - "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" - ); - debug_assert_eq!( - indices.size(ctx, generator).get_type().get_bit_width(), - llvm_usize.get_bit_width(), - "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" - ); - - let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_flatten_index", - 64 => "__nac3_ndarray_flatten_index64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_flatten_index_fn = - ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], - false, - ); - - ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); - - let index = ctx - .builder - .build_call( - ndarray_flatten_index_fn, - &[ - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.base_ptr(ctx, generator).into(), - indices.size(ctx, generator).into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - index -} - -/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the -/// multidimensional index. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx, G, Index>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Index, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, -{ - call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of -/// dimension and size of each dimension of the resultant `ndarray`. -pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast", - 64 => "__nac3_ndarray_calc_broadcast64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[ - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - ], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (min_ndims, false), - |generator, ctx, _, idx| { - let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); - let (lhs_dim_sz, rhs_dim_sz) = unsafe { - ( - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), - ) - }; - - let llvm_usize_const_one = llvm_usize.const_int(1, false); - let lhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let rhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); - - let lhs_eq_rhs = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") - .unwrap(); - - let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); - - ctx.make_assert( - generator, - is_compatible, - "0:ValueError", - "operands could not be broadcast together", - [None, None, None], - ctx.current_loc, - ); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); - let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); - let rhs_ndims = rhs.load_ndims(ctx); - let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); - let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[ - lhs_dims.into(), - lhs_ndims.into(), - rhs_dims.into(), - rhs_ndims.into(), - out_dims.base_ptr(ctx, generator).into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - out_dims, - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] -/// containing the indices used for accessing `array` corresponding to the index of the broadcasted -/// array `broadcast_idx`. -pub fn call_ndarray_calc_broadcast_index< - 'ctx, - G: CodeGenerator + ?Sized, - BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, ->( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - array: NDArrayValue<'ctx>, - broadcast_idx: &BroadcastIdx, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast_idx", - 64 => "__nac3_ndarray_calc_broadcast_idx64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let broadcast_size = broadcast_idx.size(ctx, generator); - let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); - - let array_dims = array.dim_sizes().base_ptr(ctx, generator); - let array_ndims = array.load_ndims(ctx); - let broadcast_idx_ptr = unsafe { - broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} diff --git a/nac3core/src/codegen/irrt/ndarray.rs b/nac3core/src/codegen/irrt/ndarray.rs new file mode 100644 index 00000000..bfec1d56 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray.rs @@ -0,0 +1,384 @@ +use inkwell::{ + types::IntType, + values::{BasicValueEnum, CallSiteValue, IntValue}, + AddressSpace, IntPredicate, +}; +use itertools::Either; + +use crate::codegen::{ + llvm_intrinsics, + macros::codegen_unreachable, + stmt::gen_for_callback_incrementing, + values::{ + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, NDArrayValue, TypedArrayLikeAccessor, + TypedArrayLikeAdapter, UntypedArrayLikeAccessor, + }, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the +/// calculated total size. +/// +/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. +/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, +/// or [`None`] if starting from the first dimension and ending at the last dimension +/// respectively. +pub fn call_ndarray_calc_size<'ctx, G, Dims>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + dims: &Dims, + (begin, end): (Option>, Option>), +) -> IntValue<'ctx> +where + G: CodeGenerator + ?Sized, + Dims: ArrayLikeIndexer<'ctx>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_size", + 64 => "__nac3_ndarray_calc_size64", + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + let ndarray_calc_size_fn_t = llvm_usize.fn_type( + &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], + false, + ); + let ndarray_calc_size_fn = + ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { + ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) + }); + + let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); + let end = end.unwrap_or_else(|| dims.size(ctx, generator)); + ctx.builder + .build_call( + ndarray_calc_size_fn, + &[ + dims.base_ptr(ctx, generator).into(), + dims.size(ctx, generator).into(), + begin.into(), + end.into(), + ], + "", + ) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] +/// containing `i32` indices of the flattened index. +/// +/// * `index` - The index to compute the multidimensional index for. +/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an +/// `NDArray`. +pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + index: IntValue<'ctx>, + ndarray: NDArrayValue<'ctx>, +) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + let llvm_void = ctx.ctx.void_type(); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_nd_indices", + 64 => "__nac3_ndarray_calc_nd_indices64", + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + let ndarray_calc_nd_indices_fn = + ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { + let fn_type = llvm_void.fn_type( + &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], + false, + ); + + ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) + }); + + let ndarray_num_dims = ndarray.load_ndims(ctx); + let ndarray_dims = ndarray.dim_sizes(); + + let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); + + ctx.builder + .build_call( + ndarray_calc_nd_indices_fn, + &[ + index.into(), + ndarray_dims.base_ptr(ctx, generator).into(), + ndarray_num_dims.into(), + indices.into(), + ], + "", + ) + .unwrap(); + + TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), + Box::new(|_, v| v.into_int_value()), + Box::new(|_, v| v.into()), + ) +} + +fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + indices: &Indices, +) -> IntValue<'ctx> +where + G: CodeGenerator + ?Sized, + Indices: ArrayLikeIndexer<'ctx>, +{ + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + debug_assert_eq!( + IntType::try_from(indices.element_type(ctx, generator)) + .map(IntType::get_bit_width) + .unwrap_or_default(), + llvm_i32.get_bit_width(), + "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" + ); + debug_assert_eq!( + indices.size(ctx, generator).get_type().get_bit_width(), + llvm_usize.get_bit_width(), + "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" + ); + + let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_flatten_index", + 64 => "__nac3_ndarray_flatten_index64", + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + let ndarray_flatten_index_fn = + ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], + false, + ); + + ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) + }); + + let ndarray_num_dims = ndarray.load_ndims(ctx); + let ndarray_dims = ndarray.dim_sizes(); + + let index = ctx + .builder + .build_call( + ndarray_flatten_index_fn, + &[ + ndarray_dims.base_ptr(ctx, generator).into(), + ndarray_num_dims.into(), + indices.base_ptr(ctx, generator).into(), + indices.size(ctx, generator).into(), + ], + "", + ) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(); + + index +} + +/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the +/// multidimensional index. +/// +/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an +/// `NDArray`. +/// * `indices` - The multidimensional index to compute the flattened index for. +pub fn call_ndarray_flatten_index<'ctx, G, Index>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + indices: &Index, +) -> IntValue<'ctx> +where + G: CodeGenerator + ?Sized, + Index: ArrayLikeIndexer<'ctx>, +{ + call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) +} + +/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of +/// dimension and size of each dimension of the resultant `ndarray`. +pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + lhs: NDArrayValue<'ctx>, + rhs: NDArrayValue<'ctx>, +) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_broadcast", + 64 => "__nac3_ndarray_calc_broadcast64", + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + let ndarray_calc_broadcast_fn = + ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[ + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + }); + + let lhs_ndims = lhs.load_ndims(ctx); + let rhs_ndims = rhs.load_ndims(ctx); + let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); + + gen_for_callback_incrementing( + generator, + ctx, + None, + llvm_usize.const_zero(), + (min_ndims, false), + |generator, ctx, _, idx| { + let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); + let (lhs_dim_sz, rhs_dim_sz) = unsafe { + ( + lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), + rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), + ) + }; + + let llvm_usize_const_one = llvm_usize.const_int(1, false); + let lhs_eqz = ctx + .builder + .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") + .unwrap(); + let rhs_eqz = ctx + .builder + .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") + .unwrap(); + let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); + + let lhs_eq_rhs = ctx + .builder + .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") + .unwrap(); + + let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); + + ctx.make_assert( + generator, + is_compatible, + "0:ValueError", + "operands could not be broadcast together", + [None, None, None], + ctx.current_loc, + ); + + Ok(()) + }, + llvm_usize.const_int(1, false), + ) + .unwrap(); + + let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); + let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); + let lhs_ndims = lhs.load_ndims(ctx); + let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); + let rhs_ndims = rhs.load_ndims(ctx); + let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); + let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); + + ctx.builder + .build_call( + ndarray_calc_broadcast_fn, + &[ + lhs_dims.into(), + lhs_ndims.into(), + rhs_dims.into(), + rhs_ndims.into(), + out_dims.base_ptr(ctx, generator).into(), + ], + "", + ) + .unwrap(); + + TypedArrayLikeAdapter::from( + out_dims, + Box::new(|_, v| v.into_int_value()), + Box::new(|_, v| v.into()), + ) +} + +/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] +/// containing the indices used for accessing `array` corresponding to the index of the broadcasted +/// array `broadcast_idx`. +pub fn call_ndarray_calc_broadcast_index< + 'ctx, + G: CodeGenerator + ?Sized, + BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, +>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + array: NDArrayValue<'ctx>, + broadcast_idx: &BroadcastIdx, +) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_broadcast_idx", + 64 => "__nac3_ndarray_calc_broadcast_idx64", + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + let ndarray_calc_broadcast_fn = + ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], + false, + ); + + ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + }); + + let broadcast_size = broadcast_idx.size(ctx, generator); + let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); + + let array_dims = array.dim_sizes().base_ptr(ctx, generator); + let array_ndims = array.load_ndims(ctx); + let broadcast_idx_ptr = unsafe { + broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + }; + + ctx.builder + .build_call( + ndarray_calc_broadcast_fn, + &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], + "", + ) + .unwrap(); + + TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), + Box::new(|_, v| v.into_int_value()), + Box::new(|_, v| v.into()), + ) +} diff --git a/nac3core/src/codegen/irrt/slice.rs b/nac3core/src/codegen/irrt/slice.rs new file mode 100644 index 00000000..eb7037ac --- /dev/null +++ b/nac3core/src/codegen/irrt/slice.rs @@ -0,0 +1,76 @@ +use inkwell::{ + values::{BasicValueEnum, CallSiteValue, IntValue}, + IntPredicate, +}; +use itertools::Either; +use nac3parser::ast::Expr; + +use crate::{ + codegen::{CodeGenContext, CodeGenerator}, + typecheck::typedef::Type, +}; + +/// this function allows index out of range, since python +/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`). +pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>( + i: &Expr>, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + length: IntValue<'ctx>, +) -> Result>, String> { + const SYMBOL: &str = "__nac3_slice_index_bound"; + let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { + let i32_t = ctx.ctx.i32_type(); + let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false); + ctx.module.add_function(SYMBOL, fn_t, None) + }); + + let i = if let Some(v) = generator.gen_expr(ctx, i)? { + v.to_basic_value_enum(ctx, generator, i.custom.unwrap())? + } else { + return Ok(None); + }; + Ok(Some( + ctx.builder + .build_call(func, &[i.into(), length.into()], "bounded_ind") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(), + )) +} + +pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + start: IntValue<'ctx>, + end: IntValue<'ctx>, + step: IntValue<'ctx>, +) -> IntValue<'ctx> { + const SYMBOL: &str = "__nac3_range_slice_len"; + let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { + let i32_t = ctx.ctx.i32_type(); + let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false); + ctx.module.add_function(SYMBOL, fn_t, None) + }); + + // assert step != 0, throw exception if not + let not_zero = ctx + .builder + .build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne") + .unwrap(); + ctx.make_assert( + generator, + not_zero, + "0:ValueError", + "step must not be zero", + [None, None, None], + ctx.current_loc, + ); + ctx.builder + .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +}