forked from M-Labs/nac3
1
0
Fork 0

core: Add const variants to NDArray element getters

This commit is contained in:
David Mak 2024-02-15 15:10:12 +08:00
parent 1eacaf9afa
commit 976a9512c1
2 changed files with 198 additions and 15 deletions

View File

@ -1,12 +1,12 @@
use inkwell::{
IntPredicate,
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
values::{BasicValueEnum, IntValue, PointerValue},
values::{ArrayValue, BasicValueEnum, IntValue, PointerValue},
};
use crate::codegen::{
CodeGenContext,
CodeGenerator,
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const},
stmt::gen_for_callback,
};
@ -725,7 +725,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected list[int32] but got {indices_elem_ty}")
};
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}");
debug_assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}");
let index = call_ndarray_flatten_index(
generator,
@ -743,6 +743,92 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
}
}
pub unsafe fn ptr_offset_unchecked_const(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
indices: ArrayValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let index = call_ndarray_flatten_index_const(
generator,
ctx,
self.0,
indices,
).unwrap();
unsafe {
ctx.builder.build_in_bounds_gep(
self.get_ptr(ctx),
&[index],
name.unwrap_or_default(),
)
}
}
/// Returns the pointer to the data at the index specified by `indices`.
pub fn ptr_offset_const(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
indices: ArrayValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_elem_ty = indices.get_type().get_element_type();
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected [int32] but got [{indices_elem_ty}]")
};
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected [int32] but got [{indices_elem_ty}]");
let nidx_leq_ndims = ctx.builder.build_int_compare(
IntPredicate::SLE,
llvm_usize.const_int(indices.get_type().len() as u64, false),
self.0.load_ndims(ctx),
""
);
ctx.make_assert(
generator,
nidx_leq_ndims,
"0:IndexError",
"invalid index to scalar variable",
[None, None, None],
ctx.current_loc,
);
for idx in 0..indices.get_type().len() {
let i = llvm_usize.const_int(idx as u64, false);
let dim_idx = ctx.builder
.build_extract_value(indices, idx, "")
.map(|v| v.into_int_value())
.map(|v| ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, ""))
.unwrap();
let dim_sz = self.0.get_dims().get(ctx, generator, i, None);
let dim_lt = ctx.builder.build_int_compare(
IntPredicate::SLT,
dim_idx,
dim_sz,
""
);
ctx.make_assert(
generator,
dim_lt,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(dim_idx), Some(dim_sz), None],
ctx.current_loc,
);
}
unsafe {
self.ptr_offset_unchecked_const(ctx, generator, indices, name)
}
}
/// Returns the pointer to the data at the index specified by `indices`.
pub fn ptr_offset(
&self,
@ -844,6 +930,17 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
}
}
pub unsafe fn get_unsafe_const(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
indices: ArrayValue<'ctx>,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = self.ptr_offset_unchecked_const(ctx, generator, indices, name);
ctx.builder.build_load(ptr, name.unwrap_or_default())
}
pub unsafe fn get_unsafe(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -855,6 +952,18 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
ctx.builder.build_load(ptr, name.unwrap_or_default())
}
/// Returns the data at the index specified by `indices`.
pub fn get_const(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
indices: ArrayValue<'ctx>,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = self.ptr_offset_const(ctx, generator, indices, name);
ctx.builder.build_load(ptr, name.unwrap_or_default())
}
/// Returns the data at the index specified by `indices`.
pub fn get(
&self,

View File

@ -10,8 +10,8 @@ use inkwell::{
context::Context,
memory_buffer::MemoryBuffer,
module::Module,
types::BasicTypeEnum,
values::{FloatValue, IntValue, PointerValue},
types::{BasicTypeEnum, IntType},
values::{ArrayValue, FloatValue, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use nac3parser::ast::Expr;
@ -707,16 +707,12 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
Ok(indices)
}
/// Generates a call to `__nac3_ndarray_flatten_index`.
///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx>(
fn call_ndarray_flatten_index_impl<'ctx>(
generator: &dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: ListValue<'ctx>,
indices: PointerValue<'ctx>,
indices_size: IntValue<'ctx>,
) -> Result<IntValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
@ -724,6 +720,19 @@ pub fn call_ndarray_flatten_index<'ctx>(
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.get_type().get_element_type())
.map(|itype| itype.get_bit_width())
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices_size.get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
@ -745,8 +754,6 @@ pub fn call_ndarray_flatten_index<'ctx>(
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.get_dims();
let indices_size = indices.load_size(ctx, None);
let indices_data = indices.get_data();
let index = ctx.builder
.build_call(
@ -765,4 +772,71 @@ pub fn call_ndarray_flatten_index<'ctx>(
.unwrap();
Ok(index)
}
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx>(
generator: &dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: ListValue<'ctx>,
) -> Result<IntValue<'ctx>, String> {
let indices_size = indices.load_size(ctx, None);
let indices_data = indices.get_data();
call_ndarray_flatten_index_impl(
generator,
ctx,
ndarray,
indices_data.get_ptr(ctx),
indices_size,
)
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index_const<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: ArrayValue<'ctx>,
) -> Result<IntValue<'ctx>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_size = indices.get_type().len();
let indices_alloca = generator.gen_array_var_alloc(
ctx,
indices.get_type().get_element_type(),
llvm_usize.const_int(indices_size as u64, false),
None
)?;
for i in 0..indices_size {
let v = ctx.builder.build_extract_value(indices, i, "")
.unwrap()
.into_int_value();
let elem_ptr = unsafe {
ctx.builder.build_in_bounds_gep(
indices_alloca,
&[ctx.ctx.i32_type().const_int(i as u64, false)],
""
)
};
ctx.builder.build_store(elem_ptr, v);
}
call_ndarray_flatten_index_impl(
generator,
ctx,
ndarray,
indices_alloca,
llvm_usize.const_int(indices_size as u64, false),
)
}