forked from M-Labs/nac3
core: Add const variants to NDArray element getters
This commit is contained in:
parent
1eacaf9afa
commit
976a9512c1
|
@ -1,12 +1,12 @@
|
|||
use inkwell::{
|
||||
IntPredicate,
|
||||
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
|
||||
values::{BasicValueEnum, IntValue, PointerValue},
|
||||
values::{ArrayValue, BasicValueEnum, IntValue, PointerValue},
|
||||
};
|
||||
use crate::codegen::{
|
||||
CodeGenContext,
|
||||
CodeGenerator,
|
||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const},
|
||||
stmt::gen_for_callback,
|
||||
};
|
||||
|
||||
|
@ -725,7 +725,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
|
||||
panic!("Expected list[int32] but got {indices_elem_ty}")
|
||||
};
|
||||
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}");
|
||||
debug_assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}");
|
||||
|
||||
let index = call_ndarray_flatten_index(
|
||||
generator,
|
||||
|
@ -743,6 +743,92 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||
}
|
||||
}
|
||||
|
||||
pub unsafe fn ptr_offset_unchecked_const(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
indices: ArrayValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let index = call_ndarray_flatten_index_const(
|
||||
generator,
|
||||
ctx,
|
||||
self.0,
|
||||
indices,
|
||||
).unwrap();
|
||||
|
||||
unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
self.get_ptr(ctx),
|
||||
&[index],
|
||||
name.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the pointer to the data at the index specified by `indices`.
|
||||
pub fn ptr_offset_const(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
indices: ArrayValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let indices_elem_ty = indices.get_type().get_element_type();
|
||||
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
|
||||
panic!("Expected [int32] but got [{indices_elem_ty}]")
|
||||
};
|
||||
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected [int32] but got [{indices_elem_ty}]");
|
||||
|
||||
let nidx_leq_ndims = ctx.builder.build_int_compare(
|
||||
IntPredicate::SLE,
|
||||
llvm_usize.const_int(indices.get_type().len() as u64, false),
|
||||
self.0.load_ndims(ctx),
|
||||
""
|
||||
);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
nidx_leq_ndims,
|
||||
"0:IndexError",
|
||||
"invalid index to scalar variable",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
for idx in 0..indices.get_type().len() {
|
||||
let i = llvm_usize.const_int(idx as u64, false);
|
||||
|
||||
let dim_idx = ctx.builder
|
||||
.build_extract_value(indices, idx, "")
|
||||
.map(|v| v.into_int_value())
|
||||
.map(|v| ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, ""))
|
||||
.unwrap();
|
||||
let dim_sz = self.0.get_dims().get(ctx, generator, i, None);
|
||||
|
||||
let dim_lt = ctx.builder.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
dim_idx,
|
||||
dim_sz,
|
||||
""
|
||||
);
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
dim_lt,
|
||||
"0:IndexError",
|
||||
"index {0} is out of bounds for axis 0 with size {1}",
|
||||
[Some(dim_idx), Some(dim_sz), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
self.ptr_offset_unchecked_const(ctx, generator, indices, name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the pointer to the data at the index specified by `indices`.
|
||||
pub fn ptr_offset(
|
||||
&self,
|
||||
|
@ -844,6 +930,17 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||
}
|
||||
}
|
||||
|
||||
pub unsafe fn get_unsafe_const(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
indices: ArrayValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let ptr = self.ptr_offset_unchecked_const(ctx, generator, indices, name);
|
||||
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||
}
|
||||
|
||||
pub unsafe fn get_unsafe(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
|
@ -855,6 +952,18 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Returns the data at the index specified by `indices`.
|
||||
pub fn get_const(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
indices: ArrayValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let ptr = self.ptr_offset_const(ctx, generator, indices, name);
|
||||
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Returns the data at the index specified by `indices`.
|
||||
pub fn get(
|
||||
&self,
|
||||
|
|
|
@ -10,8 +10,8 @@ use inkwell::{
|
|||
context::Context,
|
||||
memory_buffer::MemoryBuffer,
|
||||
module::Module,
|
||||
types::BasicTypeEnum,
|
||||
values::{FloatValue, IntValue, PointerValue},
|
||||
types::{BasicTypeEnum, IntType},
|
||||
values::{ArrayValue, FloatValue, IntValue, PointerValue},
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
use nac3parser::ast::Expr;
|
||||
|
@ -707,16 +707,12 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
|
|||
Ok(indices)
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_flatten_index`.
|
||||
///
|
||||
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||
pub fn call_ndarray_flatten_index<'ctx>(
|
||||
fn call_ndarray_flatten_index_impl<'ctx>(
|
||||
generator: &dyn CodeGenerator,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: ListValue<'ctx>,
|
||||
indices: PointerValue<'ctx>,
|
||||
indices_size: IntValue<'ctx>,
|
||||
) -> Result<IntValue<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
@ -724,6 +720,19 @@ pub fn call_ndarray_flatten_index<'ctx>(
|
|||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
debug_assert_eq!(
|
||||
IntType::try_from(indices.get_type().get_element_type())
|
||||
.map(|itype| itype.get_bit_width())
|
||||
.unwrap_or_default(),
|
||||
llvm_i32.get_bit_width(),
|
||||
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
|
||||
);
|
||||
debug_assert_eq!(
|
||||
indices_size.get_type().get_bit_width(),
|
||||
llvm_usize.get_bit_width(),
|
||||
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
|
||||
);
|
||||
|
||||
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_flatten_index",
|
||||
64 => "__nac3_ndarray_flatten_index64",
|
||||
|
@ -745,8 +754,6 @@ pub fn call_ndarray_flatten_index<'ctx>(
|
|||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let ndarray_dims = ndarray.get_dims();
|
||||
let indices_size = indices.load_size(ctx, None);
|
||||
let indices_data = indices.get_data();
|
||||
|
||||
let index = ctx.builder
|
||||
.build_call(
|
||||
|
@ -765,4 +772,71 @@ pub fn call_ndarray_flatten_index<'ctx>(
|
|||
.unwrap();
|
||||
|
||||
Ok(index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
||||
/// multidimensional index.
|
||||
///
|
||||
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||
pub fn call_ndarray_flatten_index<'ctx>(
|
||||
generator: &dyn CodeGenerator,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: ListValue<'ctx>,
|
||||
) -> Result<IntValue<'ctx>, String> {
|
||||
let indices_size = indices.load_size(ctx, None);
|
||||
let indices_data = indices.get_data();
|
||||
|
||||
call_ndarray_flatten_index_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray,
|
||||
indices_data.get_ptr(ctx),
|
||||
indices_size,
|
||||
)
|
||||
}
|
||||
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
||||
/// multidimensional index.
|
||||
///
|
||||
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||
pub fn call_ndarray_flatten_index_const<'ctx>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: ArrayValue<'ctx>,
|
||||
) -> Result<IntValue<'ctx>, String> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let indices_size = indices.get_type().len();
|
||||
let indices_alloca = generator.gen_array_var_alloc(
|
||||
ctx,
|
||||
indices.get_type().get_element_type(),
|
||||
llvm_usize.const_int(indices_size as u64, false),
|
||||
None
|
||||
)?;
|
||||
for i in 0..indices_size {
|
||||
let v = ctx.builder.build_extract_value(indices, i, "")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let elem_ptr = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
indices_alloca,
|
||||
&[ctx.ctx.i32_type().const_int(i as u64, false)],
|
||||
""
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(elem_ptr, v);
|
||||
}
|
||||
|
||||
call_ndarray_flatten_index_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray,
|
||||
indices_alloca,
|
||||
llvm_usize.const_int(indices_size as u64, false),
|
||||
)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue