[core] Minor improvements to IRRT and add missing documentation

This commit is contained in:
David Mak 2024-12-17 16:10:00 +08:00
parent a2efccee1c
commit 1b1cb775af
13 changed files with 242 additions and 112 deletions

View File

@ -17,6 +17,7 @@ pub trait CodeGenerator {
/// Return the module name for the code generator.
fn get_name(&self) -> &str;
/// Return an instance of [`IntType`] corresponding to the type of `size_t` for this instance.
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>;
/// Generate function call and returns the function return value.

View File

@ -24,42 +24,52 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
src_arr: ListValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) {
let size_ty = generator.get_size_type(ctx.ctx);
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let int32 = ctx.ctx.i32_type();
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let llvm_i32 = ctx.ctx.i32_type();
assert_eq!(dest_idx.0.get_type(), llvm_i32);
assert_eq!(dest_idx.1.get_type(), llvm_i32);
assert_eq!(dest_idx.2.get_type(), llvm_i32);
assert_eq!(src_idx.0.get_type(), llvm_i32);
assert_eq!(src_idx.1.get_type(), llvm_i32);
assert_eq!(src_idx.2.get_type(), llvm_i32);
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", llvm_pi8);
let slice_assign_fun = {
let ty_vec = vec![
int32.into(), // dest start idx
int32.into(), // dest end idx
int32.into(), // dest step
llvm_i32.into(), // dest start idx
llvm_i32.into(), // dest end idx
llvm_i32.into(), // dest step
elem_ptr_type.into(), // dest arr ptr
int32.into(), // dest arr len
int32.into(), // src start idx
int32.into(), // src end idx
int32.into(), // src step
llvm_i32.into(), // dest arr len
llvm_i32.into(), // src start idx
llvm_i32.into(), // src end idx
llvm_i32.into(), // src step
elem_ptr_type.into(), // src arr ptr
int32.into(), // src arr len
int32.into(), // size
llvm_i32.into(), // src arr len
llvm_i32.into(), // size
];
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
let fn_t = llvm_i32.fn_type(ty_vec.as_slice(), false);
ctx.module.add_function(fun_symbol, fn_t, None)
})
};
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let zero = llvm_i32.const_zero();
let one = llvm_i32.const_int(1, false);
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
let dest_arr_ptr =
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
let dest_len =
ctx.builder.build_int_truncate_or_bit_cast(dest_len, llvm_i32, "srclen32").unwrap();
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
let src_arr_ptr =
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
let src_len = src_arr.load_size(ctx, Some("src.len"));
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
let src_len =
ctx.builder.build_int_truncate_or_bit_cast(src_len, llvm_i32, "srclen32").unwrap();
// index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
@ -136,7 +146,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => codegen_unreachable!(ctx),
};
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
ctx.builder.build_int_truncate_or_bit_cast(s, llvm_i32, "size").unwrap()
}
.into(),
];
@ -147,6 +157,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
.map(Either::unwrap_left)
.unwrap()
};
// update length
let need_update =
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
@ -155,7 +166,8 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
ctx.builder.position_at_end(update_bb);
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
let new_len =
ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap();
dest_arr.store_size(ctx, generator, new_len);
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
ctx.builder.position_at_end(cont_bb);

View File

@ -62,8 +62,13 @@ pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_f64 = ctx.ctx.f64_type();
assert_eq!(v.get_type(), llvm_f64);
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_isinf", fn_type, None)
});
@ -84,8 +89,13 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_f64 = ctx.ctx.f64_type();
assert_eq!(v.get_type(), llvm_f64);
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_isnan", fn_type, None)
});
@ -104,6 +114,8 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
assert_eq!(v.get_type(), llvm_f64);
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gamma", fn_type, None)
@ -121,6 +133,8 @@ pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) ->
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
assert_eq!(v.get_type(), llvm_f64);
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gammaln", fn_type, None)
@ -138,6 +152,8 @@ pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
assert_eq!(v.get_type(), llvm_f64);
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_j0", fn_type, None)

View File

