From 847091580956de165ad09ec56d8b4a41af25850e Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 22 Jan 2024 16:51:35 +0800 Subject: [PATCH] core: Add NDArrayValue and helper functions --- nac3core/src/codegen/classes.rs | 489 ++++++++++++++++++++++++++++++- nac3core/src/codegen/irrt/irrt.c | 38 ++- nac3core/src/codegen/irrt/mod.rs | 100 +++++-- nac3core/src/codegen/mod.rs | 27 -- nac3core/src/toplevel/numpy.rs | 219 +++----------- 5 files changed, 630 insertions(+), 243 deletions(-) diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 4f8e78512..86822eefe 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -3,7 +3,12 @@ use inkwell::{ types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}, values::{BasicValueEnum, IntValue, PointerValue}, }; -use crate::codegen::{CodeGenContext, CodeGenerator}; +use crate::codegen::{ + CodeGenContext, + CodeGenerator, + irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, + stmt::gen_for_callback, +}; #[cfg(not(debug_assertions))] pub fn assert_is_list<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {} @@ -380,3 +385,485 @@ impl<'ctx> RangeValue<'ctx> { ctx.builder.build_load(pstep, var_name.as_str()).into_int_value() } } + +#[cfg(not(debug_assertions))] +pub fn assert_is_ndarray<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {} + +#[cfg(debug_assertions)] +pub fn assert_is_ndarray<'ctx>(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) { + if let Err(msg) = NDArrayValue::is_instance(value, llvm_usize) { + panic!("{msg}") + } +} + +/// Proxy type for accessing an `NDArray` value in LLVM. +#[derive(Copy, Clone)] +pub struct NDArrayValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>); + +impl<'ctx> NDArrayValue<'ctx> { + /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an + /// instance. + pub fn is_instance( + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + let llvm_ndarray_ty = value.get_type().get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")) + }; + if llvm_ndarray_ty.count_fields() != 3 { + return Err(format!("Expected 3 fields in `NDArray`, got {}", llvm_ndarray_ty.count_fields())) + } + + let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap(); + let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else { + return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}")) + }; + if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() { + return Err(format!("Expected {}-bit int type for `ndarray.0`, got {}-bit int", + llvm_usize.get_bit_width(), + ndarray_ndims_ty.get_bit_width())) + } + + let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap(); + let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else { + return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}")) + }; + let ndarray_dims = ndarray_pdims.get_element_type(); + let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else { + return Err(format!("Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}")) + }; + if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() { + return Err(format!("Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int", + llvm_usize.get_bit_width(), + ndarray_dims.get_bit_width())) + } + + let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap(); + let Ok(_) = PointerType::try_from(ndarray_data_ty) else { + return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}")) + }; + + Ok(()) + } + + /// Creates an [NDArrayValue] from a [PointerValue]. + pub fn from_ptr_val( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + assert_is_ndarray(ptr, llvm_usize); + NDArrayValue(ptr, name) + } + + /// Returns the underlying [PointerValue] pointing to the `NDArray` instance. + pub fn get_ptr(&self) -> PointerValue<'ctx> { + self.0 + } + + /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. + fn get_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.0, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + var_name.as_str(), + ) + } + } + + /// Stores the number of dimensions `ndims` into this instance. + pub fn store_ndims( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &dyn CodeGenerator, + ndims: IntValue<'ctx>, + ) { + debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); + + let pndims = self.get_ndims(ctx); + ctx.builder.build_store(pndims, ndims); + } + + /// Returns the number of dimensions of this `NDArray` as a value. + pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + let pndims = self.get_ndims(ctx); + ctx.builder.build_load(pndims, "").into_int_value() + } + + /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` + /// on the field. + fn get_dims_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.get_ptr(), + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + var_name.as_str(), + ) + } + } + + /// Stores the array of dimension sizes `dims` into this instance. + fn store_dims(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { + ctx.builder.build_store(self.get_dims_ptr(ctx), dims); + } + + /// Convenience method for creating a new array storing dimension sizes with the given `size`. + pub fn create_dims( + &self, + ctx: &CodeGenContext<'ctx, '_>, + llvm_usize: IntType<'ctx>, + size: IntValue<'ctx>, + ) { + self.store_dims(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "")); + } + + /// Returns a proxy object to the field storing the size of each dimension of this `NDArray`. + pub fn get_dims(&self) -> NDArrayDimsProxy<'ctx> { + NDArrayDimsProxy(self.clone()) + } + + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` + /// on the field. + fn get_data_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.data.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.get_ptr(), + &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], + var_name.as_str(), + ) + } + } + + /// Stores the array of data elements `data` into this instance. + fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + ctx.builder.build_store(self.get_data_ptr(ctx), data); + } + + /// Convenience method for creating a new array storing data elements with the given element + /// type `elem_ty` and + /// `size`. + pub fn create_data( + &self, + ctx: &CodeGenContext<'ctx, '_>, + elem_ty: BasicTypeEnum<'ctx>, + size: IntValue<'ctx>, + ) { + self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "")); + } + + /// Returns a proxy object to the field storing the data of this `NDArray`. + pub fn get_data(&self) -> NDArrayDataProxy<'ctx> { + NDArrayDataProxy(self.clone()) + } +} + +impl<'ctx> Into> for NDArrayValue<'ctx> { + fn into(self) -> PointerValue<'ctx> { + self.get_ptr() + } +} + +/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM. +#[derive(Copy, Clone)] +pub struct NDArrayDimsProxy<'ctx>(NDArrayValue<'ctx>); + +impl<'ctx> NDArrayDimsProxy<'ctx> { + /// Returns the single-indirection pointer to the array. + pub fn get_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let var_name = self.0.1.map(|v| format!("{v}.dims")).unwrap_or_default(); + + ctx.builder.build_load(self.0.get_dims_ptr(ctx), var_name.as_str()).into_pointer_value() + } + + /// Returns the pointer to the size of the `idx`-th dimension. + pub fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let in_range = ctx.builder.build_int_compare( + IntPredicate::ULT, + idx, + self.0.load_ndims(ctx), + "" + ); + ctx.make_assert( + generator, + in_range, + "0:IndexError", + "index {0} is out of bounds for axis 0 with size {1}", + [Some(idx), Some(self.0.load_ndims(ctx)), None], + ctx.current_loc, + ); + + let var_name = name + .map(|v| format!("{v}.addr")) + .unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.get_ptr(ctx), + &[idx], + var_name.as_str(), + ) + } + } + + /// Returns the size of the `idx`-th dimension. + pub fn get( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> IntValue<'ctx> { + let ptr = self.ptr_offset(ctx, generator, idx, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()).into_int_value() + } +} + +/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM. +#[derive(Copy, Clone)] +pub struct NDArrayDataProxy<'ctx>(NDArrayValue<'ctx>); + +impl<'ctx> NDArrayDataProxy<'ctx> { + /// Returns the single-indirection pointer to the array. + pub fn get_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let var_name = self.0.1.map(|v| format!("{v}.data")).unwrap_or_default(); + + ctx.builder.build_load(self.0.get_data_ptr(ctx), var_name.as_str()).into_pointer_value() + } + + pub unsafe fn ptr_to_data_flattened_unchecked( + &self, + ctx: &CodeGenContext<'ctx, '_>, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + ctx.builder.build_in_bounds_gep( + self.get_ptr(ctx), + &[idx], + name.unwrap_or_default(), + ) + } + + /// Returns the pointer to the data at the `idx`-th flattened index. + pub fn ptr_to_data_flattened( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let ndims = self.0.load_ndims(ctx); + let dims = self.0.get_dims().get_ptr(ctx); + let data_sz = call_ndarray_calc_size(generator, ctx, ndims, dims); + + let in_range = ctx.builder.build_int_compare( + IntPredicate::ULT, + idx, + data_sz, + "" + ); + ctx.make_assert( + generator, + in_range, + "0:IndexError", + "index {0} is out of bounds with size {1}", + [Some(idx), Some(self.0.load_ndims(ctx)), None], + ctx.current_loc, + ); + + unsafe { + self.ptr_to_data_flattened_unchecked(ctx, idx, name) + } + } + + pub unsafe fn get_flattened_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_to_data_flattened_unchecked(ctx, idx, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } + + /// Returns the data at the `idx`-th flattened index. + pub fn get_flattened( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_to_data_flattened(ctx, generator, idx, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } + + pub unsafe fn ptr_offset_unchecked( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &dyn CodeGenerator, + indices: ListValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let indices_elem_ty = indices.get_data().get_ptr(ctx).get_type().get_element_type(); + let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { + panic!("Expected list[int32] but got {indices_elem_ty}") + }; + assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}"); + + let index = call_ndarray_flatten_index( + generator, + ctx, + self.0, + indices, + ).unwrap(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.get_ptr(ctx), + &[index], + name.unwrap_or_default(), + ) + } + } + + /// Returns the pointer to the data at the index specified by `indices`. + pub fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + indices: ListValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let llvm_usize = generator.get_size_type(ctx.ctx); + + let nidx_leq_ndims = ctx.builder.build_int_compare( + IntPredicate::SLE, + indices.load_size(ctx, None), + self.0.load_ndims(ctx), + "" + ); + ctx.make_assert( + generator, + nidx_leq_ndims, + "0:IndexError", + "invalid index to scalar variable", + [None, None, None], + ctx.current_loc, + ); + + gen_for_callback( + generator, + ctx, + |generator, ctx| { + let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; + ctx.builder.build_store(i, llvm_usize.const_zero()); + + Ok(i) + }, + |_, ctx, i_addr| { + let indices_len = indices.load_size(ctx, None); + let ndarray_len = self.0.load_ndims(ctx); + + let min_fn_name = format!("llvm.umin.i{}", llvm_usize.get_bit_width()); + let min_fn = ctx.module.get_function(min_fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[llvm_usize.into(), llvm_usize.into()], + false + ); + ctx.module.add_function(min_fn_name.as_str(), fn_type, None) + }); + + let len = ctx + .builder + .build_call(min_fn, &[indices_len.into(), ndarray_len.into()], "") + .try_as_basic_value() + .map_left(|v| v.into_int_value()) + .left() + .unwrap(); + + let i = ctx.builder.build_load(i_addr, "").into_int_value(); + Ok(ctx.builder.build_int_compare(IntPredicate::SLT, i, len, "")) + }, + |generator, ctx, i_addr| { + let i = ctx.builder.build_load(i_addr, "").into_int_value(); + let (dim_idx, dim_sz) = unsafe { + ( + indices.get_data().get_unchecked(ctx, i, None).into_int_value(), + self.0.get_dims().get(ctx, generator, i, None), + ) + }; + + let dim_lt = ctx.builder.build_int_compare( + IntPredicate::SLT, + dim_idx, + dim_sz, + "" + ); + + ctx.make_assert( + generator, + dim_lt, + "0:IndexError", + "index {0} is out of bounds for axis 0 with size {1}", + [Some(dim_idx), Some(dim_sz), None], + ctx.current_loc, + ); + + Ok(()) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), ""); + ctx.builder.build_store(i_addr, i); + + Ok(()) + }, + ).unwrap(); + + unsafe { + self.ptr_offset_unchecked(ctx, generator, indices, name) + } + } + + pub unsafe fn get_unsafe( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &dyn CodeGenerator, + indices: ListValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_offset_unchecked(ctx, generator, indices, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } + + /// Returns the data at the index specified by `indices`. + pub fn get( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + indices: ListValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_offset(ctx, generator, indices, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } +} diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index 8b28bc1ad..97969d1ef 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -268,4 +268,40 @@ void __nac3_ndarray_calc_nd_indices64( idxs[i] = (index / stride) % dims[i]; stride *= dims[i]; } -} \ No newline at end of file +} + +uint32_t __nac3_ndarray_flatten_index( + const uint32_t* dims, + uint32_t num_dims, + const uint32_t* indices, + uint32_t num_indices +) { + uint32_t idx = 0; + uint32_t stride = 1; + for (uint32_t i = num_dims - 1; i-- >= 0; ) { + if (i < num_indices) { + idx += (stride * indices[i]); + } + + stride *= dims[i]; + } + return idx; +} + +uint64_t __nac3_ndarray_flatten_index64( + const uint64_t* dims, + uint64_t num_dims, + const uint32_t* indices, + uint64_t num_indices +) { + uint64_t idx = 0; + uint64_t stride = 1; + for (uint64_t i = num_dims - 1; i-- >= 0; ) { + if (i < num_indices) { + idx += (stride * indices[i]); + } + + stride *= dims[i]; + } + return idx; +} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 3b21dd9c4..5702ba93e 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,8 +1,7 @@ use crate::typecheck::typedef::Type; use super::{ - classes::ListValue, - assert_is_ndarray, + classes::{ListValue, NDArrayValue}, CodeGenContext, CodeGenerator, }; @@ -607,7 +606,7 @@ pub fn call_ndarray_calc_size<'ctx>( pub fn call_ndarray_init_dims<'ctx>( generator: &dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: PointerValue<'ctx>, + ndarray: NDArrayValue<'ctx>, shape: ListValue<'ctx>, ) { let llvm_void = ctx.ctx.void_type(); @@ -617,8 +616,6 @@ pub fn call_ndarray_init_dims<'ctx>( let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - assert_is_ndarray(ndarray); - let ndarray_init_dims_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_init_dims", 64 => "__nac3_ndarray_init_dims64", @@ -637,22 +634,14 @@ pub fn call_ndarray_init_dims<'ctx>( ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None) }); - let ndarray_dims = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - None, - ); + let ndarray_dims = ndarray.get_dims(); let shape_data = shape.get_data(); - let ndarray_num_dims = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_int_value(); + let ndarray_num_dims = ndarray.load_ndims(ctx); ctx.builder.build_call( ndarray_init_dims_fn, &[ - ndarray_dims.into(), + ndarray_dims.get_ptr(ctx).into(), shape_data.get_ptr(ctx).into(), ndarray_num_dims.into(), ], @@ -669,12 +658,9 @@ pub fn call_ndarray_calc_nd_indices<'ctx>( generator: &dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, '_>, index: IntValue<'ctx>, - ndarray: PointerValue<'ctx>, + ndarray: NDArrayValue<'ctx>, ) -> Result, String> { - assert_is_ndarray(ndarray); - let llvm_void = ctx.ctx.void_type(); - let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); @@ -698,16 +684,8 @@ pub fn call_ndarray_calc_nd_indices<'ctx>( ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) }); - let ndarray_num_dims = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_int_value(); - let ndarray_dims = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - None, - ).into_pointer_value(); + let ndarray_num_dims = ndarray.load_ndims(ctx); + let ndarray_dims = ndarray.get_dims(); let indices = ctx.builder.build_array_alloca( llvm_usize, @@ -719,7 +697,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx>( ndarray_calc_nd_indices_fn, &[ index.into(), - ndarray_dims.into(), + ndarray_dims.get_ptr(ctx).into(), ndarray_num_dims.into(), indices.into(), ], @@ -727,4 +705,64 @@ pub fn call_ndarray_calc_nd_indices<'ctx>( ); Ok(indices) +} + +/// Generates a call to `__nac3_ndarray_flatten_index`. +/// +/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an +/// `NDArray`. +/// * `indices` - The multidimensional index to compute the flattened index for. +pub fn call_ndarray_flatten_index<'ctx>( + generator: &dyn CodeGenerator, + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + indices: ListValue<'ctx>, +) -> Result, String> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_flatten_index", + 64 => "__nac3_ndarray_flatten_index64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[ + llvm_usize.into(), + llvm_pusize.into(), + llvm_pi32.into(), + llvm_pusize.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) + }); + + let ndarray_num_dims = ndarray.load_ndims(ctx); + let ndarray_dims = ndarray.get_dims(); + let indices_size = indices.load_size(ctx, None); + let indices_data = indices.get_data(); + + let index = ctx.builder + .build_call( + ndarray_flatten_index_fn, + &[ + ndarray_num_dims.into(), + ndarray_dims.get_ptr(ctx).into(), + indices_size.into(), + indices_data.get_ptr(ctx).into(), + ], + "", + ) + .try_as_basic_value() + .map_left(|v| v.into_int_value()) + .left() + .unwrap(); + + Ok(index) } \ No newline at end of file diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 2102076d7..66c1d4f4f 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -34,9 +34,6 @@ use std::sync::{ }; use std::thread; -#[cfg(debug_assertions)] -use inkwell::types::AnyTypeEnum; - pub mod classes; pub mod concrete_type; pub mod expr; @@ -999,27 +996,3 @@ fn gen_in_range_check<'ctx>( ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") } - -/// Checks whether the pointer `value` refers to an `NDArray` in LLVM. -fn assert_is_ndarray(value: PointerValue) -> PointerValue { - #[cfg(debug_assertions)] - { - let llvm_ndarray_ty = value.get_type().get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}") - }; - - assert_eq!(llvm_ndarray_ty.count_fields(), 3); - assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..)))); - let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else { - unreachable!() - }; - let BasicTypeEnum::PointerType(dims) = ndarray_dims else { - panic!("Expected pointer type for `list.1`, but got {ndarray_dims}") - }; - assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..))); - assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..)))); - } - - value -} diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index 4b6f49680..820b575bc 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -3,7 +3,7 @@ use inkwell::values::{ArrayValue, IntValue}; use nac3parser::ast::StrRef; use crate::{ codegen::{ - classes::ListValue, + classes::{ListValue, NDArrayValue}, CodeGenContext, CodeGenerator, irrt::{ @@ -27,11 +27,10 @@ fn create_ndarray_const_shape<'ctx, 'a>( ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, shape: ArrayValue<'ctx> -) -> Result, String> { +) -> Result, String> { let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); - let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); @@ -68,54 +67,18 @@ fn create_ndarray_const_shape<'ctx, 'a>( llvm_ndarray_t.into(), None, )?; + let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None); let num_dims = llvm_usize.const_int(shape.get_type().len() as u64, false); + ndarray.store_ndims(ctx, generator, num_dims); - let ndarray_num_dims = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - "", - ) - }; - ctx.builder.build_store(ndarray_num_dims, num_dims); - - let ndarray_dims = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - "", - ) - }; - - let ndarray_num_dims = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_int_value(); - - ctx.builder.build_store( - ndarray_dims, - ctx.builder.build_array_alloca( - llvm_usize, - ndarray_num_dims, - "", - ), - ); + let ndarray_num_dims = ndarray.load_ndims(ctx); + ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims); for i in 0..shape.get_type().len() { - let ndarray_dim = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - None, - ).into_pointer_value(); - let ndarray_dim = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray_dim, - &[llvm_i32.const_int(i as u64, true)], - "", - ) - }; + let ndarray_dim = ndarray + .get_dims() + .ptr_offset(ctx, generator, llvm_usize.const_int(i as u64, true), None); let shape_dim = ctx.builder.build_extract_value(shape, i, "") .map(|val| val.into_int_value()) .unwrap(); @@ -123,42 +86,14 @@ fn create_ndarray_const_shape<'ctx, 'a>( ctx.builder.build_store(ndarray_dim, shape_dim); } - let (ndarray_num_dims, ndarray_dims) = unsafe { - ( - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - "" - ), - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - "" - ), - ) - }; + let ndarray_dims = ndarray.get_dims().get_ptr(ctx); let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, - ctx.builder.build_load(ndarray_num_dims, "").into_int_value(), - ctx.builder.build_load(ndarray_dims, "").into_pointer_value(), - ); - - let ndarray_data = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], - "", - ) - }; - ctx.builder.build_store( - ndarray_data, - ctx.builder.build_array_alloca( - llvm_ndarray_data_t, - ndarray_num_elems, - "" - ), + ndarray.load_ndims(ctx), + ndarray_dims, ); + ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); Ok(ndarray) } @@ -214,7 +149,7 @@ fn call_ndarray_empty_impl<'ctx, 'a>( ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, shape: ListValue<'ctx>, -) -> Result, String> { +) -> Result, String> { let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); @@ -284,79 +219,23 @@ fn call_ndarray_empty_impl<'ctx, 'a>( llvm_ndarray_t.into(), None, )?; + let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None); let num_dims = shape.load_size(ctx, None); + ndarray.store_ndims(ctx, generator, num_dims); - let ndarray_num_dims = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - "", - ) - }; - ctx.builder.build_store(ndarray_num_dims, num_dims); - - let ndarray_dims = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - "", - ) - }; - - let ndarray_num_dims = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_int_value(); - - ctx.builder.build_store( - ndarray_dims, - ctx.builder.build_array_alloca( - llvm_usize, - ndarray_num_dims, - "", - ), - ); + let ndarray_num_dims = ndarray.load_ndims(ctx); + ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims); call_ndarray_init_dims(generator, ctx, ndarray, shape); - let (ndarray_num_dims, ndarray_dims) = unsafe { - ( - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - "" - ), - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - "" - ), - ) - }; let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, - ctx.builder.build_load(ndarray_num_dims, "").into_int_value(), - ctx.builder.build_load(ndarray_dims, "").into_pointer_value(), - ); - - let ndarray_data = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], - "", - ) - }; - ctx.builder.build_store( - ndarray_data, - ctx.builder.build_array_alloca( - llvm_ndarray_data_t, - ndarray_num_elems, - "", - ), + ndarray.load_ndims(ctx), + ndarray.get_dims().get_ptr(ctx), ); + ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); Ok(ndarray) } @@ -369,35 +248,19 @@ fn call_ndarray_empty_impl<'ctx, 'a>( fn ndarray_fill_flattened<'ctx, 'a, ValueFn>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: PointerValue<'ctx>, + ndarray: NDArrayValue<'ctx>, value_fn: ValueFn, ) -> Result<(), String> where ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result, String>, { - let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let (num_dims, dims) = unsafe { - ( - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - "" - ), - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - "" - ), - ) - }; - let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, - ctx.builder.build_load(num_dims, "").into_int_value(), - ctx.builder.build_load(dims, "").into_pointer_value(), + ndarray.load_ndims(ctx), + ndarray.get_dims().get_ptr(ctx), ); gen_for_callback( @@ -417,21 +280,11 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>( Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, "")) }, |generator, ctx, i_addr| { - let ndarray_data = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], - None - ).into_pointer_value(); - let i = ctx.builder .build_load(i_addr, "") .into_int_value(); let elem = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray_data, - &[i], - "" - ) + ndarray.get_data().ptr_to_data_flattened_unchecked(ctx, i, None) }; let value = value_fn(generator, ctx, i)?; @@ -459,7 +312,7 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>( fn ndarray_fill_indexed<'ctx, 'a, ValueFn>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: PointerValue<'ctx>, + ndarray: NDArrayValue<'ctx>, value_fn: ValueFn, ) -> Result<(), String> where @@ -491,7 +344,7 @@ fn call_ndarray_zeros_impl<'ctx, 'a>( ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, shape: ListValue<'ctx>, -) -> Result, String> { +) -> Result, String> { let supported_types = [ ctx.primitives.int32, ctx.primitives.int64, @@ -527,7 +380,7 @@ fn call_ndarray_ones_impl<'ctx, 'a>( ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, shape: ListValue<'ctx>, -) -> Result, String> { +) -> Result, String> { let supported_types = [ ctx.primitives.int32, ctx.primitives.int64, @@ -564,7 +417,7 @@ fn call_ndarray_full_impl<'ctx, 'a>( elem_ty: Type, shape: ListValue<'ctx>, fill_value: BasicValueEnum<'ctx>, -) -> Result, String> { +) -> Result, String> { let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; ndarray_fill_flattened( generator, @@ -633,7 +486,7 @@ fn call_ndarray_eye_impl<'ctx, 'a>( nrows: IntValue<'ctx>, ncols: IntValue<'ctx>, offset: IntValue<'ctx>, -) -> Result, String> { +) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize_2 = llvm_usize.array_type(2); @@ -718,7 +571,7 @@ pub fn gen_ndarray_empty<'ctx, 'a>( context, context.primitives.float, ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - ) + ).map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.zeros`. @@ -742,7 +595,7 @@ pub fn gen_ndarray_zeros<'ctx, 'a>( context, context.primitives.float, ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - ) + ).map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.ones`. @@ -766,7 +619,7 @@ pub fn gen_ndarray_ones<'ctx, 'a>( context, context.primitives.float, ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - ) + ).map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.full`. @@ -794,7 +647,7 @@ pub fn gen_ndarray_full<'ctx, 'a>( fill_value_ty, ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), fill_value_arg, - ) + ).map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.eye`. @@ -839,7 +692,7 @@ pub fn gen_ndarray_eye<'ctx, 'a>( nrows_arg.into_int_value(), ncols_arg.into_int_value(), offset_arg.into_int_value(), - ) + ).map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.identity`. @@ -866,5 +719,5 @@ pub fn gen_ndarray_identity<'ctx, 'a>( n_arg.into_int_value(), n_arg.into_int_value(), llvm_usize.const_zero(), - ) + ).map(NDArrayValue::into) } \ No newline at end of file