forked from M-Labs/nac3
core: Add const variants to NDArray element getters
This commit is contained in:
parent
1eacaf9afa
commit
976a9512c1
@ -1,12 +1,12 @@
|
|||||||
use inkwell::{
|
use inkwell::{
|
||||||
IntPredicate,
|
IntPredicate,
|
||||||
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
|
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
values::{ArrayValue, BasicValueEnum, IntValue, PointerValue},
|
||||||
};
|
};
|
||||||
use crate::codegen::{
|
use crate::codegen::{
|
||||||
CodeGenContext,
|
CodeGenContext,
|
||||||
CodeGenerator,
|
CodeGenerator,
|
||||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const},
|
||||||
stmt::gen_for_callback,
|
stmt::gen_for_callback,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -725,7 +725,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||||||
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
|
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
|
||||||
panic!("Expected list[int32] but got {indices_elem_ty}")
|
panic!("Expected list[int32] but got {indices_elem_ty}")
|
||||||
};
|
};
|
||||||
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}");
|
debug_assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}");
|
||||||
|
|
||||||
let index = call_ndarray_flatten_index(
|
let index = call_ndarray_flatten_index(
|
||||||
generator,
|
generator,
|
||||||
@ -743,6 +743,92 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub unsafe fn ptr_offset_unchecked_const(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
indices: ArrayValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let index = call_ndarray_flatten_index_const(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
self.0,
|
||||||
|
indices,
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
self.get_ptr(ctx),
|
||||||
|
&[index],
|
||||||
|
name.unwrap_or_default(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the pointer to the data at the index specified by `indices`.
|
||||||
|
pub fn ptr_offset_const(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
indices: ArrayValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let indices_elem_ty = indices.get_type().get_element_type();
|
||||||
|
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
|
||||||
|
panic!("Expected [int32] but got [{indices_elem_ty}]")
|
||||||
|
};
|
||||||
|
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected [int32] but got [{indices_elem_ty}]");
|
||||||
|
|
||||||
|
let nidx_leq_ndims = ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::SLE,
|
||||||
|
llvm_usize.const_int(indices.get_type().len() as u64, false),
|
||||||
|
self.0.load_ndims(ctx),
|
||||||
|
""
|
||||||
|
);
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
nidx_leq_ndims,
|
||||||
|
"0:IndexError",
|
||||||
|
"invalid index to scalar variable",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
for idx in 0..indices.get_type().len() {
|
||||||
|
let i = llvm_usize.const_int(idx as u64, false);
|
||||||
|
|
||||||
|
let dim_idx = ctx.builder
|
||||||
|
.build_extract_value(indices, idx, "")
|
||||||
|
.map(|v| v.into_int_value())
|
||||||
|
.map(|v| ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, ""))
|
||||||
|
.unwrap();
|
||||||
|
let dim_sz = self.0.get_dims().get(ctx, generator, i, None);
|
||||||
|
|
||||||
|
let dim_lt = ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
dim_idx,
|
||||||
|
dim_sz,
|
||||||
|
""
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
dim_lt,
|
||||||
|
"0:IndexError",
|
||||||
|
"index {0} is out of bounds for axis 0 with size {1}",
|
||||||
|
[Some(dim_idx), Some(dim_sz), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
self.ptr_offset_unchecked_const(ctx, generator, indices, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the pointer to the data at the index specified by `indices`.
|
/// Returns the pointer to the data at the index specified by `indices`.
|
||||||
pub fn ptr_offset(
|
pub fn ptr_offset(
|
||||||
&self,
|
&self,
|
||||||
@ -844,6 +930,17 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub unsafe fn get_unsafe_const(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
indices: ArrayValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
let ptr = self.ptr_offset_unchecked_const(ctx, generator, indices, name);
|
||||||
|
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||||
|
}
|
||||||
|
|
||||||
pub unsafe fn get_unsafe(
|
pub unsafe fn get_unsafe(
|
||||||
&self,
|
&self,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
@ -855,6 +952,18 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||||||
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the data at the index specified by `indices`.
|
||||||
|
pub fn get_const(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
indices: ArrayValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
let ptr = self.ptr_offset_const(ctx, generator, indices, name);
|
||||||
|
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the data at the index specified by `indices`.
|
/// Returns the data at the index specified by `indices`.
|
||||||
pub fn get(
|
pub fn get(
|
||||||
&self,
|
&self,
|
||||||
|
@ -10,8 +10,8 @@ use inkwell::{
|
|||||||
context::Context,
|
context::Context,
|
||||||
memory_buffer::MemoryBuffer,
|
memory_buffer::MemoryBuffer,
|
||||||
module::Module,
|
module::Module,
|
||||||
types::BasicTypeEnum,
|
types::{BasicTypeEnum, IntType},
|
||||||
values::{FloatValue, IntValue, PointerValue},
|
values::{ArrayValue, FloatValue, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
use nac3parser::ast::Expr;
|
use nac3parser::ast::Expr;
|
||||||
@ -707,16 +707,12 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
|
|||||||
Ok(indices)
|
Ok(indices)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_flatten_index`.
|
fn call_ndarray_flatten_index_impl<'ctx>(
|
||||||
///
|
|
||||||
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
|
||||||
/// `NDArray`.
|
|
||||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
|
||||||
pub fn call_ndarray_flatten_index<'ctx>(
|
|
||||||
generator: &dyn CodeGenerator,
|
generator: &dyn CodeGenerator,
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
ndarray: NDArrayValue<'ctx>,
|
ndarray: NDArrayValue<'ctx>,
|
||||||
indices: ListValue<'ctx>,
|
indices: PointerValue<'ctx>,
|
||||||
|
indices_size: IntValue<'ctx>,
|
||||||
) -> Result<IntValue<'ctx>, String> {
|
) -> Result<IntValue<'ctx>, String> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
@ -724,6 +720,19 @@ pub fn call_ndarray_flatten_index<'ctx>(
|
|||||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
debug_assert_eq!(
|
||||||
|
IntType::try_from(indices.get_type().get_element_type())
|
||||||
|
.map(|itype| itype.get_bit_width())
|
||||||
|
.unwrap_or_default(),
|
||||||
|
llvm_i32.get_bit_width(),
|
||||||
|
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
|
||||||
|
);
|
||||||
|
debug_assert_eq!(
|
||||||
|
indices_size.get_type().get_bit_width(),
|
||||||
|
llvm_usize.get_bit_width(),
|
||||||
|
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
|
||||||
|
);
|
||||||
|
|
||||||
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
|
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
|
||||||
32 => "__nac3_ndarray_flatten_index",
|
32 => "__nac3_ndarray_flatten_index",
|
||||||
64 => "__nac3_ndarray_flatten_index64",
|
64 => "__nac3_ndarray_flatten_index64",
|
||||||
@ -745,8 +754,6 @@ pub fn call_ndarray_flatten_index<'ctx>(
|
|||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
let ndarray_dims = ndarray.get_dims();
|
let ndarray_dims = ndarray.get_dims();
|
||||||
let indices_size = indices.load_size(ctx, None);
|
|
||||||
let indices_data = indices.get_data();
|
|
||||||
|
|
||||||
let index = ctx.builder
|
let index = ctx.builder
|
||||||
.build_call(
|
.build_call(
|
||||||
@ -766,3 +773,70 @@ pub fn call_ndarray_flatten_index<'ctx>(
|
|||||||
|
|
||||||
Ok(index)
|
Ok(index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
||||||
|
/// multidimensional index.
|
||||||
|
///
|
||||||
|
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
||||||
|
/// `NDArray`.
|
||||||
|
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||||
|
pub fn call_ndarray_flatten_index<'ctx>(
|
||||||
|
generator: &dyn CodeGenerator,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
indices: ListValue<'ctx>,
|
||||||
|
) -> Result<IntValue<'ctx>, String> {
|
||||||
|
let indices_size = indices.load_size(ctx, None);
|
||||||
|
let indices_data = indices.get_data();
|
||||||
|
|
||||||
|
call_ndarray_flatten_index_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray,
|
||||||
|
indices_data.get_ptr(ctx),
|
||||||
|
indices_size,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
||||||
|
/// multidimensional index.
|
||||||
|
///
|
||||||
|
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
||||||
|
/// `NDArray`.
|
||||||
|
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||||
|
pub fn call_ndarray_flatten_index_const<'ctx>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
indices: ArrayValue<'ctx>,
|
||||||
|
) -> Result<IntValue<'ctx>, String> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let indices_size = indices.get_type().len();
|
||||||
|
let indices_alloca = generator.gen_array_var_alloc(
|
||||||
|
ctx,
|
||||||
|
indices.get_type().get_element_type(),
|
||||||
|
llvm_usize.const_int(indices_size as u64, false),
|
||||||
|
None
|
||||||
|
)?;
|
||||||
|
for i in 0..indices_size {
|
||||||
|
let v = ctx.builder.build_extract_value(indices, i, "")
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
let elem_ptr = unsafe {
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
indices_alloca,
|
||||||
|
&[ctx.ctx.i32_type().const_int(i as u64, false)],
|
||||||
|
""
|
||||||
|
)
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(elem_ptr, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
call_ndarray_flatten_index_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray,
|
||||||
|
indices_alloca,
|
||||||
|
llvm_usize.const_int(indices_size as u64, false),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user