@ -130,10 +130,11 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
let llvm_i32 = ctx.ctx.i32_type();
let zero = llvm_i32.const_zero();
let one = llvm_i32.const_int(1, false);
let length = ctx.builder.build_int_truncate_or_bit_cast(length, llvm_i32, "leni32").unwrap();
Ok(Some(match (start, end, step) {
(s, e, None) => (
if let Some(s) = s.as_ref() {
@ -142,7 +143,7 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
None => return Ok(None),
}
} else {
int32.const_zero()
llvm_i32.const_zero()
},
{
let e = if let Some(s) = e.as_ref() {

View File

@ -1,4 +1,5 @@
use inkwell::{
types::BasicTypeEnum,
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace,
};
@ -7,19 +8,26 @@ use crate::codegen::{
expr::{create_and_call_function, infer_and_call_function},
irrt::get_usize_dependent_function_name,
types::ProxyType,
values::{ndarray::NDArrayValue, ProxyValue},
values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor},
CodeGenContext, CodeGenerator,
};
/// Generates a call to `__nac3_ndarray_util_assert_shape_no_negative`.
///
/// Assets that `shape` does not contain negative dimensions.
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndims: IntValue<'ctx>,
shape: PointerValue<'ctx>,
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!(
BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
let name = get_usize_dependent_function_name(
generator,
ctx,
@ -30,23 +38,37 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator +
ctx,
&name,
Some(llvm_usize.into()),
&[(llvm_usize.into(), ndims.into()), (llvm_pusize.into(), shape.into())],
&[
(llvm_usize.into(), shape.size(ctx, generator).into()),
(llvm_pusize.into(), shape.base_ptr(ctx, generator).into()),
],
None,
None,
);
}
/// Generates a call to `__nac3_ndarray_util_assert_shape_output_shape_same`.
///
/// Asserts that `ndarray_shape` and `output_shape` are the same in the context of writing output to
/// an `ndarray`.
pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray_ndims: IntValue<'ctx>,
ndarray_shape: PointerValue<'ctx>,
output_ndims: IntValue<'ctx>,
output_shape: IntValue<'ctx>,
ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!(
BasicTypeEnum::try_from(ndarray_shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
assert_eq!(
BasicTypeEnum::try_from(output_shape.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
let name = get_usize_dependent_function_name(
generator,
ctx,
@ -58,16 +80,20 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator +
&name,
Some(llvm_usize.into()),
&[
(llvm_usize.into(), ndarray_ndims.into()),
(llvm_pusize.into(), ndarray_shape.into()),
(llvm_usize.into(), output_ndims.into()),
(llvm_pusize.into(), output_shape.into()),
(llvm_usize.into(), ndarray_shape.size(ctx, generator).into()),
(llvm_pusize.into(), ndarray_shape.base_ptr(ctx, generator).into()),
(llvm_usize.into(), output_shape.size(ctx, generator).into()),
(llvm_pusize.into(), output_shape.base_ptr(ctx, generator).into()),
],
None,
None,
);
}
/// Generates a call to `__nac3_ndarray_size`.
///
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an
/// `ndarray`, corresponding to the value of `ndarray.size`.
pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
@ -90,6 +116,10 @@ pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
.unwrap()
}
/// Generates a call to `__nac3_ndarray_nbytes`.
///
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the
/// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`.
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
@ -112,6 +142,10 @@ pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
.unwrap()
}
/// Generates a call to `__nac3_ndarray_len`.
///
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of
/// the `ndarray`, corresponding to the value of `ndarray.__len__`.
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
@ -134,6 +168,9 @@ pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
.unwrap()
}
/// Generates a call to `__nac3_ndarray_is_c_contiguous`.
///
/// Returns an `i1` value indicating whether the `ndarray` is C-contiguous.
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
@ -156,6 +193,9 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
.unwrap()
}
/// Generates a call to `__nac3_ndarray_get_nth_pelement`.
///
/// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`.
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
@ -167,6 +207,8 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type();
assert_eq!(index.get_type(), llvm_usize);
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
create_and_call_function(
@ -181,11 +223,16 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
.unwrap()
}
/// Generates a call to `__nac3_ndarray_get_pelement_by_indices`.
///
/// `indices` must have the same number of elements as the number of dimensions in `ndarray`.
///
/// Returns a [`PointerValue`] to the element indexed by `indices`.
pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: PointerValue<'ctx>,
indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
@ -193,6 +240,11 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let llvm_ndarray = ndarray.get_type().as_base_type();
assert_eq!(
BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
let name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
@ -202,7 +254,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
Some(llvm_pi8.into()),
&[
(llvm_ndarray.into(), ndarray.as_base_value().into()),
(llvm_pusize.into(), indices.into()),
(llvm_pusize.into(), indices.base_ptr(ctx, generator).into()),
],
Some("pelement"),
None,
@ -211,6 +263,9 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
.unwrap()
}
/// Generates a call to `__nac3_ndarray_set_strides_by_shape`.
///
/// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous.
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
@ -231,6 +286,11 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
);
}
/// Generates a call to `__nac3_ndarray_copy_data`.
///
/// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number
/// of elements in `src_ndarray` must be greater than or equal to the number of elements in
/// `dst_ndarray`.
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,

View File

@ -5,6 +5,11 @@ use crate::codegen::{
CodeGenContext, CodeGenerator,
};
/// Generates a call to `__nac3_ndarray_index`.
///
/// Performs [basic indexing](https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
/// on `src_ndarray` using `indices`, writing the result to `dst_ndarray`, corresponding to the
/// operation `dst_ndarray = src_ndarray[indices]`.
pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,

View File

@ -1,4 +1,5 @@
use inkwell::{
types::BasicTypeEnum,
values::{BasicValueEnum, IntValue},
AddressSpace,
};
@ -9,21 +10,29 @@ use crate::codegen::{
types::ProxyType,
values::{
ndarray::{NDArrayValue, NDIterValue},
ArrayLikeValue, ArraySliceValue, ProxyValue,
ProxyValue, TypedArrayLikeAccessor,
},
CodeGenContext, CodeGenerator,
};
/// Generates a call to `__nac3_nditer_initialize`.
///
/// Initializes the `iter` object.
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
iter: NDIterValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
indices: ArraySliceValue<'ctx>,
indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!(
BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
create_and_call_function(
@ -40,6 +49,10 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
);
}
/// Generates a call to `__nac3_nditer_initialize_has_element`.
///
/// Returns an `i1` value indicating whether there are elements left to traverse for the `iter`
/// object.
pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
@ -59,6 +72,9 @@ pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>(
.unwrap()
}
/// Generates a call to `__nac3_nditer_next`.
///
/// Moves `iter` to point to the next element.
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,

View File

@ -1,10 +1,11 @@
use inkwell::{
types::IntType,
types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, CallSiteValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
use super::get_usize_dependent_function_name;
use crate::codegen::{
llvm_intrinsics,
macros::codegen_unreachable,
@ -23,8 +24,8 @@ mod basic;
mod indexing;
mod iter;
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
/// Generates a call to `__nac3_ndarray_calc_size`. Returns a
/// [`usize`][CodeGenerator::get_size_type] representing the calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
@ -43,18 +44,22 @@ where
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
assert!(begin.is_none_or(|begin| begin.get_type() == llvm_usize));
assert!(end.is_none_or(|end| end.get_type() == llvm_usize));
assert_eq!(
BasicTypeEnum::try_from(dims.element_type(ctx, generator)).unwrap(),
llvm_usize.into()
);
let ndarray_calc_size_fn_name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_size");
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
let ndarray_calc_size_fn =
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
ctx.module.get_function(&ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(&ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
@ -76,10 +81,10 @@ where
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypedArrayLikeAdapter`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `index` - The `llvm_usize` index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
@ -94,19 +99,18 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
assert_eq!(index.get_type(), llvm_usize);
let ndarray_calc_nd_indices_fn_name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_nd_indices");
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
ctx.module.get_function(&ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
ctx.module.add_function(&ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
@ -134,15 +138,21 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
)
}
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns a `usize` of the flattened index for
/// the multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Indices,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>,
Index: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
@ -163,19 +173,16 @@ where
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn_name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_flatten_index");
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
ctx.module.get_function(&ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
ctx.module.add_function(&ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
@ -201,27 +208,8 @@ where
index
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a [`TypedArrayLikeAdapter`]
/// containing the size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -231,13 +219,10 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn_name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_calc_broadcast");
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
ctx.module.get_function(&ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
@ -249,7 +234,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
ctx.module.add_function(&ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);

View File

@ -6,6 +6,13 @@ use itertools::Either;
use crate::codegen::{CodeGenContext, CodeGenerator};
/// Invokes the `__nac3_range_slice_len` in IRRT.
///
/// - `start`: The `i32` start value for the slice.
/// - `end`: The `i32` end value for the slice.
/// - `step`: The `i32` step value for the slice.
///
/// Returns an `i32` value of the length of the slice.
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -14,9 +21,15 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
const SYMBOL: &str = "__nac3_range_slice_len";
let llvm_i32 = ctx.ctx.i32_type();
assert_eq!(start.get_type(), llvm_i32);
assert_eq!(end.get_type(), llvm_i32);
assert_eq!(step.get_type(), llvm_i32);
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
let fn_t = llvm_i32.fn_type(&[llvm_i32.into(), llvm_i32.into(), llvm_i32.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
@ -33,6 +46,7 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
.map(CallSiteValue::try_as_basic_value)

View File

@ -14,7 +14,7 @@ use crate::codegen::{
types::structure::{check_struct_type_matches_fields, StructField, StructFields},
values::{
ndarray::{NDArrayValue, NDIterValue},
ArraySliceValue, ProxyValue,
ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAdapter,
},
CodeGenContext, CodeGenerator,
};
@ -128,6 +128,11 @@ impl<'ctx> NDIterType<'ctx> {
}
/// Allocate an [`NDIter`] that iterates through the given `ndarray`.
///
/// Note: This function allocates an array on the stack at the current builder location, which
/// may lead to stack explosion if called in a hot loop. Therefore, callers are recommended to
/// call `llvm.stacksave` before calling this function and call `llvm.stackrestore` after the
/// [`NDIter`] is no longer needed.
#[must_use]
pub fn construct<G: CodeGenerator + ?Sized>(
&self,
@ -141,16 +146,12 @@ impl<'ctx> NDIterType<'ctx> {
// The caller has the responsibility to allocate 'indices' for `NDIter`.
let indices =
generator.gen_array_var_alloc(ctx, self.llvm_usize.into(), ndims, None).unwrap();
let indices =
TypedArrayLikeAdapter::from(indices, |_, _, v| v.into_int_value(), |_, _, v| v.into());
let nditer = <Self as ProxyType<'ctx>>::Value::from_pointer_value(
nditer,
ndarray,
indices,
self.llvm_usize,
None,
);
let nditer = self.map_value(nditer, ndarray, indices.as_slice_value(ctx, generator), None);
irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, indices);
irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, &indices);
nditer
}

View File

@ -265,6 +265,14 @@ where
) -> IntValue<'ctx> {
self.adapted.size(ctx, generator)
}
fn as_slice_value<CG: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &CG,
) -> ArraySliceValue<'ctx> {
self.adapted.as_slice_value(ctx, generator)
}
}
impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index>

View File

@ -358,6 +358,10 @@ impl<'ctx> NDArrayValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self);
}
/// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and
/// copy the contents over.
///
/// The new ndarray will own its data and will be C-contiguous.
#[must_use]
pub fn make_copy<G: CodeGenerator + ?Sized>(
&self,

View File

@ -69,7 +69,10 @@ impl<'ctx> NDIterValue<'ctx> {
irrt::ndarray::call_nac3_nditer_next(generator, ctx, *self);
}
fn element_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
fn element_field(
&self,
ctx: &CodeGenContext<'ctx, '_>,
) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).element
}
@ -138,6 +141,10 @@ impl<'ctx> NDArrayValue<'ctx> {
///
/// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterValue`] to
/// get properties of the current iteration (e.g., the current element, indices, etc.)
///
/// Note: The caller is recommended to call `llvm.stacksave` and `llvm.stackrestore` before and
/// after invoking this function respectively. See [`NDIterType::construct`] for an explanation
/// on why this is suggested.
pub fn foreach<'a, G, F>(
&self,
generator: &mut G,