nac3/nac3core/src/codegen/irrt/mod.rs

953 lines
36 KiB
Rust
Raw Normal View History

use inkwell::{
attributes::{Attribute, AttributeLoc},
context::Context,
memory_buffer::MemoryBuffer,
module::Module,
types::{BasicTypeEnum, IntType},
values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue},
2022-01-09 19:55:17 +08:00
AddressSpace, IntPredicate,
};
2024-02-19 19:30:25 +08:00
use itertools::Either;
2022-01-09 19:55:17 +08:00
use nac3parser::ast::Expr;
use super::{
llvm_intrinsics,
macros::codegen_unreachable,
stmt::gen_for_callback_incrementing,
values::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
CodeGenContext, CodeGenerator,
};
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
2023-12-08 17:43:32 +08:00
#[must_use]
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
let bitcode_buf = MemoryBuffer::create_from_memory_range(
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
"irrt_bitcode_buffer",
);
let irrt_mod = Module::parse_bitcode_from_buffer(&bitcode_buf, ctx).unwrap();
let inline_attr = Attribute::get_named_enum_kind_id("alwaysinline");
2022-01-09 19:55:17 +08:00
for symbol in &[
"__nac3_int_exp_int32_t",
"__nac3_int_exp_int64_t",
"__nac3_range_slice_len",
"__nac3_slice_index_bound",
] {
let function = irrt_mod.get_function(symbol).unwrap();
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
}
// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
let exn_id_type = ctx.i32_type();
let errors = &[
("EXN_INDEX_ERROR", "0:IndexError"),
("EXN_VALUE_ERROR", "0:ValueError"),
("EXN_ASSERTION_ERROR", "0:AssertionError"),
("EXN_TYPE_ERROR", "0:TypeError"),
];
for (irrt_name, symbol_name) in errors {
let exn_id = symbol_resolver.get_string_id(symbol_name);
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
let global = irrt_mod.get_global(irrt_name).unwrap_or_else(|| {
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
});
global.set_initializer(&exn_id);
}
2022-01-09 01:05:17 +08:00
irrt_mod
}
// repeated squaring method adapted from GNU Scientific Library:
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
2023-12-06 11:49:02 +08:00
ctx: &mut CodeGenContext<'ctx, '_>,
base: IntValue<'ctx>,
exp: IntValue<'ctx>,
2022-03-05 03:45:09 +08:00
signed: bool,
) -> IntValue<'ctx> {
2022-03-05 03:45:09 +08:00
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
(32, 32, true) => "__nac3_int_exp_int32_t",
(64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t",
2024-08-23 13:10:55 +08:00
_ => codegen_unreachable!(ctx),
};
let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
ctx.module.add_function(symbol, fn_type, None)
});
// throw exception when exp < 0
2024-06-12 14:45:03 +08:00
let ge_zero = ctx
.builder
.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
)
.unwrap();
ctx.make_assert(
generator,
ge_zero,
"0:ValueError",
"integer power must be positive or zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
2024-02-19 19:30:25 +08:00
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
2022-01-09 19:55:17 +08:00
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
2023-12-06 11:49:02 +08:00
ctx: &mut CodeGenContext<'ctx, '_>,
2022-01-09 19:55:17 +08:00
start: IntValue<'ctx>,
end: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
const SYMBOL: &str = "__nac3_range_slice_len";
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
// assert step != 0, throw exception if not
2024-06-12 14:45:03 +08:00
let not_zero = ctx
.builder
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
.unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"step must not be zero",
[None, None, None],
ctx.current_loc,
);
2022-01-09 19:55:17 +08:00
ctx.builder
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
2024-02-19 19:30:25 +08:00
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
2022-01-09 19:55:17 +08:00
.unwrap()
}
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
/// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to
/// NO numeric slice in python.
///
/// equivalent code:
/// ```pseudo_code
/// match (start, end, step):
/// case (s, e, None | Some(step)) if step > 0:
/// return (
/// match s:
/// case None:
/// 0
/// case Some(s):
/// handle_in_bound(s)
/// ,match e:
/// case None:
/// length - 1
/// case Some(e):
/// handle_in_bound(e) - 1
/// ,step == None ? 1 : step
/// )
/// case (s, e, Some(step)) if step < 0:
/// return (
/// match s:
/// case None:
/// length - 1
/// case Some(s):
/// s = handle_in_bound(s)
/// if s == length:
/// s - 1
/// else:
/// s
/// ,match e:
/// case None:
/// 0
/// case Some(e):
/// handle_in_bound(e) + 1
/// ,step
/// )
2022-01-09 19:55:17 +08:00
/// ```
2023-12-06 11:49:02 +08:00
pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
2022-01-09 19:55:17 +08:00
start: &Option<Box<Expr<Option<Type>>>>,
end: &Option<Box<Expr<Option<Type>>>>,
step: &Option<Box<Expr<Option<Type>>>>,
2023-12-06 11:49:02 +08:00
ctx: &mut CodeGenContext<'ctx, '_>,
2022-01-09 19:55:17 +08:00
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
2022-01-09 19:55:17 +08:00
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let one = int32.const_int(1, false);
2024-02-19 19:30:25 +08:00
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
Ok(Some(match (start, end, step) {
2022-01-09 19:55:17 +08:00
(s, e, None) => (
if let Some(s) = s.as_ref() {
match handle_slice_index_bound(s, ctx, generator, length)? {
Some(v) => v,
None => return Ok(None),
}
} else {
int32.const_zero()
},
2022-01-09 19:55:17 +08:00
{
let e = if let Some(s) = e.as_ref() {
match handle_slice_index_bound(s, ctx, generator, length)? {
Some(v) => v,
None => return Ok(None),
}
} else {
length
};
2024-02-19 19:30:25 +08:00
ctx.builder.build_int_sub(e, one, "final_end").unwrap()
2022-01-09 19:55:17 +08:00
},
one,
),
(s, e, Some(step)) => {
let step = if let Some(v) = generator.gen_expr(ctx, step)? {
v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value()
} else {
2024-06-12 14:45:03 +08:00
return Ok(None);
};
// assert step != 0, throw exception if not
2024-06-12 14:45:03 +08:00
let not_zero = ctx
.builder
.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
)
.unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"slice step cannot be zero",
[None, None, None],
ctx.current_loc,
);
2024-02-19 19:30:25 +08:00
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap();
2024-06-12 14:45:03 +08:00
let neg = ctx
.builder
.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg")
.unwrap();
2022-01-09 19:55:17 +08:00
(
match s {
Some(s) => {
let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else {
2024-06-12 14:45:03 +08:00
return Ok(None);
};
2022-01-09 19:55:17 +08:00
ctx.builder
.build_select(
2024-06-12 14:45:03 +08:00
ctx.builder
.build_and(
ctx.builder
.build_int_compare(
IntPredicate::EQ,
s,
length,
"s_eq_len",
)
.unwrap(),
neg,
"should_minus_one",
)
.unwrap(),
2024-02-19 19:30:25 +08:00
ctx.builder.build_int_sub(s, one, "s_min").unwrap(),
2022-01-09 19:55:17 +08:00
s,
"final_start",
)
2024-02-19 19:30:25 +08:00
.map(BasicValueEnum::into_int_value)
.unwrap()
2022-01-09 19:55:17 +08:00
}
2024-06-12 14:45:03 +08:00
None => ctx
.builder
.build_select(neg, len_id, zero, "stt")
2024-02-19 19:30:25 +08:00
.map(BasicValueEnum::into_int_value)
.unwrap(),
2022-01-09 19:55:17 +08:00
},
match e {
Some(e) => {
let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else {
2024-06-12 14:45:03 +08:00
return Ok(None);
};
2022-01-09 19:55:17 +08:00
ctx.builder
.build_select(
neg,
2024-02-19 19:30:25 +08:00
ctx.builder.build_int_add(e, one, "end_add_one").unwrap(),
ctx.builder.build_int_sub(e, one, "end_sub_one").unwrap(),
2022-01-09 19:55:17 +08:00
"final_end",
)
2024-02-19 19:30:25 +08:00
.map(BasicValueEnum::into_int_value)
.unwrap()
2022-01-09 19:55:17 +08:00
}
2024-06-12 14:45:03 +08:00
None => ctx
.builder
.build_select(neg, zero, len_id, "end")
2024-02-19 19:30:25 +08:00
.map(BasicValueEnum::into_int_value)
.unwrap(),
2022-01-09 19:55:17 +08:00
},
step,
)
}
}))
2022-01-09 19:55:17 +08:00
}
/// this function allows index out of range, since python
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
2023-12-06 11:49:02 +08:00
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
2022-01-09 19:55:17 +08:00
i: &Expr<Option<Type>>,
2023-12-06 11:49:02 +08:00
ctx: &mut CodeGenContext<'ctx, '_>,
2022-01-09 19:55:17 +08:00
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<IntValue<'ctx>>, String> {
2022-01-09 19:55:17 +08:00
const SYMBOL: &str = "__nac3_slice_index_bound";
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
} else {
2024-06-12 14:45:03 +08:00
return Ok(None);
};
2024-06-12 14:45:03 +08:00
Ok(Some(
ctx.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap(),
))
2022-01-09 19:55:17 +08:00
}
/// This function handles 'end' **inclusively**.
2023-12-08 17:43:32 +08:00
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
2022-01-09 19:55:17 +08:00
/// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
2023-12-06 11:49:02 +08:00
ctx: &mut CodeGenContext<'ctx, '_>,
2022-01-09 19:55:17 +08:00
ty: BasicTypeEnum<'ctx>,
dest_arr: ListValue<'ctx>,
2022-01-09 19:55:17 +08:00
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
src_arr: ListValue<'ctx>,
2022-01-09 19:55:17 +08:00
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) {
let size_ty = generator.get_size_type(ctx.ctx);
2023-01-12 19:31:03 +08:00
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
2022-01-09 19:55:17 +08:00
let int32 = ctx.ctx.i32_type();
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
let slice_assign_fun = {
let ty_vec = vec![
int32.into(), // dest start idx
int32.into(), // dest end idx
int32.into(), // dest step
elem_ptr_type.into(), // dest arr ptr
int32.into(), // dest arr len
int32.into(), // src start idx
int32.into(), // src end idx
int32.into(), // src step
elem_ptr_type.into(), // src arr ptr
int32.into(), // src arr len
int32.into(), // size
];
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
ctx.module.add_function(fun_symbol, fn_t, None)
})
};
2022-01-09 19:55:17 +08:00
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
2024-06-12 14:45:03 +08:00
let dest_arr_ptr =
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
2024-02-19 19:30:25 +08:00
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
2024-06-12 14:45:03 +08:00
let src_arr_ptr =
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
let src_len = src_arr.load_size(ctx, Some("src.len"));
2024-02-19 19:30:25 +08:00
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
2022-01-09 19:55:17 +08:00
// index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
2022-01-09 19:55:17 +08:00
// throw exception if not satisfied
2024-06-12 14:45:03 +08:00
let src_end = ctx
.builder
.build_select(
2024-06-12 14:45:03 +08:00
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
2024-02-19 19:30:25 +08:00
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
2024-02-19 19:30:25 +08:00
.map(BasicValueEnum::into_int_value)
.unwrap();
2024-06-12 14:45:03 +08:00
let dest_end = ctx
.builder
.build_select(
2024-06-12 14:45:03 +08:00
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
2024-02-19 19:30:25 +08:00
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
2024-02-19 19:30:25 +08:00
.map(BasicValueEnum::into_int_value)
.unwrap();
let src_slice_len =
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len =
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
2024-06-12 14:45:03 +08:00
let src_eq_dest = ctx
.builder
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
.unwrap();
let src_slt_dest = ctx
.builder
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
.unwrap();
let dest_step_eq_one = ctx
.builder
.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
)
.unwrap();
2024-02-19 19:30:25 +08:00
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
ctx.make_assert(
generator,
cond,
"0:ValueError",
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
ctx.current_loc,
);
let new_len = {
let args = vec![
dest_idx.0.into(), // dest start idx
dest_idx.1.into(), // dest end idx
dest_idx.2.into(), // dest step
dest_arr_ptr.into(), // dest arr ptr
dest_len.into(), // dest arr len
src_idx.0.into(), // src start idx
src_idx.1.into(), // src end idx
src_idx.2.into(), // src step
src_arr_ptr.into(), // src arr ptr
src_len.into(), // src arr len
{
let s = match ty {
BasicTypeEnum::FloatType(t) => t.size_of(),
BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
2024-08-23 13:10:55 +08:00
_ => codegen_unreachable!(ctx),
};
2024-02-19 19:30:25 +08:00
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
}
.into(),
];
ctx.builder
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
2024-02-19 19:30:25 +08:00
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
};
2022-01-09 19:55:17 +08:00
// update length
2024-06-12 14:45:03 +08:00
let need_update =
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
2022-01-09 19:55:17 +08:00
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let update_bb = ctx.ctx.append_basic_block(current, "update");
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
2024-02-19 19:30:25 +08:00
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
2022-01-09 19:55:17 +08:00
ctx.builder.position_at_end(update_bb);
2024-06-12 14:45:03 +08:00
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
dest_arr.store_size(ctx, generator, new_len);
2024-02-19 19:30:25 +08:00
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
2022-01-09 19:55:17 +08:00
ctx.builder.position_at_end(cont_bb);
}
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
2023-12-06 11:49:02 +08:00
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isinf", fn_type, None)
});
2024-06-12 14:45:03 +08:00
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isinf")
2024-02-19 19:30:25 +08:00
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
2023-12-06 11:49:02 +08:00
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isnan", fn_type, None)
});
2024-06-12 14:45:03 +08:00
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isnan")
2024-02-19 19:30:25 +08:00
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
2024-06-12 14:45:03 +08:00
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gamma", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gamma")
2024-02-19 19:30:25 +08:00
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
2024-06-12 14:45:03 +08:00
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gammaln", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gammaln")
2024-02-19 19:30:25 +08:00
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
2024-06-12 14:45:03 +08:00
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_j0", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "j0")
2024-02-19 19:30:25 +08:00
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
2024-02-20 18:07:55 +08:00
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
2024-06-12 14:45:03 +08:00
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
2024-08-20 11:29:03 +08:00
/// or [`None`] if starting from the first dimension and ending at the last dimension
/// respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
2024-06-12 14:45:03 +08:00
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
2024-08-23 13:10:55 +08:00
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
2024-06-12 14:45:03 +08:00
let ndarray_calc_size_fn =
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)
2024-02-19 19:30:25 +08:00
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
2024-02-20 18:07:55 +08:00
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
2024-08-21 11:10:52 +08:00
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
2024-03-22 16:57:06 +08:00
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
2024-08-23 13:10:55 +08:00
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
2024-06-12 14:45:03 +08:00
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
2024-06-12 14:45:03 +08:00
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
2024-06-12 14:45:03 +08:00
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
2024-02-19 19:30:25 +08:00
ctx.builder
.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.base_ptr(ctx, generator).into(),
2024-02-19 19:30:25 +08:00
ndarray_num_dims.into(),
indices.into(),
],
"",
)
.unwrap();
2024-03-22 16:57:06 +08:00
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Indices,
) -> IntValue<'ctx>
2024-06-12 14:45:03 +08:00
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator))
2024-02-20 18:07:55 +08:00
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
2024-08-23 13:10:55 +08:00
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
2024-06-12 14:45:03 +08:00
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
2024-06-12 14:45:03 +08:00
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
2024-06-12 14:45:03 +08:00
let index = ctx
.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
2024-02-19 19:30:25 +08:00
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
2024-02-20 18:07:55 +08:00
index
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
2024-02-20 18:07:55 +08:00
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
2024-08-21 11:10:52 +08:00
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
2024-06-12 14:45:03 +08:00
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
2024-08-23 13:10:55 +08:00
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
2024-06-12 14:45:03 +08:00
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
2024-06-12 14:45:03 +08:00
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
2024-06-12 14:45:03 +08:00
let lhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let rhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
let lhs_eq_rhs = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap();
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
2024-06-12 14:45:03 +08:00
)
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
2024-06-12 14:45:03 +08:00
pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
2024-08-23 13:10:55 +08:00
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
2024-06-12 14:45:03 +08:00
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false,
);
2024-06-12 14:45:03 +08:00
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
2024-06-12 14:45:03 +08:00
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
2024-06-12 14:45:03 +08:00
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
2024-06-12 14:45:03 +08:00
}