1
0
forked from M-Labs/nac3

core: Add const variants to NDArray element getters

This commit is contained in:
David Mak 2024-02-15 15:10:12 +08:00
parent 1eacaf9afa
commit 976a9512c1
2 changed files with 198 additions and 15 deletions

View File

@ -1,12 +1,12 @@
use inkwell::{ use inkwell::{
IntPredicate, IntPredicate,
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}, types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
values::{BasicValueEnum, IntValue, PointerValue}, values::{ArrayValue, BasicValueEnum, IntValue, PointerValue},
}; };
use crate::codegen::{ use crate::codegen::{
CodeGenContext, CodeGenContext,
CodeGenerator, CodeGenerator,
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const},
stmt::gen_for_callback, stmt::gen_for_callback,
}; };
@ -725,7 +725,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected list[int32] but got {indices_elem_ty}") panic!("Expected list[int32] but got {indices_elem_ty}")
}; };
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}"); debug_assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}");
let index = call_ndarray_flatten_index( let index = call_ndarray_flatten_index(
generator, generator,
@ -743,6 +743,92 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
} }
} }
pub unsafe fn ptr_offset_unchecked_const(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
indices: ArrayValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let index = call_ndarray_flatten_index_const(
generator,
ctx,
self.0,
indices,
).unwrap();
unsafe {
ctx.builder.build_in_bounds_gep(
self.get_ptr(ctx),
&[index],
name.unwrap_or_default(),
)
}
}
/// Returns the pointer to the data at the index specified by `indices`.
pub fn ptr_offset_const(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
indices: ArrayValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_elem_ty = indices.get_type().get_element_type();
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected [int32] but got [{indices_elem_ty}]")
};
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected [int32] but got [{indices_elem_ty}]");
let nidx_leq_ndims = ctx.builder.build_int_compare(
IntPredicate::SLE,
llvm_usize.const_int(indices.get_type().len() as u64, false),
self.0.load_ndims(ctx),
""
);
ctx.make_assert(
generator,
nidx_leq_ndims,
"0:IndexError",
"invalid index to scalar variable",
[None, None, None],
ctx.current_loc,
);
for idx in 0..indices.get_type().len() {
let i = llvm_usize.const_int(idx as u64, false);
let dim_idx = ctx.builder
.build_extract_value(indices, idx, "")
.map(|v| v.into_int_value())
.map(|v| ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, ""))
.unwrap();
let dim_sz = self.0.get_dims().get(ctx, generator, i, None);
let dim_lt = ctx.builder.build_int_compare(
IntPredicate::SLT,
dim_idx,
dim_sz,
""
);
ctx.make_assert(
generator,
dim_lt,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(dim_idx), Some(dim_sz), None],
ctx.current_loc,
);
}
unsafe {
self.ptr_offset_unchecked_const(ctx, generator, indices, name)
}
}
/// Returns the pointer to the data at the index specified by `indices`. /// Returns the pointer to the data at the index specified by `indices`.
pub fn ptr_offset( pub fn ptr_offset(
&self, &self,
@ -844,6 +930,17 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
} }
} }
pub unsafe fn get_unsafe_const(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
indices: ArrayValue<'ctx>,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = self.ptr_offset_unchecked_const(ctx, generator, indices, name);
ctx.builder.build_load(ptr, name.unwrap_or_default())
}
pub unsafe fn get_unsafe( pub unsafe fn get_unsafe(
&self, &self,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -855,6 +952,18 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
ctx.builder.build_load(ptr, name.unwrap_or_default()) ctx.builder.build_load(ptr, name.unwrap_or_default())
} }
/// Returns the data at the index specified by `indices`.
pub fn get_const(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
indices: ArrayValue<'ctx>,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = self.ptr_offset_const(ctx, generator, indices, name);
ctx.builder.build_load(ptr, name.unwrap_or_default())
}
/// Returns the data at the index specified by `indices`. /// Returns the data at the index specified by `indices`.
pub fn get( pub fn get(
&self, &self,

View File

@ -10,8 +10,8 @@ use inkwell::{
context::Context, context::Context,
memory_buffer::MemoryBuffer, memory_buffer::MemoryBuffer,
module::Module, module::Module,
types::BasicTypeEnum, types::{BasicTypeEnum, IntType},
values::{FloatValue, IntValue, PointerValue}, values::{ArrayValue, FloatValue, IntValue, PointerValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
@ -707,16 +707,12 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
Ok(indices) Ok(indices)
} }
/// Generates a call to `__nac3_ndarray_flatten_index`. fn call_ndarray_flatten_index_impl<'ctx>(
///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx>(
generator: &dyn CodeGenerator, generator: &dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
indices: ListValue<'ctx>, indices: PointerValue<'ctx>,
indices_size: IntValue<'ctx>,
) -> Result<IntValue<'ctx>, String> { ) -> Result<IntValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
@ -724,6 +720,19 @@ pub fn call_ndarray_flatten_index<'ctx>(
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.get_type().get_element_type())
.map(|itype| itype.get_bit_width())
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices_size.get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index", 32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64", 64 => "__nac3_ndarray_flatten_index64",
@ -745,8 +754,6 @@ pub fn call_ndarray_flatten_index<'ctx>(
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.get_dims(); let ndarray_dims = ndarray.get_dims();
let indices_size = indices.load_size(ctx, None);
let indices_data = indices.get_data();
let index = ctx.builder let index = ctx.builder
.build_call( .build_call(
@ -765,4 +772,71 @@ pub fn call_ndarray_flatten_index<'ctx>(
.unwrap(); .unwrap();
Ok(index) Ok(index)
} }
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx>(
generator: &dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: ListValue<'ctx>,
) -> Result<IntValue<'ctx>, String> {
let indices_size = indices.load_size(ctx, None);
let indices_data = indices.get_data();
call_ndarray_flatten_index_impl(
generator,
ctx,
ndarray,
indices_data.get_ptr(ctx),
indices_size,
)
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index_const<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: ArrayValue<'ctx>,
) -> Result<IntValue<'ctx>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_size = indices.get_type().len();
let indices_alloca = generator.gen_array_var_alloc(
ctx,
indices.get_type().get_element_type(),
llvm_usize.const_int(indices_size as u64, false),
None
)?;
for i in 0..indices_size {
let v = ctx.builder.build_extract_value(indices, i, "")
.unwrap()
.into_int_value();
let elem_ptr = unsafe {
ctx.builder.build_in_bounds_gep(
indices_alloca,
&[ctx.ctx.i32_type().const_int(i as u64, false)],
""
)
};
ctx.builder.build_store(elem_ptr, v);
}
call_ndarray_flatten_index_impl(
generator,
ctx,
ndarray,
indices_alloca,
llvm_usize.const_int(indices_size as u64, false),
)
}