From 5ee08b585f07e267629bf9020802ed2d6f85ca1f Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 23 Jan 2024 17:21:24 +0800 Subject: [PATCH 1/3] core: Add ListValue and helper functions --- nac3core/src/codegen/classes.rs | 225 +++++++++++++++++++++++++++++++ nac3core/src/codegen/expr.rs | 89 +++++------- nac3core/src/codegen/irrt/mod.rs | 45 +++---- nac3core/src/codegen/mod.rs | 17 +-- nac3core/src/codegen/stmt.rs | 27 ++-- nac3core/src/toplevel/numpy.rs | 47 +++---- 6 files changed, 308 insertions(+), 142 deletions(-) create mode 100644 nac3core/src/codegen/classes.rs diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs new file mode 100644 index 0000000..66f74ac --- /dev/null +++ b/nac3core/src/codegen/classes.rs @@ -0,0 +1,225 @@ +use inkwell::{ + IntPredicate, + types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}, + values::{BasicValueEnum, IntValue, PointerValue}, +}; +use crate::codegen::{CodeGenContext, CodeGenerator}; + +#[cfg(not(debug_assertions))] +pub fn assert_is_list<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {} + +#[cfg(debug_assertions)] +pub fn assert_is_list<'ctx>(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) { + if let Err(msg) = ListValue::is_instance(value, llvm_usize) { + panic!("{msg}") + } +} + +/// Proxy type for accessing a `list` value in LLVM. +#[derive(Copy, Clone)] +pub struct ListValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>); + +impl<'ctx> ListValue<'ctx> { + /// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an + /// instance. + pub fn is_instance( + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + let llvm_list_ty = value.get_type().get_element_type(); + let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { + panic!("Expected struct type for `list` type, got {llvm_list_ty}") + }; + if llvm_list_ty.count_fields() != 2 { + return Err(format!("Expected 2 fields in `list`, got {}", llvm_list_ty.count_fields())) + } + + let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap(); + let Ok(_) = PointerType::try_from(list_size_ty) else { + return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}")) + }; + + let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap(); + let Ok(list_data_ty) = IntType::try_from(list_data_ty) else { + return Err(format!("Expected int type for `list.1`, got {list_data_ty}")) + }; + if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() { + return Err(format!("Expected {}-bit int type for `list.1`, got {}-bit int", + llvm_usize.get_bit_width(), + list_data_ty.get_bit_width())) + } + + Ok(()) + } + + /// Creates an [ListValue] from a [PointerValue]. + pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self { + assert_is_list(ptr, llvm_usize); + ListValue(ptr, name) + } + + /// Returns the underlying [PointerValue] pointing to the `list` instance. + pub fn get_ptr(&self) -> PointerValue<'ctx> { + self.0 + } + + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` + /// on the field. + fn get_data_pptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.data.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.get_ptr(), + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + var_name.as_str(), + ) + } + } + + /// Returns the pointer to the field storing the size of this `list`. + fn get_size_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.size.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.0, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + var_name.as_str(), + ) + } + } + + /// Stores the array of data elements `data` into this instance. + fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + ctx.builder.build_store(self.get_data_pptr(ctx), data); + } + + /// Convenience method for creating a new array storing data elements with the given element + /// type `elem_ty` and `size`. + /// + /// If `size` is [None], the size stored in the field of this instance is used instead. + pub fn create_data( + &self, + ctx: &CodeGenContext<'ctx, '_>, + elem_ty: BasicTypeEnum<'ctx>, + size: Option>, + ) { + let size = size.unwrap_or_else(|| self.load_size(ctx, None)); + self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "")); + } + + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` + /// on the field. + pub fn get_data(&self) -> ListDataProxy<'ctx> { + ListDataProxy(self.clone()) + } + + /// Stores the `size` of this `list` into this instance. + pub fn store_size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &dyn CodeGenerator, + size: IntValue<'ctx>, + ) { + debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx)); + + let psize = self.get_size_ptr(ctx); + ctx.builder.build_store(psize, size); + } + + /// Returns the size of this `list` as a value. + pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { + let psize = self.get_size_ptr(ctx); + let var_name = name + .map(|v| v.to_string()) + .or_else(|| self.1.map(|v| format!("{v}.size"))) + .unwrap_or_default(); + + ctx.builder.build_load(psize, var_name.as_str()).into_int_value() + } +} + +/// Proxy type for accessing the `data` array of an `list` instance in LLVM. +#[derive(Copy, Clone)] +pub struct ListDataProxy<'ctx>(ListValue<'ctx>); + +impl<'ctx> ListDataProxy<'ctx> { + /// Returns the single-indirection pointer to the array. + pub fn get_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let var_name = self.0.1.map(|v| format!("{v}.data")).unwrap_or_default(); + + ctx.builder.build_load(self.0.get_data_pptr(ctx), var_name.as_str()).into_pointer_value() + } + + pub unsafe fn ptr_offset_unchecked( + &self, + ctx: &CodeGenContext<'ctx, '_>, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let var_name = name + .map(|v| format!("{v}.addr")) + .unwrap_or_default(); + + ctx.builder.build_in_bounds_gep( + self.get_ptr(ctx), + &[idx], + var_name.as_str(), + ) + } + + /// Returns the pointer to the data at the `idx`-th index. + pub fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); + + let in_range = ctx.builder.build_int_compare( + IntPredicate::ULT, + idx, + self.0.load_size(ctx, None), + "" + ); + ctx.make_assert( + generator, + in_range, + "0:IndexError", + "list index out of range", + [None, None, None], + ctx.current_loc, + ); + + unsafe { + self.ptr_offset_unchecked(ctx, idx, name) + } + } + + pub unsafe fn get_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_offset_unchecked(ctx, idx, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } + + /// Returns the data at the `idx`-th flattened index. + pub fn get( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_offset(ctx, generator, idx, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } +} diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 79ebe4a..734e945 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use crate::{ codegen::{ + classes::ListValue, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_type, @@ -896,43 +897,26 @@ pub fn allocate_list<'ctx, G: CodeGenerator>( ty: BasicTypeEnum<'ctx>, length: IntValue<'ctx>, name: Option<&str>, -) -> PointerValue<'ctx> { +) -> ListValue<'ctx> { let size_t = generator.get_size_type(ctx.ctx); - let i32_t = ctx.ctx.i32_type(); // List structure; type { ty*, size_t } let arr_ty = ctx.ctx .struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false); - let zero = ctx.ctx.i32_type().const_zero(); let arr_str_ptr = ctx.builder.build_alloca( arr_ty, format!("{}.addr", name.unwrap_or("list")).as_str() ); + let list = ListValue::from_ptr_val(arr_str_ptr, size_t, Some("list")); - unsafe { - // Pointer to the `length` element of the list structure - let len_ptr = ctx.builder.build_in_bounds_gep( - arr_str_ptr, - &[zero, i32_t.const_int(1, false)], - "" - ); - let length = ctx.builder.build_int_z_extend( - length, - size_t, - "" - ); - ctx.builder.build_store(len_ptr, length); + let length = ctx.builder.build_int_z_extend( + length, + size_t, + "" + ); + list.store_size(ctx, generator, length); + list.create_data(ctx, ty, None); - // Pointer to the `data` element of the list structure - let arr_ptr = ctx.builder.build_array_alloca(ty, length, ""); - let ptr_to_arr = ctx.builder.build_in_bounds_gep( - arr_str_ptr, - &[zero, i32_t.const_zero()], - "" - ); - ctx.builder.build_store(ptr_to_arr, arr_ptr); - } - - arr_str_ptr + list } /// Generates LLVM IR for a [list comprehension expression][expr]. @@ -1006,8 +990,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( list_alloc_size.into_int_value(), Some("listcomp.addr") ); - list_content = ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("listcomp.data.addr")) - .into_pointer_value(); + list_content = list.get_data().get_ptr(ctx); let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init")); @@ -1042,8 +1025,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( ) .into_int_value(); list = allocate_list(generator, ctx, elem_ty, length, Some("listcomp")); - list_content = - ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("list_content")).into_pointer_value(); + list_content = list.get_data().get_ptr(ctx); let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; // counter = -1 ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true)); @@ -1065,12 +1047,9 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( } // Emits the content of `cont_bb` - let emit_cont_bb = |ctx: &CodeGenContext| { + let emit_cont_bb = |ctx: &CodeGenContext<'ctx, '_>, generator: &dyn CodeGenerator, list: ListValue<'ctx>| { ctx.builder.position_at_end(cont_bb); - let len_ptr = unsafe { - ctx.builder.build_gep(list, &[zero_size_t, int32.const_int(1, false)], "length") - }; - ctx.builder.build_store(len_ptr, ctx.builder.build_load(index, "index")); + list.store_size(ctx, generator, ctx.builder.build_load(index, "index").into_int_value()); }; for cond in ifs { @@ -1079,7 +1058,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( } else { // Bail if the predicate is an ellipsis - Emit cont_bb contents in case the // no element matches the predicate - emit_cont_bb(ctx); + emit_cont_bb(ctx, generator, list); return Ok(None) }; @@ -1092,7 +1071,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( let Some(elem) = generator.gen_expr(ctx, elt)? else { // Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents - emit_cont_bb(ctx); + emit_cont_bb(ctx, generator, list); return Ok(None) }; @@ -1104,9 +1083,9 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( .build_store(index, ctx.builder.build_int_add(i, size_t.const_int(1, false), "inc")); ctx.builder.build_unconditional_branch(test_bb); - emit_cont_bb(ctx); + emit_cont_bb(ctx, generator, list); - Ok(Some(list.into())) + Ok(Some(list.get_ptr().into())) } /// Generates LLVM IR for a [binary operator expression][expr]. @@ -1226,6 +1205,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ) -> Result>, String> { ctx.current_loc = expr.location; let int32 = ctx.ctx.i32_type(); + let usize = generator.get_size_type(ctx.ctx); let zero = int32.const_int(0, false); let loc = ctx.debug_info.0.create_debug_location( @@ -1296,19 +1276,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( }; let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false); let arr_str_ptr = allocate_list(generator, ctx, ty, length, Some("list")); - let arr_ptr = ctx.build_gep_and_load(arr_str_ptr, &[zero, zero], Some("list.ptr.addr")) - .into_pointer_value(); - unsafe { - for (i, v) in elements.iter().enumerate() { - let elem_ptr = ctx.builder.build_gep( - arr_ptr, - &[int32.const_int(i as u64, false)], - "elem_ptr", - ); - ctx.builder.build_store(elem_ptr, *v); - } + let arr_ptr = arr_str_ptr.get_data(); + for (i, v) in elements.iter().enumerate() { + let elem_ptr = arr_ptr + .ptr_offset(ctx, generator, usize.const_int(i as u64, false), Some("elem_ptr")); + ctx.builder.build_store(elem_ptr, *v); } - arr_str_ptr.into() + arr_str_ptr.get_ptr().into() } ExprKind::Tuple { elts, .. } => { let elements_val = elts @@ -1758,9 +1732,8 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { return Ok(None) }; + let v = ListValue::from_ptr_val(v, usize, Some("arr")); let ty = ctx.get_llvm_type(generator, *ty); - let arr_ptr = ctx.build_gep_and_load(v, &[zero, zero], Some("arr.addr")) - .into_pointer_value(); if let ExprKind::Slice { lower, upper, step } = &slice.node { let one = int32.const_int(1, false); let Some((start, end, step)) = @@ -1800,11 +1773,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v, (start, end, step), ); - res_array_ret.into() + res_array_ret.get_ptr().into() } else { - let len = ctx - .build_gep_and_load(v, &[zero, int32.const_int(1, false)], Some("len")) - .into_int_value(); + let len = v.load_size(ctx, Some("len")); let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() } else { @@ -1843,7 +1814,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( [Some(raw_index), Some(len), None], expr.location, ); - ctx.build_gep_and_load(arr_ptr, &[index], None).into() + v.get_data().get(ctx, generator, index, None).into() } } TypeEnum::TNDArray { .. } => { diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 530ce0a..3b21dd9 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,6 +1,11 @@ use crate::typecheck::typedef::Type; -use super::{assert_is_list, assert_is_ndarray, CodeGenContext, CodeGenerator}; +use super::{ + classes::ListValue, + assert_is_ndarray, + CodeGenContext, + CodeGenerator, +}; use inkwell::{ attributes::{Attribute, AttributeLoc}, context::Context, @@ -158,12 +163,12 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( step: &Option>>>, ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut G, - list: PointerValue<'ctx>, + list: ListValue<'ctx>, ) -> Result, IntValue<'ctx>, IntValue<'ctx>)>, String> { let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); let one = int32.const_int(1, false); - let length = ctx.build_gep_and_load(list, &[zero, one], Some("length")).into_int_value(); + let length = list.load_size(ctx, Some("length")); let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32"); Ok(Some(match (start, end, step) { (s, e, None) => ( @@ -295,9 +300,9 @@ pub fn list_slice_assignment<'ctx>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, '_>, ty: BasicTypeEnum<'ctx>, - dest_arr: PointerValue<'ctx>, + dest_arr: ListValue<'ctx>, dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), - src_arr: PointerValue<'ctx>, + src_arr: ListValue<'ctx>, src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), ) { let size_ty = generator.get_size_type(ctx.ctx); @@ -326,21 +331,21 @@ pub fn list_slice_assignment<'ctx>( let zero = int32.const_zero(); let one = int32.const_int(1, false); - let dest_arr_ptr = ctx.build_gep_and_load(dest_arr, &[zero, zero], Some("dest.addr")); + let dest_arr_ptr = dest_arr.get_data().get_ptr(ctx); let dest_arr_ptr = ctx.builder.build_pointer_cast( - dest_arr_ptr.into_pointer_value(), + dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast", ); - let dest_len = ctx.build_gep_and_load(dest_arr, &[zero, one], Some("dest.len")).into_int_value(); + let dest_len = dest_arr.load_size(ctx, Some("dest.len")); let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32"); - let src_arr_ptr = ctx.build_gep_and_load(src_arr, &[zero, zero], Some("src.addr")); + let src_arr_ptr = src_arr.get_data().get_ptr(ctx); let src_arr_ptr = ctx.builder.build_pointer_cast( - src_arr_ptr.into_pointer_value(), + src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast", ); - let src_len = ctx.build_gep_and_load(src_arr, &[zero, one], Some("src.len")).into_int_value(); + let src_len = src_arr.load_size(ctx, Some("src.len")); let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32"); // index in bound and positive should be done @@ -443,9 +448,8 @@ pub fn list_slice_assignment<'ctx>( let cont_bb = ctx.ctx.append_basic_block(current, "cont"); ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb); ctx.builder.position_at_end(update_bb); - let dest_len_ptr = unsafe { ctx.builder.build_gep(dest_arr, &[zero, one], "dest_len_ptr") }; let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len"); - ctx.builder.build_store(dest_len_ptr, new_len); + dest_arr.store_size(ctx, generator, new_len); ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.position_at_end(cont_bb); } @@ -604,11 +608,8 @@ pub fn call_ndarray_init_dims<'ctx>( generator: &dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, '_>, ndarray: PointerValue<'ctx>, - shape: PointerValue<'ctx>, + shape: ListValue<'ctx>, ) { - assert_is_ndarray(ndarray); - assert_is_list(shape); - let llvm_void = ctx.ctx.void_type(); let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -616,6 +617,8 @@ pub fn call_ndarray_init_dims<'ctx>( let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + assert_is_ndarray(ndarray); + let ndarray_init_dims_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_init_dims", 64 => "__nac3_ndarray_init_dims64", @@ -639,11 +642,7 @@ pub fn call_ndarray_init_dims<'ctx>( &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], None, ); - let shape_data = ctx.build_gep_and_load( - shape, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None - ); + let shape_data = shape.get_data(); let ndarray_num_dims = ctx.build_gep_and_load( ndarray, &[llvm_i32.const_zero(), llvm_i32.const_zero()], @@ -654,7 +653,7 @@ pub fn call_ndarray_init_dims<'ctx>( ndarray_init_dims_fn, &[ ndarray_dims.into(), - shape_data.into(), + shape_data.get_ptr(ctx).into(), ndarray_num_dims.into(), ], "", diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index b1836a0..2102076 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -37,6 +37,7 @@ use std::thread; #[cfg(debug_assertions)] use inkwell::types::AnyTypeEnum; +pub mod classes; pub mod concrete_type; pub mod expr; mod generator; @@ -999,22 +1000,6 @@ fn gen_in_range_check<'ctx>( ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") } -/// Checks whether the pointer `value` refers to a `list` in LLVM. -fn assert_is_list(value: PointerValue) -> PointerValue { - #[cfg(debug_assertions)] - { - let llvm_shape_ty = value.get_type().get_element_type(); - let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else { - panic!("Expected struct type for `list` type, but got {llvm_shape_ty}") - }; - assert_eq!(llvm_shape_ty.count_fields(), 2); - assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..)))); - assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..)))); - } - - value -} - /// Checks whether the pointer `value` refers to an `NDArray` in LLVM. fn assert_is_ndarray(value: PointerValue) -> PointerValue { #[cfg(debug_assertions)] diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 1cf57b2..8ac5a7d 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -6,6 +6,7 @@ use super::{ }; use crate::{ codegen::{ + classes::ListValue, expr::gen_binop_expr, gen_in_range_check, }, @@ -92,6 +93,8 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( pattern: &Expr>, name: Option<&str>, ) -> Result>, String> { + let llvm_usize = generator.get_size_type(ctx.ctx); + // very similar to gen_expr, but we don't do an extra load at the end // and we flatten nested tuples Ok(Some(match &pattern.node { @@ -132,16 +135,13 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( ExprKind::Subscript { value, slice, .. } => { match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() { TypeEnum::TList { .. } => { - let i32_type = ctx.ctx.i32_type(); - let zero = i32_type.const_zero(); let v = generator .gen_expr(ctx, value)? .unwrap() .to_basic_value_enum(ctx, generator, value.custom.unwrap())? .into_pointer_value(); - let len = ctx - .build_gep_and_load(v, &[zero, i32_type.const_int(1, false)], Some("len")) - .into_int_value(); + let v = ListValue::from_ptr_val(v, llvm_usize, None); + let len = v.load_size(ctx, Some("len")); let raw_index = generator .gen_expr(ctx, slice)? .unwrap() @@ -180,12 +180,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( [Some(raw_index), Some(len), None], slice.location, ); - unsafe { - let arr_ptr = ctx - .build_gep_and_load(v, &[i32_type.const_zero(), i32_type.const_zero()], Some("arr.addr")) - .into_pointer_value(); - ctx.builder.build_gep(arr_ptr, &[index], name.unwrap_or("")) - } + v.get_data().ptr_offset(ctx, generator, index, name) } TypeEnum::TNDArray { .. } => { @@ -206,6 +201,8 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( target: &Expr>, value: ValueEnum<'ctx>, ) -> Result<(), String> { + let llvm_usize = generator.get_size_type(ctx.ctx); + match &target.node { ExprKind::Tuple { elts, .. } => { let BasicValueEnum::StructValue(v) = @@ -233,6 +230,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( .unwrap() .to_basic_value_enum(ctx, generator, ls.custom.unwrap())? .into_pointer_value(); + let ls = ListValue::from_ptr_val(ls, llvm_usize, None); let Some((start, end, step)) = handle_slice_indices(lower, upper, step, ctx, generator, ls)? else { return Ok(()) @@ -240,9 +238,10 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( let value = value .to_basic_value_enum(ctx, generator, target.custom.unwrap())? .into_pointer_value(); - let (TypeEnum::TList { ty } | TypeEnum::TNDArray { ty, .. }) = &*ctx.unifier.get_ty(target.custom.unwrap()) else { - unreachable!() - }; + let value = ListValue::from_ptr_val(value, llvm_usize, None); + let (TypeEnum::TList { ty } | TypeEnum::TNDArray { ty, .. }) = &*ctx.unifier.get_ty(target.custom.unwrap()) else { + unreachable!() + }; let ty = ctx.get_llvm_type(generator, *ty); let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else { diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index 13bb8a5..4b6f496 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -3,6 +3,7 @@ use inkwell::values::{ArrayValue, IntValue}; use nac3parser::ast::StrRef; use crate::{ codegen::{ + classes::ListValue, CodeGenContext, CodeGenerator, irrt::{ @@ -212,7 +213,7 @@ fn call_ndarray_empty_impl<'ctx, 'a>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, - shape: PointerValue<'ctx>, + shape: ListValue<'ctx>, ) -> Result, String> { let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); @@ -239,29 +240,15 @@ fn call_ndarray_empty_impl<'ctx, 'a>( let i = ctx.builder .build_load(i_addr, "") .into_int_value(); - let shape_len = ctx.build_gep_and_load( - shape, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - None, - ).into_int_value(); + let shape_len = shape.load_size(ctx, None); - Ok(ctx.builder.build_int_compare(IntPredicate::ULE, i, shape_len, "")) + Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "")) }, |generator, ctx, i_addr| { - let shape_elems = ctx.build_gep_and_load( - shape, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None - ).into_pointer_value(); - let i = ctx.builder .build_load(i_addr, "") .into_int_value(); - let shape_dim = ctx.build_gep_and_load( - shape_elems, - &[i], - None - ).into_int_value(); + let shape_dim = shape.get_data().get(ctx, generator, i, None).into_int_value(); let shape_dim_gez = ctx.builder.build_int_compare( IntPredicate::SGE, @@ -298,11 +285,7 @@ fn call_ndarray_empty_impl<'ctx, 'a>( None, )?; - let num_dims = ctx.build_gep_and_load( - shape, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - None - ).into_int_value(); + let num_dims = shape.load_size(ctx, None); let ndarray_num_dims = unsafe { ctx.builder.build_in_bounds_gep( @@ -507,7 +490,7 @@ fn call_ndarray_zeros_impl<'ctx, 'a>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, - shape: PointerValue<'ctx>, + shape: ListValue<'ctx>, ) -> Result, String> { let supported_types = [ ctx.primitives.int32, @@ -543,7 +526,7 @@ fn call_ndarray_ones_impl<'ctx, 'a>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, - shape: PointerValue<'ctx>, + shape: ListValue<'ctx>, ) -> Result, String> { let supported_types = [ ctx.primitives.int32, @@ -579,7 +562,7 @@ fn call_ndarray_full_impl<'ctx, 'a>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, - shape: PointerValue<'ctx>, + shape: ListValue<'ctx>, fill_value: BasicValueEnum<'ctx>, ) -> Result, String> { let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; @@ -725,6 +708,7 @@ pub fn gen_ndarray_empty<'ctx, 'a>( assert!(obj.is_none()); assert_eq!(args.len(), 1); + let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone() .to_basic_value_enum(context, generator, shape_ty)?; @@ -733,7 +717,7 @@ pub fn gen_ndarray_empty<'ctx, 'a>( generator, context, context.primitives.float, - shape_arg.into_pointer_value(), + ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), ) } @@ -748,6 +732,7 @@ pub fn gen_ndarray_zeros<'ctx, 'a>( assert!(obj.is_none()); assert_eq!(args.len(), 1); + let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone() .to_basic_value_enum(context, generator, shape_ty)?; @@ -756,7 +741,7 @@ pub fn gen_ndarray_zeros<'ctx, 'a>( generator, context, context.primitives.float, - shape_arg.into_pointer_value(), + ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), ) } @@ -771,6 +756,7 @@ pub fn gen_ndarray_ones<'ctx, 'a>( assert!(obj.is_none()); assert_eq!(args.len(), 1); + let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone() .to_basic_value_enum(context, generator, shape_ty)?; @@ -779,7 +765,7 @@ pub fn gen_ndarray_ones<'ctx, 'a>( generator, context, context.primitives.float, - shape_arg.into_pointer_value(), + ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), ) } @@ -794,6 +780,7 @@ pub fn gen_ndarray_full<'ctx, 'a>( assert!(obj.is_none()); assert_eq!(args.len(), 2); + let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone() .to_basic_value_enum(context, generator, shape_ty)?; @@ -805,7 +792,7 @@ pub fn gen_ndarray_full<'ctx, 'a>( generator, context, fill_value_ty, - shape_arg.into_pointer_value(), + ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), fill_value_arg, ) } -- 2.44.1 From 148900302e92d2b9dc2aedaefda98415d80afae9 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 23 Jan 2024 18:27:00 +0800 Subject: [PATCH 2/3] core: Add RangeValue and helper functions --- nac3core/src/codegen/classes.rs | 159 +++++++++++++++++++++++++++++- nac3core/src/codegen/expr.rs | 19 ++-- nac3core/src/codegen/stmt.rs | 4 +- nac3core/src/toplevel/builtins.rs | 3 +- 4 files changed, 168 insertions(+), 17 deletions(-) diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 66f74ac..4f8e785 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -28,7 +28,7 @@ impl<'ctx> ListValue<'ctx> { ) -> Result<(), String> { let llvm_list_ty = value.get_type().get_element_type(); let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { - panic!("Expected struct type for `list` type, got {llvm_list_ty}") + return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}")) }; if llvm_list_ty.count_fields() != 2 { return Err(format!("Expected 2 fields in `list`, got {}", llvm_list_ty.count_fields())) @@ -223,3 +223,160 @@ impl<'ctx> ListDataProxy<'ctx> { ctx.builder.build_load(ptr, name.unwrap_or_default()) } } + +#[cfg(not(debug_assertions))] +pub fn assert_is_range(_value: PointerValue) {} + +#[cfg(debug_assertions)] +pub fn assert_is_range(value: PointerValue) { + if let Err(msg) = RangeValue::is_instance(value) { + panic!("{msg}") + } +} + +/// Proxy type for accessing a `range` value in LLVM. +#[derive(Copy, Clone)] +pub struct RangeValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>); + +impl<'ctx> RangeValue<'ctx> { + /// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance. + pub fn is_instance(value: PointerValue<'ctx>) -> Result<(), String> { + let llvm_range_ty = value.get_type().get_element_type(); + let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else { + return Err(format!("Expected array type for `range` type, got {llvm_range_ty}")) + }; + if llvm_range_ty.len() != 3 { + return Err(format!("Expected 3 elements for `range` type, got {}", llvm_range_ty.len())) + } + + let llvm_range_elem_ty = llvm_range_ty.get_element_type(); + let Ok(llvm_range_elem_ty) = IntType::try_from(llvm_range_elem_ty) else { + return Err(format!("Expected int type for `range` element type, got {llvm_range_elem_ty}")) + }; + if llvm_range_elem_ty.get_bit_width() != 32 { + return Err(format!("Expected 32-bit int type for `range` element type, got {}", + llvm_range_elem_ty.get_bit_width())) + } + + Ok(()) + } + + /// Creates an [RangeValue] from a [PointerValue]. + pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self { + assert_is_range(ptr); + RangeValue(ptr, name) + } + + /// Returns the underlying [PointerValue] pointing to the `range` instance. + pub fn get_ptr(&self) -> PointerValue<'ctx> { + self.0 + } + + fn get_start_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.start.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.0, + &[llvm_i32.const_zero(), llvm_i32.const_int(0, false)], + var_name.as_str(), + ) + } + } + + fn get_end_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.end.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.0, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + var_name.as_str(), + ) + } + } + + fn get_step_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.step.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.0, + &[llvm_i32.const_zero(), llvm_i32.const_int(2, false)], + var_name.as_str(), + ) + } + } + + /// Stores the `start` value into this instance. + pub fn store_start( + &self, + ctx: &CodeGenContext<'ctx, '_>, + start: IntValue<'ctx>, + ) { + debug_assert_eq!(start.get_type().get_bit_width(), 32); + + let pstart = self.get_start_ptr(ctx); + ctx.builder.build_store(pstart, start); + } + + /// Returns the `start` value of this `range`. + pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { + let pstart = self.get_start_ptr(ctx); + let var_name = name + .map(|v| v.to_string()) + .or_else(|| self.1.map(|v| format!("{v}.start"))) + .unwrap_or_default(); + + ctx.builder.build_load(pstart, var_name.as_str()).into_int_value() + } + + /// Stores the `end` value into this instance. + pub fn store_end( + &self, + ctx: &CodeGenContext<'ctx, '_>, + end: IntValue<'ctx>, + ) { + debug_assert_eq!(end.get_type().get_bit_width(), 32); + + let pend = self.get_start_ptr(ctx); + ctx.builder.build_store(pend, end); + } + + /// Returns the `end` value of this `range`. + pub fn load_end(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { + let pend = self.get_end_ptr(ctx); + let var_name = name + .map(|v| v.to_string()) + .or_else(|| self.1.map(|v| format!("{v}.end"))) + .unwrap_or_default(); + + ctx.builder.build_load(pend, var_name.as_str()).into_int_value() + } + + /// Stores the `step` value into this instance. + pub fn store_step( + &self, + ctx: &CodeGenContext<'ctx, '_>, + step: IntValue<'ctx>, + ) { + debug_assert_eq!(step.get_type().get_bit_width(), 32); + + let pstep = self.get_start_ptr(ctx); + ctx.builder.build_store(pstep, step); + } + + /// Returns the `step` value of this `range`. + pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { + let pstep = self.get_step_ptr(ctx); + let var_name = name + .map(|v| v.to_string()) + .or_else(|| self.1.map(|v| format!("{v}.step"))) + .unwrap_or_default(); + + ctx.builder.build_load(pstep, var_name.as_str()).into_int_value() + } +} diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 734e945..58f090f 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use crate::{ codegen::{ - classes::ListValue, + classes::{ListValue, RangeValue}, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_type, @@ -870,18 +870,11 @@ pub fn gen_call<'ctx, G: CodeGenerator>( /// respectively. pub fn destructure_range<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, - range: PointerValue<'ctx>, + range: RangeValue<'ctx>, ) -> (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>) { - let int32 = ctx.ctx.i32_type(); - let start = ctx - .build_gep_and_load(range, &[int32.const_zero(), int32.const_int(0, false)], Some("range.start")) - .into_int_value(); - let end = ctx - .build_gep_and_load(range, &[int32.const_zero(), int32.const_int(1, false)], Some("range.stop")) - .into_int_value(); - let step = ctx - .build_gep_and_load(range, &[int32.const_zero(), int32.const_int(2, false)], Some("range.step")) - .into_int_value(); + let start = range.load_start(ctx, None); + let end = range.load_end(ctx, None); + let step = range.load_step(ctx, None); (start, end, step) } @@ -965,7 +958,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( let list_content; if is_range { - let iter_val = iter_val.into_pointer_value(); + let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); let (start, stop, step) = destructure_range(ctx, iter_val); let diff = ctx.builder.build_int_sub(stop, start, "diff"); // add 1 to the length as the value is rounded to zero diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 8ac5a7d..b9fcda0 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -6,7 +6,7 @@ use super::{ }; use crate::{ codegen::{ - classes::ListValue, + classes::{ListValue, RangeValue}, expr::gen_binop_expr, gen_in_range_check, }, @@ -321,7 +321,7 @@ pub fn gen_for( return Ok(()) }; if is_iterable_range_expr { - let iter_val = iter_val.into_pointer_value(); + let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); // Internal variable for loop; Cannot be assigned let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index d2eb458..d4bf68c 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,6 +1,7 @@ use super::*; use crate::{ codegen::{ + classes::RangeValue, expr::destructure_range, irrt::{ calculate_len_for_slice_range, @@ -1453,7 +1454,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; Ok(if ctx.unifier.unioned(arg_ty, range_ty) { - let arg = arg.into_pointer_value(); + let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into()) } else { -- 2.44.1 From 847091580956de165ad09ec56d8b4a41af25850e Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 22 Jan 2024 16:51:35 +0800 Subject: [PATCH 3/3] core: Add NDArrayValue and helper functions --- nac3core/src/codegen/classes.rs | 489 ++++++++++++++++++++++++++++++- nac3core/src/codegen/irrt/irrt.c | 38 ++- nac3core/src/codegen/irrt/mod.rs | 100 +++++-- nac3core/src/codegen/mod.rs | 27 -- nac3core/src/toplevel/numpy.rs | 219 +++----------- 5 files changed, 630 insertions(+), 243 deletions(-) diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 4f8e785..86822ee 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -3,7 +3,12 @@ use inkwell::{ types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}, values::{BasicValueEnum, IntValue, PointerValue}, }; -use crate::codegen::{CodeGenContext, CodeGenerator}; +use crate::codegen::{ + CodeGenContext, + CodeGenerator, + irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, + stmt::gen_for_callback, +}; #[cfg(not(debug_assertions))] pub fn assert_is_list<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {} @@ -380,3 +385,485 @@ impl<'ctx> RangeValue<'ctx> { ctx.builder.build_load(pstep, var_name.as_str()).into_int_value() } } + +#[cfg(not(debug_assertions))] +pub fn assert_is_ndarray<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {} + +#[cfg(debug_assertions)] +pub fn assert_is_ndarray<'ctx>(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) { + if let Err(msg) = NDArrayValue::is_instance(value, llvm_usize) { + panic!("{msg}") + } +} + +/// Proxy type for accessing an `NDArray` value in LLVM. +#[derive(Copy, Clone)] +pub struct NDArrayValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>); + +impl<'ctx> NDArrayValue<'ctx> { + /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an + /// instance. + pub fn is_instance( + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + let llvm_ndarray_ty = value.get_type().get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")) + }; + if llvm_ndarray_ty.count_fields() != 3 { + return Err(format!("Expected 3 fields in `NDArray`, got {}", llvm_ndarray_ty.count_fields())) + } + + let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap(); + let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else { + return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}")) + }; + if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() { + return Err(format!("Expected {}-bit int type for `ndarray.0`, got {}-bit int", + llvm_usize.get_bit_width(), + ndarray_ndims_ty.get_bit_width())) + } + + let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap(); + let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else { + return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}")) + }; + let ndarray_dims = ndarray_pdims.get_element_type(); + let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else { + return Err(format!("Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}")) + }; + if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() { + return Err(format!("Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int", + llvm_usize.get_bit_width(), + ndarray_dims.get_bit_width())) + } + + let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap(); + let Ok(_) = PointerType::try_from(ndarray_data_ty) else { + return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}")) + }; + + Ok(()) + } + + /// Creates an [NDArrayValue] from a [PointerValue]. + pub fn from_ptr_val( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + assert_is_ndarray(ptr, llvm_usize); + NDArrayValue(ptr, name) + } + + /// Returns the underlying [PointerValue] pointing to the `NDArray` instance. + pub fn get_ptr(&self) -> PointerValue<'ctx> { + self.0 + } + + /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. + fn get_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.0, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + var_name.as_str(), + ) + } + } + + /// Stores the number of dimensions `ndims` into this instance. + pub fn store_ndims( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &dyn CodeGenerator, + ndims: IntValue<'ctx>, + ) { + debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); + + let pndims = self.get_ndims(ctx); + ctx.builder.build_store(pndims, ndims); + } + + /// Returns the number of dimensions of this `NDArray` as a value. + pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + let pndims = self.get_ndims(ctx); + ctx.builder.build_load(pndims, "").into_int_value() + } + + /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` + /// on the field. + fn get_dims_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.get_ptr(), + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + var_name.as_str(), + ) + } + } + + /// Stores the array of dimension sizes `dims` into this instance. + fn store_dims(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { + ctx.builder.build_store(self.get_dims_ptr(ctx), dims); + } + + /// Convenience method for creating a new array storing dimension sizes with the given `size`. + pub fn create_dims( + &self, + ctx: &CodeGenContext<'ctx, '_>, + llvm_usize: IntType<'ctx>, + size: IntValue<'ctx>, + ) { + self.store_dims(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "")); + } + + /// Returns a proxy object to the field storing the size of each dimension of this `NDArray`. + pub fn get_dims(&self) -> NDArrayDimsProxy<'ctx> { + NDArrayDimsProxy(self.clone()) + } + + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` + /// on the field. + fn get_data_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.data.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.get_ptr(), + &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], + var_name.as_str(), + ) + } + } + + /// Stores the array of data elements `data` into this instance. + fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + ctx.builder.build_store(self.get_data_ptr(ctx), data); + } + + /// Convenience method for creating a new array storing data elements with the given element + /// type `elem_ty` and + /// `size`. + pub fn create_data( + &self, + ctx: &CodeGenContext<'ctx, '_>, + elem_ty: BasicTypeEnum<'ctx>, + size: IntValue<'ctx>, + ) { + self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "")); + } + + /// Returns a proxy object to the field storing the data of this `NDArray`. + pub fn get_data(&self) -> NDArrayDataProxy<'ctx> { + NDArrayDataProxy(self.clone()) + } +} + +impl<'ctx> Into> for NDArrayValue<'ctx> { + fn into(self) -> PointerValue<'ctx> { + self.get_ptr() + } +} + +/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM. +#[derive(Copy, Clone)] +pub struct NDArrayDimsProxy<'ctx>(NDArrayValue<'ctx>); + +impl<'ctx> NDArrayDimsProxy<'ctx> { + /// Returns the single-indirection pointer to the array. + pub fn get_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let var_name = self.0.1.map(|v| format!("{v}.dims")).unwrap_or_default(); + + ctx.builder.build_load(self.0.get_dims_ptr(ctx), var_name.as_str()).into_pointer_value() + } + + /// Returns the pointer to the size of the `idx`-th dimension. + pub fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let in_range = ctx.builder.build_int_compare( + IntPredicate::ULT, + idx, + self.0.load_ndims(ctx), + "" + ); + ctx.make_assert( + generator, + in_range, + "0:IndexError", + "index {0} is out of bounds for axis 0 with size {1}", + [Some(idx), Some(self.0.load_ndims(ctx)), None], + ctx.current_loc, + ); + + let var_name = name + .map(|v| format!("{v}.addr")) + .unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.get_ptr(ctx), + &[idx], + var_name.as_str(), + ) + } + } + + /// Returns the size of the `idx`-th dimension. + pub fn get( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> IntValue<'ctx> { + let ptr = self.ptr_offset(ctx, generator, idx, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()).into_int_value() + } +} + +/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM. +#[derive(Copy, Clone)] +pub struct NDArrayDataProxy<'ctx>(NDArrayValue<'ctx>); + +impl<'ctx> NDArrayDataProxy<'ctx> { + /// Returns the single-indirection pointer to the array. + pub fn get_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let var_name = self.0.1.map(|v| format!("{v}.data")).unwrap_or_default(); + + ctx.builder.build_load(self.0.get_data_ptr(ctx), var_name.as_str()).into_pointer_value() + } + + pub unsafe fn ptr_to_data_flattened_unchecked( + &self, + ctx: &CodeGenContext<'ctx, '_>, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + ctx.builder.build_in_bounds_gep( + self.get_ptr(ctx), + &[idx], + name.unwrap_or_default(), + ) + } + + /// Returns the pointer to the data at the `idx`-th flattened index. + pub fn ptr_to_data_flattened( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let ndims = self.0.load_ndims(ctx); + let dims = self.0.get_dims().get_ptr(ctx); + let data_sz = call_ndarray_calc_size(generator, ctx, ndims, dims); + + let in_range = ctx.builder.build_int_compare( + IntPredicate::ULT, + idx, + data_sz, + "" + ); + ctx.make_assert( + generator, + in_range, + "0:IndexError", + "index {0} is out of bounds with size {1}", + [Some(idx), Some(self.0.load_ndims(ctx)), None], + ctx.current_loc, + ); + + unsafe { + self.ptr_to_data_flattened_unchecked(ctx, idx, name) + } + } + + pub unsafe fn get_flattened_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_to_data_flattened_unchecked(ctx, idx, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } + + /// Returns the data at the `idx`-th flattened index. + pub fn get_flattened( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_to_data_flattened(ctx, generator, idx, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } + + pub unsafe fn ptr_offset_unchecked( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &dyn CodeGenerator, + indices: ListValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let indices_elem_ty = indices.get_data().get_ptr(ctx).get_type().get_element_type(); + let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { + panic!("Expected list[int32] but got {indices_elem_ty}") + }; + assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}"); + + let index = call_ndarray_flatten_index( + generator, + ctx, + self.0, + indices, + ).unwrap(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.get_ptr(ctx), + &[index], + name.unwrap_or_default(), + ) + } + } + + /// Returns the pointer to the data at the index specified by `indices`. + pub fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + indices: ListValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let llvm_usize = generator.get_size_type(ctx.ctx); + + let nidx_leq_ndims = ctx.builder.build_int_compare( + IntPredicate::SLE, + indices.load_size(ctx, None), + self.0.load_ndims(ctx), + "" + ); + ctx.make_assert( + generator, + nidx_leq_ndims, + "0:IndexError", + "invalid index to scalar variable", + [None, None, None], + ctx.current_loc, + ); + + gen_for_callback( + generator, + ctx, + |generator, ctx| { + let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; + ctx.builder.build_store(i, llvm_usize.const_zero()); + + Ok(i) + }, + |_, ctx, i_addr| { + let indices_len = indices.load_size(ctx, None); + let ndarray_len = self.0.load_ndims(ctx); + + let min_fn_name = format!("llvm.umin.i{}", llvm_usize.get_bit_width()); + let min_fn = ctx.module.get_function(min_fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[llvm_usize.into(), llvm_usize.into()], + false + ); + ctx.module.add_function(min_fn_name.as_str(), fn_type, None) + }); + + let len = ctx + .builder + .build_call(min_fn, &[indices_len.into(), ndarray_len.into()], "") + .try_as_basic_value() + .map_left(|v| v.into_int_value()) + .left() + .unwrap(); + + let i = ctx.builder.build_load(i_addr, "").into_int_value(); + Ok(ctx.builder.build_int_compare(IntPredicate::SLT, i, len, "")) + }, + |generator, ctx, i_addr| { + let i = ctx.builder.build_load(i_addr, "").into_int_value(); + let (dim_idx, dim_sz) = unsafe { + ( + indices.get_data().get_unchecked(ctx, i, None).into_int_value(), + self.0.get_dims().get(ctx, generator, i, None), + ) + }; + + let dim_lt = ctx.builder.build_int_compare( + IntPredicate::SLT, + dim_idx, + dim_sz, + "" + ); + + ctx.make_assert( + generator, + dim_lt, + "0:IndexError", + "index {0} is out of bounds for axis 0 with size {1}", + [Some(dim_idx), Some(dim_sz), None], + ctx.current_loc, + ); + + Ok(()) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), ""); + ctx.builder.build_store(i_addr, i); + + Ok(()) + }, + ).unwrap(); + + unsafe { + self.ptr_offset_unchecked(ctx, generator, indices, name) + } + } + + pub unsafe fn get_unsafe( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &dyn CodeGenerator, + indices: ListValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_offset_unchecked(ctx, generator, indices, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } + + /// Returns the data at the index specified by `indices`. + pub fn get( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + indices: ListValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_offset(ctx, generator, indices, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } +} diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index 8b28bc1..97969d1 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -268,4 +268,40 @@ void __nac3_ndarray_calc_nd_indices64( idxs[i] = (index / stride) % dims[i]; stride *= dims[i]; } -} \ No newline at end of file +} + +uint32_t __nac3_ndarray_flatten_index( + const uint32_t* dims, + uint32_t num_dims, + const uint32_t* indices, + uint32_t num_indices +) { + uint32_t idx = 0; + uint32_t stride = 1; + for (uint32_t i = num_dims - 1; i-- >= 0; ) { + if (i < num_indices) { + idx += (stride * indices[i]); + } + + stride *= dims[i]; + } + return idx; +} + +uint64_t __nac3_ndarray_flatten_index64( + const uint64_t* dims, + uint64_t num_dims, + const uint32_t* indices, + uint64_t num_indices +) { + uint64_t idx = 0; + uint64_t stride = 1; + for (uint64_t i = num_dims - 1; i-- >= 0; ) { + if (i < num_indices) { + idx += (stride * indices[i]); + } + + stride *= dims[i]; + } + return idx; +} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 3b21dd9..5702ba9 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,8 +1,7 @@ use crate::typecheck::typedef::Type; use super::{ - classes::ListValue, - assert_is_ndarray, + classes::{ListValue, NDArrayValue}, CodeGenContext, CodeGenerator, }; @@ -607,7 +606,7 @@ pub fn call_ndarray_calc_size<'ctx>( pub fn call_ndarray_init_dims<'ctx>( generator: &dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: PointerValue<'ctx>, + ndarray: NDArrayValue<'ctx>, shape: ListValue<'ctx>, ) { let llvm_void = ctx.ctx.void_type(); @@ -617,8 +616,6 @@ pub fn call_ndarray_init_dims<'ctx>( let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - assert_is_ndarray(ndarray); - let ndarray_init_dims_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_init_dims", 64 => "__nac3_ndarray_init_dims64", @@ -637,22 +634,14 @@ pub fn call_ndarray_init_dims<'ctx>( ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None) }); - let ndarray_dims = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - None, - ); + let ndarray_dims = ndarray.get_dims(); let shape_data = shape.get_data(); - let ndarray_num_dims = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_int_value(); + let ndarray_num_dims = ndarray.load_ndims(ctx); ctx.builder.build_call( ndarray_init_dims_fn, &[ - ndarray_dims.into(), + ndarray_dims.get_ptr(ctx).into(), shape_data.get_ptr(ctx).into(), ndarray_num_dims.into(), ], @@ -669,12 +658,9 @@ pub fn call_ndarray_calc_nd_indices<'ctx>( generator: &dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, '_>, index: IntValue<'ctx>, - ndarray: PointerValue<'ctx>, + ndarray: NDArrayValue<'ctx>, ) -> Result, String> { - assert_is_ndarray(ndarray); - let llvm_void = ctx.ctx.void_type(); - let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); @@ -698,16 +684,8 @@ pub fn call_ndarray_calc_nd_indices<'ctx>( ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) }); - let ndarray_num_dims = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_int_value(); - let ndarray_dims = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - None, - ).into_pointer_value(); + let ndarray_num_dims = ndarray.load_ndims(ctx); + let ndarray_dims = ndarray.get_dims(); let indices = ctx.builder.build_array_alloca( llvm_usize, @@ -719,7 +697,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx>( ndarray_calc_nd_indices_fn, &[ index.into(), - ndarray_dims.into(), + ndarray_dims.get_ptr(ctx).into(), ndarray_num_dims.into(), indices.into(), ], @@ -727,4 +705,64 @@ pub fn call_ndarray_calc_nd_indices<'ctx>( ); Ok(indices) +} + +/// Generates a call to `__nac3_ndarray_flatten_index`. +/// +/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an +/// `NDArray`. +/// * `indices` - The multidimensional index to compute the flattened index for. +pub fn call_ndarray_flatten_index<'ctx>( + generator: &dyn CodeGenerator, + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + indices: ListValue<'ctx>, +) -> Result, String> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_flatten_index", + 64 => "__nac3_ndarray_flatten_index64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[ + llvm_usize.into(), + llvm_pusize.into(), + llvm_pi32.into(), + llvm_pusize.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) + }); + + let ndarray_num_dims = ndarray.load_ndims(ctx); + let ndarray_dims = ndarray.get_dims(); + let indices_size = indices.load_size(ctx, None); + let indices_data = indices.get_data(); + + let index = ctx.builder + .build_call( + ndarray_flatten_index_fn, + &[ + ndarray_num_dims.into(), + ndarray_dims.get_ptr(ctx).into(), + indices_size.into(), + indices_data.get_ptr(ctx).into(), + ], + "", + ) + .try_as_basic_value() + .map_left(|v| v.into_int_value()) + .left() + .unwrap(); + + Ok(index) } \ No newline at end of file diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 2102076..66c1d4f 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -34,9 +34,6 @@ use std::sync::{ }; use std::thread; -#[cfg(debug_assertions)] -use inkwell::types::AnyTypeEnum; - pub mod classes; pub mod concrete_type; pub mod expr; @@ -999,27 +996,3 @@ fn gen_in_range_check<'ctx>( ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") } - -/// Checks whether the pointer `value` refers to an `NDArray` in LLVM. -fn assert_is_ndarray(value: PointerValue) -> PointerValue { - #[cfg(debug_assertions)] - { - let llvm_ndarray_ty = value.get_type().get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}") - }; - - assert_eq!(llvm_ndarray_ty.count_fields(), 3); - assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..)))); - let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else { - unreachable!() - }; - let BasicTypeEnum::PointerType(dims) = ndarray_dims else { - panic!("Expected pointer type for `list.1`, but got {ndarray_dims}") - }; - assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..))); - assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..)))); - } - - value -} diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index 4b6f496..820b575 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -3,7 +3,7 @@ use inkwell::values::{ArrayValue, IntValue}; use nac3parser::ast::StrRef; use crate::{ codegen::{ - classes::ListValue, + classes::{ListValue, NDArrayValue}, CodeGenContext, CodeGenerator, irrt::{ @@ -27,11 +27,10 @@ fn create_ndarray_const_shape<'ctx, 'a>( ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, shape: ArrayValue<'ctx> -) -> Result, String> { +) -> Result, String> { let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); - let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); @@ -68,54 +67,18 @@ fn create_ndarray_const_shape<'ctx, 'a>( llvm_ndarray_t.into(), None, )?; + let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None); let num_dims = llvm_usize.const_int(shape.get_type().len() as u64, false); + ndarray.store_ndims(ctx, generator, num_dims); - let ndarray_num_dims = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - "", - ) - }; - ctx.builder.build_store(ndarray_num_dims, num_dims); - - let ndarray_dims = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - "", - ) - }; - - let ndarray_num_dims = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_int_value(); - - ctx.builder.build_store( - ndarray_dims, - ctx.builder.build_array_alloca( - llvm_usize, - ndarray_num_dims, - "", - ), - ); + let ndarray_num_dims = ndarray.load_ndims(ctx); + ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims); for i in 0..shape.get_type().len() { - let ndarray_dim = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - None, - ).into_pointer_value(); - let ndarray_dim = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray_dim, - &[llvm_i32.const_int(i as u64, true)], - "", - ) - }; + let ndarray_dim = ndarray + .get_dims() + .ptr_offset(ctx, generator, llvm_usize.const_int(i as u64, true), None); let shape_dim = ctx.builder.build_extract_value(shape, i, "") .map(|val| val.into_int_value()) .unwrap(); @@ -123,42 +86,14 @@ fn create_ndarray_const_shape<'ctx, 'a>( ctx.builder.build_store(ndarray_dim, shape_dim); } - let (ndarray_num_dims, ndarray_dims) = unsafe { - ( - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - "" - ), - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - "" - ), - ) - }; + let ndarray_dims = ndarray.get_dims().get_ptr(ctx); let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, - ctx.builder.build_load(ndarray_num_dims, "").into_int_value(), - ctx.builder.build_load(ndarray_dims, "").into_pointer_value(), - ); - - let ndarray_data = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], - "", - ) - }; - ctx.builder.build_store( - ndarray_data, - ctx.builder.build_array_alloca( - llvm_ndarray_data_t, - ndarray_num_elems, - "" - ), + ndarray.load_ndims(ctx), + ndarray_dims, ); + ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); Ok(ndarray) } @@ -214,7 +149,7 @@ fn call_ndarray_empty_impl<'ctx, 'a>( ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, shape: ListValue<'ctx>, -) -> Result, String> { +) -> Result, String> { let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); @@ -284,79 +219,23 @@ fn call_ndarray_empty_impl<'ctx, 'a>( llvm_ndarray_t.into(), None, )?; + let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None); let num_dims = shape.load_size(ctx, None); + ndarray.store_ndims(ctx, generator, num_dims); - let ndarray_num_dims = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - "", - ) - }; - ctx.builder.build_store(ndarray_num_dims, num_dims); - - let ndarray_dims = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - "", - ) - }; - - let ndarray_num_dims = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_int_value(); - - ctx.builder.build_store( - ndarray_dims, - ctx.builder.build_array_alloca( - llvm_usize, - ndarray_num_dims, - "", - ), - ); + let ndarray_num_dims = ndarray.load_ndims(ctx); + ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims); call_ndarray_init_dims(generator, ctx, ndarray, shape); - let (ndarray_num_dims, ndarray_dims) = unsafe { - ( - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - "" - ), - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - "" - ), - ) - }; let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, - ctx.builder.build_load(ndarray_num_dims, "").into_int_value(), - ctx.builder.build_load(ndarray_dims, "").into_pointer_value(), - ); - - let ndarray_data = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], - "", - ) - }; - ctx.builder.build_store( - ndarray_data, - ctx.builder.build_array_alloca( - llvm_ndarray_data_t, - ndarray_num_elems, - "", - ), + ndarray.load_ndims(ctx), + ndarray.get_dims().get_ptr(ctx), ); + ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); Ok(ndarray) } @@ -369,35 +248,19 @@ fn call_ndarray_empty_impl<'ctx, 'a>( fn ndarray_fill_flattened<'ctx, 'a, ValueFn>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: PointerValue<'ctx>, + ndarray: NDArrayValue<'ctx>, value_fn: ValueFn, ) -> Result<(), String> where ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result, String>, { - let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let (num_dims, dims) = unsafe { - ( - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - "" - ), - ctx.builder.build_in_bounds_gep( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - "" - ), - ) - }; - let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, - ctx.builder.build_load(num_dims, "").into_int_value(), - ctx.builder.build_load(dims, "").into_pointer_value(), + ndarray.load_ndims(ctx), + ndarray.get_dims().get_ptr(ctx), ); gen_for_callback( @@ -417,21 +280,11 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>( Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, "")) }, |generator, ctx, i_addr| { - let ndarray_data = ctx.build_gep_and_load( - ndarray, - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], - None - ).into_pointer_value(); - let i = ctx.builder .build_load(i_addr, "") .into_int_value(); let elem = unsafe { - ctx.builder.build_in_bounds_gep( - ndarray_data, - &[i], - "" - ) + ndarray.get_data().ptr_to_data_flattened_unchecked(ctx, i, None) }; let value = value_fn(generator, ctx, i)?; @@ -459,7 +312,7 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>( fn ndarray_fill_indexed<'ctx, 'a, ValueFn>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: PointerValue<'ctx>, + ndarray: NDArrayValue<'ctx>, value_fn: ValueFn, ) -> Result<(), String> where @@ -491,7 +344,7 @@ fn call_ndarray_zeros_impl<'ctx, 'a>( ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, shape: ListValue<'ctx>, -) -> Result, String> { +) -> Result, String> { let supported_types = [ ctx.primitives.int32, ctx.primitives.int64, @@ -527,7 +380,7 @@ fn call_ndarray_ones_impl<'ctx, 'a>( ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, shape: ListValue<'ctx>, -) -> Result, String> { +) -> Result, String> { let supported_types = [ ctx.primitives.int32, ctx.primitives.int64, @@ -564,7 +417,7 @@ fn call_ndarray_full_impl<'ctx, 'a>( elem_ty: Type, shape: ListValue<'ctx>, fill_value: BasicValueEnum<'ctx>, -) -> Result, String> { +) -> Result, String> { let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; ndarray_fill_flattened( generator, @@ -633,7 +486,7 @@ fn call_ndarray_eye_impl<'ctx, 'a>( nrows: IntValue<'ctx>, ncols: IntValue<'ctx>, offset: IntValue<'ctx>, -) -> Result, String> { +) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize_2 = llvm_usize.array_type(2); @@ -718,7 +571,7 @@ pub fn gen_ndarray_empty<'ctx, 'a>( context, context.primitives.float, ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - ) + ).map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.zeros`. @@ -742,7 +595,7 @@ pub fn gen_ndarray_zeros<'ctx, 'a>( context, context.primitives.float, ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - ) + ).map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.ones`. @@ -766,7 +619,7 @@ pub fn gen_ndarray_ones<'ctx, 'a>( context, context.primitives.float, ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - ) + ).map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.full`. @@ -794,7 +647,7 @@ pub fn gen_ndarray_full<'ctx, 'a>( fill_value_ty, ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), fill_value_arg, - ) + ).map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.eye`. @@ -839,7 +692,7 @@ pub fn gen_ndarray_eye<'ctx, 'a>( nrows_arg.into_int_value(), ncols_arg.into_int_value(), offset_arg.into_int_value(), - ) + ).map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.identity`. @@ -866,5 +719,5 @@ pub fn gen_ndarray_identity<'ctx, 'a>( n_arg.into_int_value(), n_arg.into_int_value(), llvm_usize.const_zero(), - ) + ).map(NDArrayValue::into) } \ No newline at end of file -- 2.44.1