From 5ee08b585f07e267629bf9020802ed2d6f85ca1f Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 23 Jan 2024 17:21:24 +0800 Subject: [PATCH] core: Add ListValue and helper functions --- nac3core/src/codegen/classes.rs | 225 +++++++++++++++++++++++++++++++ nac3core/src/codegen/expr.rs | 89 +++++------- nac3core/src/codegen/irrt/mod.rs | 45 +++---- nac3core/src/codegen/mod.rs | 17 +-- nac3core/src/codegen/stmt.rs | 27 ++-- nac3core/src/toplevel/numpy.rs | 47 +++---- 6 files changed, 308 insertions(+), 142 deletions(-) create mode 100644 nac3core/src/codegen/classes.rs diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs new file mode 100644 index 00000000..66f74ac4 --- /dev/null +++ b/nac3core/src/codegen/classes.rs @@ -0,0 +1,225 @@ +use inkwell::{ + IntPredicate, + types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}, + values::{BasicValueEnum, IntValue, PointerValue}, +}; +use crate::codegen::{CodeGenContext, CodeGenerator}; + +#[cfg(not(debug_assertions))] +pub fn assert_is_list<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {} + +#[cfg(debug_assertions)] +pub fn assert_is_list<'ctx>(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) { + if let Err(msg) = ListValue::is_instance(value, llvm_usize) { + panic!("{msg}") + } +} + +/// Proxy type for accessing a `list` value in LLVM. +#[derive(Copy, Clone)] +pub struct ListValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>); + +impl<'ctx> ListValue<'ctx> { + /// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an + /// instance. + pub fn is_instance( + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + let llvm_list_ty = value.get_type().get_element_type(); + let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { + panic!("Expected struct type for `list` type, got {llvm_list_ty}") + }; + if llvm_list_ty.count_fields() != 2 { + return Err(format!("Expected 2 fields in `list`, got {}", llvm_list_ty.count_fields())) + } + + let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap(); + let Ok(_) = PointerType::try_from(list_size_ty) else { + return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}")) + }; + + let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap(); + let Ok(list_data_ty) = IntType::try_from(list_data_ty) else { + return Err(format!("Expected int type for `list.1`, got {list_data_ty}")) + }; + if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() { + return Err(format!("Expected {}-bit int type for `list.1`, got {}-bit int", + llvm_usize.get_bit_width(), + list_data_ty.get_bit_width())) + } + + Ok(()) + } + + /// Creates an [ListValue] from a [PointerValue]. + pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self { + assert_is_list(ptr, llvm_usize); + ListValue(ptr, name) + } + + /// Returns the underlying [PointerValue] pointing to the `list` instance. + pub fn get_ptr(&self) -> PointerValue<'ctx> { + self.0 + } + + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` + /// on the field. + fn get_data_pptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.data.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.get_ptr(), + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + var_name.as_str(), + ) + } + } + + /// Returns the pointer to the field storing the size of this `list`. + fn get_size_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.1.map(|v| format!("{v}.size.addr")).unwrap_or_default(); + + unsafe { + ctx.builder.build_in_bounds_gep( + self.0, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + var_name.as_str(), + ) + } + } + + /// Stores the array of data elements `data` into this instance. + fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + ctx.builder.build_store(self.get_data_pptr(ctx), data); + } + + /// Convenience method for creating a new array storing data elements with the given element + /// type `elem_ty` and `size`. + /// + /// If `size` is [None], the size stored in the field of this instance is used instead. + pub fn create_data( + &self, + ctx: &CodeGenContext<'ctx, '_>, + elem_ty: BasicTypeEnum<'ctx>, + size: Option>, + ) { + let size = size.unwrap_or_else(|| self.load_size(ctx, None)); + self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "")); + } + + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` + /// on the field. + pub fn get_data(&self) -> ListDataProxy<'ctx> { + ListDataProxy(self.clone()) + } + + /// Stores the `size` of this `list` into this instance. + pub fn store_size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &dyn CodeGenerator, + size: IntValue<'ctx>, + ) { + debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx)); + + let psize = self.get_size_ptr(ctx); + ctx.builder.build_store(psize, size); + } + + /// Returns the size of this `list` as a value. + pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { + let psize = self.get_size_ptr(ctx); + let var_name = name + .map(|v| v.to_string()) + .or_else(|| self.1.map(|v| format!("{v}.size"))) + .unwrap_or_default(); + + ctx.builder.build_load(psize, var_name.as_str()).into_int_value() + } +} + +/// Proxy type for accessing the `data` array of an `list` instance in LLVM. +#[derive(Copy, Clone)] +pub struct ListDataProxy<'ctx>(ListValue<'ctx>); + +impl<'ctx> ListDataProxy<'ctx> { + /// Returns the single-indirection pointer to the array. + pub fn get_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let var_name = self.0.1.map(|v| format!("{v}.data")).unwrap_or_default(); + + ctx.builder.build_load(self.0.get_data_pptr(ctx), var_name.as_str()).into_pointer_value() + } + + pub unsafe fn ptr_offset_unchecked( + &self, + ctx: &CodeGenContext<'ctx, '_>, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let var_name = name + .map(|v| format!("{v}.addr")) + .unwrap_or_default(); + + ctx.builder.build_in_bounds_gep( + self.get_ptr(ctx), + &[idx], + var_name.as_str(), + ) + } + + /// Returns the pointer to the data at the `idx`-th index. + pub fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); + + let in_range = ctx.builder.build_int_compare( + IntPredicate::ULT, + idx, + self.0.load_size(ctx, None), + "" + ); + ctx.make_assert( + generator, + in_range, + "0:IndexError", + "list index out of range", + [None, None, None], + ctx.current_loc, + ); + + unsafe { + self.ptr_offset_unchecked(ctx, idx, name) + } + } + + pub unsafe fn get_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_offset_unchecked(ctx, idx, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } + + /// Returns the data at the `idx`-th flattened index. + pub fn get( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + idx: IntValue<'ctx>, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_offset(ctx, generator, idx, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()) + } +} diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 79ebe4ad..734e9451 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use crate::{ codegen::{ + classes::ListValue, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_type, @@ -896,43 +897,26 @@ pub fn allocate_list<'ctx, G: CodeGenerator>( ty: BasicTypeEnum<'ctx>, length: IntValue<'ctx>, name: Option<&str>, -) -> PointerValue<'ctx> { +) -> ListValue<'ctx> { let size_t = generator.get_size_type(ctx.ctx); - let i32_t = ctx.ctx.i32_type(); // List structure; type { ty*, size_t } let arr_ty = ctx.ctx .struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false); - let zero = ctx.ctx.i32_type().const_zero(); let arr_str_ptr = ctx.builder.build_alloca( arr_ty, format!("{}.addr", name.unwrap_or("list")).as_str() ); + let list = ListValue::from_ptr_val(arr_str_ptr, size_t, Some("list")); - unsafe { - // Pointer to the `length` element of the list structure - let len_ptr = ctx.builder.build_in_bounds_gep( - arr_str_ptr, - &[zero, i32_t.const_int(1, false)], - "" - ); - let length = ctx.builder.build_int_z_extend( - length, - size_t, - "" - ); - ctx.builder.build_store(len_ptr, length); + let length = ctx.builder.build_int_z_extend( + length, + size_t, + "" + ); + list.store_size(ctx, generator, length); + list.create_data(ctx, ty, None); - // Pointer to the `data` element of the list structure - let arr_ptr = ctx.builder.build_array_alloca(ty, length, ""); - let ptr_to_arr = ctx.builder.build_in_bounds_gep( - arr_str_ptr, - &[zero, i32_t.const_zero()], - "" - ); - ctx.builder.build_store(ptr_to_arr, arr_ptr); - } - - arr_str_ptr + list } /// Generates LLVM IR for a [list comprehension expression][expr]. @@ -1006,8 +990,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( list_alloc_size.into_int_value(), Some("listcomp.addr") ); - list_content = ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("listcomp.data.addr")) - .into_pointer_value(); + list_content = list.get_data().get_ptr(ctx); let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init")); @@ -1042,8 +1025,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( ) .into_int_value(); list = allocate_list(generator, ctx, elem_ty, length, Some("listcomp")); - list_content = - ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("list_content")).into_pointer_value(); + list_content = list.get_data().get_ptr(ctx); let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; // counter = -1 ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true)); @@ -1065,12 +1047,9 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( } // Emits the content of `cont_bb` - let emit_cont_bb = |ctx: &CodeGenContext| { + let emit_cont_bb = |ctx: &CodeGenContext<'ctx, '_>, generator: &dyn CodeGenerator, list: ListValue<'ctx>| { ctx.builder.position_at_end(cont_bb); - let len_ptr = unsafe { - ctx.builder.build_gep(list, &[zero_size_t, int32.const_int(1, false)], "length") - }; - ctx.builder.build_store(len_ptr, ctx.builder.build_load(index, "index")); + list.store_size(ctx, generator, ctx.builder.build_load(index, "index").into_int_value()); }; for cond in ifs { @@ -1079,7 +1058,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( } else { // Bail if the predicate is an ellipsis - Emit cont_bb contents in case the // no element matches the predicate - emit_cont_bb(ctx); + emit_cont_bb(ctx, generator, list); return Ok(None) }; @@ -1092,7 +1071,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( let Some(elem) = generator.gen_expr(ctx, elt)? else { // Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents - emit_cont_bb(ctx); + emit_cont_bb(ctx, generator, list); return Ok(None) }; @@ -1104,9 +1083,9 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( .build_store(index, ctx.builder.build_int_add(i, size_t.const_int(1, false), "inc")); ctx.builder.build_unconditional_branch(test_bb); - emit_cont_bb(ctx); + emit_cont_bb(ctx, generator, list); - Ok(Some(list.into())) + Ok(Some(list.get_ptr().into())) } /// Generates LLVM IR for a [binary operator expression][expr]. @@ -1226,6 +1205,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ) -> Result>, String> { ctx.current_loc = expr.location; let int32 = ctx.ctx.i32_type(); + let usize = generator.get_size_type(ctx.ctx); let zero = int32.const_int(0, false); let loc = ctx.debug_info.0.create_debug_location( @@ -1296,19 +1276,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( }; let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false); let arr_str_ptr = allocate_list(generator, ctx, ty, length, Some("list")); - let arr_ptr = ctx.build_gep_and_load(arr_str_ptr, &[zero, zero], Some("list.ptr.addr")) - .into_pointer_value(); - unsafe { - for (i, v) in elements.iter().enumerate() { - let elem_ptr = ctx.builder.build_gep( - arr_ptr, - &[int32.const_int(i as u64, false)], - "elem_ptr", - ); - ctx.builder.build_store(elem_ptr, *v); - } + let arr_ptr = arr_str_ptr.get_data(); + for (i, v) in elements.iter().enumerate() { + let elem_ptr = arr_ptr + .ptr_offset(ctx, generator, usize.const_int(i as u64, false), Some("elem_ptr")); + ctx.builder.build_store(elem_ptr, *v); } - arr_str_ptr.into() + arr_str_ptr.get_ptr().into() } ExprKind::Tuple { elts, .. } => { let elements_val = elts @@ -1758,9 +1732,8 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { return Ok(None) }; + let v = ListValue::from_ptr_val(v, usize, Some("arr")); let ty = ctx.get_llvm_type(generator, *ty); - let arr_ptr = ctx.build_gep_and_load(v, &[zero, zero], Some("arr.addr")) - .into_pointer_value(); if let ExprKind::Slice { lower, upper, step } = &slice.node { let one = int32.const_int(1, false); let Some((start, end, step)) = @@ -1800,11 +1773,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v, (start, end, step), ); - res_array_ret.into() + res_array_ret.get_ptr().into() } else { - let len = ctx - .build_gep_and_load(v, &[zero, int32.const_int(1, false)], Some("len")) - .into_int_value(); + let len = v.load_size(ctx, Some("len")); let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() } else { @@ -1843,7 +1814,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( [Some(raw_index), Some(len), None], expr.location, ); - ctx.build_gep_and_load(arr_ptr, &[index], None).into() + v.get_data().get(ctx, generator, index, None).into() } } TypeEnum::TNDArray { .. } => { diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 530ce0ae..3b21dd9c 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,6 +1,11 @@ use crate::typecheck::typedef::Type; -use super::{assert_is_list, assert_is_ndarray, CodeGenContext, CodeGenerator}; +use super::{ + classes::ListValue, + assert_is_ndarray, + CodeGenContext, + CodeGenerator, +}; use inkwell::{ attributes::{Attribute, AttributeLoc}, context::Context, @@ -158,12 +163,12 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( step: &Option>>>, ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut G, - list: PointerValue<'ctx>, + list: ListValue<'ctx>, ) -> Result, IntValue<'ctx>, IntValue<'ctx>)>, String> { let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); let one = int32.const_int(1, false); - let length = ctx.build_gep_and_load(list, &[zero, one], Some("length")).into_int_value(); + let length = list.load_size(ctx, Some("length")); let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32"); Ok(Some(match (start, end, step) { (s, e, None) => ( @@ -295,9 +300,9 @@ pub fn list_slice_assignment<'ctx>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, '_>, ty: BasicTypeEnum<'ctx>, - dest_arr: PointerValue<'ctx>, + dest_arr: ListValue<'ctx>, dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), - src_arr: PointerValue<'ctx>, + src_arr: ListValue<'ctx>, src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), ) { let size_ty = generator.get_size_type(ctx.ctx); @@ -326,21 +331,21 @@ pub fn list_slice_assignment<'ctx>( let zero = int32.const_zero(); let one = int32.const_int(1, false); - let dest_arr_ptr = ctx.build_gep_and_load(dest_arr, &[zero, zero], Some("dest.addr")); + let dest_arr_ptr = dest_arr.get_data().get_ptr(ctx); let dest_arr_ptr = ctx.builder.build_pointer_cast( - dest_arr_ptr.into_pointer_value(), + dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast", ); - let dest_len = ctx.build_gep_and_load(dest_arr, &[zero, one], Some("dest.len")).into_int_value(); + let dest_len = dest_arr.load_size(ctx, Some("dest.len")); let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32"); - let src_arr_ptr = ctx.build_gep_and_load(src_arr, &[zero, zero], Some("src.addr")); + let src_arr_ptr = src_arr.get_data().get_ptr(ctx); let src_arr_ptr = ctx.builder.build_pointer_cast( - src_arr_ptr.into_pointer_value(), + src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast", ); - let src_len = ctx.build_gep_and_load(src_arr, &[zero, one], Some("src.len")).into_int_value(); + let src_len = src_arr.load_size(ctx, Some("src.len")); let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32"); // index in bound and positive should be done @@ -443,9 +448,8 @@ pub fn list_slice_assignment<'ctx>( let cont_bb = ctx.ctx.append_basic_block(current, "cont"); ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb); ctx.builder.position_at_end(update_bb); - let dest_len_ptr = unsafe { ctx.builder.build_gep(dest_arr, &[zero, one], "dest_len_ptr") }; let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len"); - ctx.builder.build_store(dest_len_ptr, new_len); + dest_arr.store_size(ctx, generator, new_len); ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.position_at_end(cont_bb); } @@ -604,11 +608,8 @@ pub fn call_ndarray_init_dims<'ctx>( generator: &dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, '_>, ndarray: PointerValue<'ctx>, - shape: PointerValue<'ctx>, + shape: ListValue<'ctx>, ) { - assert_is_ndarray(ndarray); - assert_is_list(shape); - let llvm_void = ctx.ctx.void_type(); let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -616,6 +617,8 @@ pub fn call_ndarray_init_dims<'ctx>( let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + assert_is_ndarray(ndarray); + let ndarray_init_dims_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_init_dims", 64 => "__nac3_ndarray_init_dims64", @@ -639,11 +642,7 @@ pub fn call_ndarray_init_dims<'ctx>( &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], None, ); - let shape_data = ctx.build_gep_and_load( - shape, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None - ); + let shape_data = shape.get_data(); let ndarray_num_dims = ctx.build_gep_and_load( ndarray, &[llvm_i32.const_zero(), llvm_i32.const_zero()], @@ -654,7 +653,7 @@ pub fn call_ndarray_init_dims<'ctx>( ndarray_init_dims_fn, &[ ndarray_dims.into(), - shape_data.into(), + shape_data.get_ptr(ctx).into(), ndarray_num_dims.into(), ], "", diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index b1836a07..2102076d 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -37,6 +37,7 @@ use std::thread; #[cfg(debug_assertions)] use inkwell::types::AnyTypeEnum; +pub mod classes; pub mod concrete_type; pub mod expr; mod generator; @@ -999,22 +1000,6 @@ fn gen_in_range_check<'ctx>( ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") } -/// Checks whether the pointer `value` refers to a `list` in LLVM. -fn assert_is_list(value: PointerValue) -> PointerValue { - #[cfg(debug_assertions)] - { - let llvm_shape_ty = value.get_type().get_element_type(); - let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else { - panic!("Expected struct type for `list` type, but got {llvm_shape_ty}") - }; - assert_eq!(llvm_shape_ty.count_fields(), 2); - assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..)))); - assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..)))); - } - - value -} - /// Checks whether the pointer `value` refers to an `NDArray` in LLVM. fn assert_is_ndarray(value: PointerValue) -> PointerValue { #[cfg(debug_assertions)] diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 1cf57b29..8ac5a7db 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -6,6 +6,7 @@ use super::{ }; use crate::{ codegen::{ + classes::ListValue, expr::gen_binop_expr, gen_in_range_check, }, @@ -92,6 +93,8 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( pattern: &Expr>, name: Option<&str>, ) -> Result>, String> { + let llvm_usize = generator.get_size_type(ctx.ctx); + // very similar to gen_expr, but we don't do an extra load at the end // and we flatten nested tuples Ok(Some(match &pattern.node { @@ -132,16 +135,13 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( ExprKind::Subscript { value, slice, .. } => { match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() { TypeEnum::TList { .. } => { - let i32_type = ctx.ctx.i32_type(); - let zero = i32_type.const_zero(); let v = generator .gen_expr(ctx, value)? .unwrap() .to_basic_value_enum(ctx, generator, value.custom.unwrap())? .into_pointer_value(); - let len = ctx - .build_gep_and_load(v, &[zero, i32_type.const_int(1, false)], Some("len")) - .into_int_value(); + let v = ListValue::from_ptr_val(v, llvm_usize, None); + let len = v.load_size(ctx, Some("len")); let raw_index = generator .gen_expr(ctx, slice)? .unwrap() @@ -180,12 +180,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( [Some(raw_index), Some(len), None], slice.location, ); - unsafe { - let arr_ptr = ctx - .build_gep_and_load(v, &[i32_type.const_zero(), i32_type.const_zero()], Some("arr.addr")) - .into_pointer_value(); - ctx.builder.build_gep(arr_ptr, &[index], name.unwrap_or("")) - } + v.get_data().ptr_offset(ctx, generator, index, name) } TypeEnum::TNDArray { .. } => { @@ -206,6 +201,8 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( target: &Expr>, value: ValueEnum<'ctx>, ) -> Result<(), String> { + let llvm_usize = generator.get_size_type(ctx.ctx); + match &target.node { ExprKind::Tuple { elts, .. } => { let BasicValueEnum::StructValue(v) = @@ -233,6 +230,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( .unwrap() .to_basic_value_enum(ctx, generator, ls.custom.unwrap())? .into_pointer_value(); + let ls = ListValue::from_ptr_val(ls, llvm_usize, None); let Some((start, end, step)) = handle_slice_indices(lower, upper, step, ctx, generator, ls)? else { return Ok(()) @@ -240,9 +238,10 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( let value = value .to_basic_value_enum(ctx, generator, target.custom.unwrap())? .into_pointer_value(); - let (TypeEnum::TList { ty } | TypeEnum::TNDArray { ty, .. }) = &*ctx.unifier.get_ty(target.custom.unwrap()) else { - unreachable!() - }; + let value = ListValue::from_ptr_val(value, llvm_usize, None); + let (TypeEnum::TList { ty } | TypeEnum::TNDArray { ty, .. }) = &*ctx.unifier.get_ty(target.custom.unwrap()) else { + unreachable!() + }; let ty = ctx.get_llvm_type(generator, *ty); let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else { diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index 13bb8a53..4b6f4968 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -3,6 +3,7 @@ use inkwell::values::{ArrayValue, IntValue}; use nac3parser::ast::StrRef; use crate::{ codegen::{ + classes::ListValue, CodeGenContext, CodeGenerator, irrt::{ @@ -212,7 +213,7 @@ fn call_ndarray_empty_impl<'ctx, 'a>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, - shape: PointerValue<'ctx>, + shape: ListValue<'ctx>, ) -> Result, String> { let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); @@ -239,29 +240,15 @@ fn call_ndarray_empty_impl<'ctx, 'a>( let i = ctx.builder .build_load(i_addr, "") .into_int_value(); - let shape_len = ctx.build_gep_and_load( - shape, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - None, - ).into_int_value(); + let shape_len = shape.load_size(ctx, None); - Ok(ctx.builder.build_int_compare(IntPredicate::ULE, i, shape_len, "")) + Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "")) }, |generator, ctx, i_addr| { - let shape_elems = ctx.build_gep_and_load( - shape, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None - ).into_pointer_value(); - let i = ctx.builder .build_load(i_addr, "") .into_int_value(); - let shape_dim = ctx.build_gep_and_load( - shape_elems, - &[i], - None - ).into_int_value(); + let shape_dim = shape.get_data().get(ctx, generator, i, None).into_int_value(); let shape_dim_gez = ctx.builder.build_int_compare( IntPredicate::SGE, @@ -298,11 +285,7 @@ fn call_ndarray_empty_impl<'ctx, 'a>( None, )?; - let num_dims = ctx.build_gep_and_load( - shape, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - None - ).into_int_value(); + let num_dims = shape.load_size(ctx, None); let ndarray_num_dims = unsafe { ctx.builder.build_in_bounds_gep( @@ -507,7 +490,7 @@ fn call_ndarray_zeros_impl<'ctx, 'a>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, - shape: PointerValue<'ctx>, + shape: ListValue<'ctx>, ) -> Result, String> { let supported_types = [ ctx.primitives.int32, @@ -543,7 +526,7 @@ fn call_ndarray_ones_impl<'ctx, 'a>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, - shape: PointerValue<'ctx>, + shape: ListValue<'ctx>, ) -> Result, String> { let supported_types = [ ctx.primitives.int32, @@ -579,7 +562,7 @@ fn call_ndarray_full_impl<'ctx, 'a>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, - shape: PointerValue<'ctx>, + shape: ListValue<'ctx>, fill_value: BasicValueEnum<'ctx>, ) -> Result, String> { let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; @@ -725,6 +708,7 @@ pub fn gen_ndarray_empty<'ctx, 'a>( assert!(obj.is_none()); assert_eq!(args.len(), 1); + let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone() .to_basic_value_enum(context, generator, shape_ty)?; @@ -733,7 +717,7 @@ pub fn gen_ndarray_empty<'ctx, 'a>( generator, context, context.primitives.float, - shape_arg.into_pointer_value(), + ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), ) } @@ -748,6 +732,7 @@ pub fn gen_ndarray_zeros<'ctx, 'a>( assert!(obj.is_none()); assert_eq!(args.len(), 1); + let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone() .to_basic_value_enum(context, generator, shape_ty)?; @@ -756,7 +741,7 @@ pub fn gen_ndarray_zeros<'ctx, 'a>( generator, context, context.primitives.float, - shape_arg.into_pointer_value(), + ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), ) } @@ -771,6 +756,7 @@ pub fn gen_ndarray_ones<'ctx, 'a>( assert!(obj.is_none()); assert_eq!(args.len(), 1); + let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone() .to_basic_value_enum(context, generator, shape_ty)?; @@ -779,7 +765,7 @@ pub fn gen_ndarray_ones<'ctx, 'a>( generator, context, context.primitives.float, - shape_arg.into_pointer_value(), + ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), ) } @@ -794,6 +780,7 @@ pub fn gen_ndarray_full<'ctx, 'a>( assert!(obj.is_none()); assert_eq!(args.len(), 2); + let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone() .to_basic_value_enum(context, generator, shape_ty)?; @@ -805,7 +792,7 @@ pub fn gen_ndarray_full<'ctx, 'a>( generator, context, fill_value_ty, - shape_arg.into_pointer_value(), + ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), fill_value_arg, ) }