diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 4b781d2..8d7f8e3 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1090,33 +1090,6 @@ pub fn destructure_range<'ctx>( (start, end, step) } -/// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting -/// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified. -/// -/// Setting `ty` to [`None`] implies that the list is empty **and** does not have a known element -/// type, and will therefore set the `list.data` type as `size_t*`. It is undefined behavior to -/// generate a sized list with an unknown element type. -pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ty: Option>, - length: IntValue<'ctx>, - name: Option<&'ctx str>, -) -> ListValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_elem_ty = ty.unwrap_or(llvm_usize.into()); - - // List structure; type { ty*, size_t } - let arr_ty = ListType::new(generator, ctx.ctx, llvm_elem_ty); - let list = arr_ty.alloca_var(generator, ctx, name); - - let length = ctx.builder.build_int_z_extend(length, llvm_usize, "").unwrap(); - list.store_size(ctx, generator, length); - list.create_data(ctx, llvm_elem_ty, None); - - list -} - /// Generates LLVM IR for a [list comprehension expression][expr]. pub fn gen_comprehension<'ctx, G: CodeGenerator>( generator: &mut G, @@ -1189,12 +1162,11 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( "listcomp.alloc_size", ) .unwrap(); - list = allocate_list( + list = ListType::new(generator, ctx.ctx, elem_ty).construct( generator, ctx, - Some(elem_ty), list_alloc_size.into_int_value(), - Some("listcomp.addr"), + Some("listcomp"), ); let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); @@ -1241,7 +1213,12 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( Some("length"), ) .into_int_value(); - list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp")); + list = ListType::new(generator, ctx.ctx, elem_ty).construct( + generator, + ctx, + length, + Some("listcomp"), + ); let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; // counter = -1 @@ -1406,7 +1383,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( .build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "") .unwrap(); - let new_list = allocate_list(generator, ctx, Some(llvm_elem_ty), size, None); + let new_list = ListType::new(generator, ctx.ctx, llvm_elem_ty) + .construct(generator, ctx, size, None); let lhs_size = ctx .builder @@ -1493,10 +1471,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); let sizeof_elem = elem_llvm_ty.size_of().unwrap(); - let new_list = allocate_list( + let new_list = ListType::new(generator, ctx.ctx, elem_llvm_ty).construct( generator, ctx, - Some(elem_llvm_ty), ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(), None, ); @@ -2553,7 +2530,20 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( Some(elements[0].get_type()) }; let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false); - let arr_str_ptr = allocate_list(generator, ctx, ty, length, Some("list")); + let arr_str_ptr = if let Some(ty) = ty { + ListType::new(generator, ctx.ctx, ty).construct( + generator, + ctx, + length, + Some("list"), + ) + } else { + ListType::new_untyped(generator, ctx.ctx).construct_empty( + generator, + ctx, + Some("list"), + ) + }; let arr_ptr = arr_str_ptr.data(); for (i, v) in elements.iter().enumerate() { let elem_ptr = arr_ptr.ptr_offset( @@ -3031,8 +3021,12 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .unwrap(), step, ); - let res_array_ret = - allocate_list(generator, ctx, Some(ty), length, Some("ret")); + let res_array_ret = ListType::new(generator, ctx.ctx, ty).construct( + generator, + ctx, + length, + Some("ret"), + ); let Some(res_ind) = handle_slice_indices( &None, &None, diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9328bb8..9b5af0f 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,5 +1,5 @@ use inkwell::{ - types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType}, + types::{BasicType, BasicTypeEnum, PointerType}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; @@ -639,17 +639,17 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); let list_ty = ListType::from_type(ty, llvm_usize); - let list_elem_ty = list_ty.element_type(); + let list_elem_ty = list_ty.element_type().unwrap(); let ndims = llvm_usize.const_int(1, false); match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) + BasicTypeEnum::PointerType(ptr_ty) if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => { ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) } - AnyTypeEnum::PointerType(ptr_ty) + BasicTypeEnum::PointerType(ptr_ty) if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => { todo!("Getting ndims for list[ndarray] not supported") @@ -670,10 +670,10 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( let llvm_i1 = ctx.ctx.bool_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let list_elem_ty = src_lst.get_type().element_type(); + let list_elem_ty = src_lst.get_type().element_type().unwrap(); match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) + BasicTypeEnum::PointerType(ptr_ty) if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => { // The stride of elements in this dimension, i.e. the number of elements between arr[i] @@ -733,7 +733,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( )?; } - AnyTypeEnum::PointerType(ptr_ty) + BasicTypeEnum::PointerType(ptr_ty) if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => { todo!("Not implemented for list[ndarray]") diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 6608a80..337d049 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -1,68 +1,113 @@ use inkwell::{ - context::Context, + context::{AsContextRef, Context}, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - AddressSpace, + values::{IntValue, PointerValue}, + AddressSpace, IntPredicate, OptimizationLevel, }; +use itertools::Itertools; + +use nac3core_derive::StructFields; use super::ProxyType; -use crate::codegen::{ - values::{ListValue, ProxyValue}, - CodeGenContext, CodeGenerator, +use crate::{ + codegen::{ + types::structure::{ + check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + }, + values::{ListValue, ProxyValue}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{iter_type_vars, Type, TypeEnum}, }; /// Proxy type for a `list` type in LLVM. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct ListType<'ctx> { ty: PointerType<'ctx>, + item: Option>, llvm_usize: IntType<'ctx>, } +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct ListStructFields<'ctx> { + /// Array pointer to content. + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + pub items: StructField<'ctx, PointerValue<'ctx>>, + + /// Number of items in the array. + #[value_type(usize)] + pub len: StructField<'ctx, IntValue<'ctx>>, +} + +impl<'ctx> ListStructFields<'ctx> { + #[must_use] + pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let mut counter = FieldIndexCounter::default(); + + ListStructFields { + items: StructField::create( + &mut counter, + "items", + item.ptr_type(AddressSpace::default()), + ), + len: StructField::create(&mut counter, "len", llvm_usize), + } + } +} + impl<'ctx> ListType<'ctx> { /// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not. pub fn is_representable( llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>, ) -> Result<(), String> { - let llvm_list_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { - return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}")); - }; - if llvm_list_ty.count_fields() != 2 { - return Err(format!( - "Expected 2 fields in `list`, got {}", - llvm_list_ty.count_fields() - )); - } + let ctx = llvm_ty.get_context(); - let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap(); - let Ok(_) = PointerType::try_from(list_size_ty) else { - return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}")); + let llvm_ty = llvm_ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!("Expected struct type for `list` type, got {llvm_ty}")); }; - let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap(); - let Ok(list_data_ty) = IntType::try_from(list_data_ty) else { - return Err(format!("Expected int type for `list.1`, got {list_data_ty}")); - }; - if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected {}-bit int type for `list.1`, got {}-bit int", - llvm_usize.get_bit_width(), - list_data_ty.get_bit_width() - )); - } + let fields = ListStructFields::new(ctx, llvm_usize); - Ok(()) + check_struct_type_matches_fields( + fields, + llvm_ty, + "list", + &[(fields.items.name(), &|ty| { + if ty.is_pointer_type() { + Ok(()) + } else { + Err(format!("Expected T* for `list.items`, got {ty}")) + } + })], + ) + } + + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> ListStructFields<'ctx> { + ListStructFields::new_typed(item, llvm_usize) + } + + /// See [`ListType::fields`]. + // TODO: Move this into e.g. StructProxyType + #[must_use] + pub fn get_fields(&self, _ctx: &impl AsContextRef<'ctx>) -> ListStructFields<'ctx> { + Self::fields(self.item.unwrap_or(self.llvm_usize.into()), self.llvm_usize) } /// Creates an LLVM type corresponding to the expected structure of a `List`. #[must_use] fn llvm_type( ctx: &'ctx Context, - element_type: BasicTypeEnum<'ctx>, + element_type: Option>, llvm_usize: IntType<'ctx>, ) -> PointerType<'ctx> { - // struct List { data: T*, size: size_t } - let field_tys = [element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()]; + let element_type = element_type.unwrap_or(llvm_usize.into()); + + let field_tys = + Self::fields(element_type, llvm_usize).into_iter().map(|field| field.1).collect_vec(); ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } @@ -75,9 +120,50 @@ impl<'ctx> ListType<'ctx> { element_type: BasicTypeEnum<'ctx>, ) -> Self { let llvm_usize = generator.get_size_type(ctx); - let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize); + let llvm_list = Self::llvm_type(ctx, Some(element_type), llvm_usize); - ListType::from_type(llvm_list, llvm_usize) + Self { ty: llvm_list, item: Some(element_type), llvm_usize } + } + + /// Creates an instance of [`ListType`] with an unknown element type. + #[must_use] + pub fn new_untyped(generator: &G, ctx: &'ctx Context) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_list = Self::llvm_type(ctx, None, llvm_usize); + + Self { ty: llvm_list, item: None, llvm_usize } + } + + /// Creates an [`ListType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ) -> Self { + // Check unifier type and extract `item_type` + let elem_type = match &*ctx.unifier.get_ty_immutable(ty) { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + iter_type_vars(params).next().unwrap().ty + } + + _ => panic!("Expected `list` type, but got {}", ctx.unifier.stringify(ty)), + }; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_elem_type = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) { + None + } else { + Some(ctx.get_llvm_type(generator, elem_type)) + }; + + Self { + ty: Self::llvm_type(ctx.ctx, llvm_elem_type, llvm_usize), + item: llvm_elem_type, + llvm_usize, + } } /// Creates an [`ListType`] from a [`PointerType`]. @@ -85,30 +171,39 @@ impl<'ctx> ListType<'ctx> { pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); - ListType { ty: ptr_ty, llvm_usize } + let ctx = ptr_ty.get_context(); + + // We are just searching for the index off a field - Slot an arbitrary element type in. + let item_field_idx = + Self::fields(ctx.i8_type().into(), llvm_usize).index_of_field(|f| f.items); + let item = unsafe { + ptr_ty + .get_element_type() + .into_struct_type() + .get_field_type_at_index_unchecked(item_field_idx) + .into_pointer_type() + .get_element_type() + }; + let item = BasicTypeEnum::try_from(item).unwrap_or_else(|()| { + panic!( + "Expected BasicTypeEnum for list element type, got {}", + ptr_ty.get_element_type().print_to_string() + ) + }); + + ListType { ty: ptr_ty, item: Some(item), llvm_usize } } /// Returns the type of the `size` field of this `list` type. #[must_use] pub fn size_type(&self) -> IntType<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(1) - .map(BasicTypeEnum::into_int_type) - .unwrap() + self.llvm_usize } /// Returns the element type of this `list` type. #[must_use] - pub fn element_type(&self) -> AnyTypeEnum<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(0) - .map(BasicTypeEnum::into_pointer_type) - .map(PointerType::get_element_type) - .unwrap() + pub fn element_type(&self) -> Option> { + self.item } /// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type. @@ -144,6 +239,73 @@ impl<'ctx> ListType<'ctx> { ) } + /// Allocates a [`ListValue`] on the stack using `item` of this [`ListType`] instance. + /// + /// The returned list will contain: + /// + /// - `data`: Allocated with `len` number of elements. + /// - `len`: Initialized to the value of `len` passed to this function. + #[must_use] + pub fn construct( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + len: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + let len = ctx.builder.build_int_z_extend(len, self.llvm_usize, "").unwrap(); + + // Generate a runtime assertion if allocating a non-empty list with unknown element type + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None && self.item.is_none() { + let len_eqz = ctx + .builder + .build_int_compare(IntPredicate::EQ, len, self.llvm_usize.const_zero(), "") + .unwrap(); + + ctx.make_assert( + generator, + len_eqz, + "0:AssertionError", + "Cannot allocate a non-empty list with unknown element type", + [None, None, None], + ctx.current_loc, + ); + } + + let plist = self.alloca_var(generator, ctx, name); + plist.store_size(ctx, generator, len); + + let item = self.item.unwrap_or(self.llvm_usize.into()); + plist.create_data(ctx, item, None); + + plist + } + + /// Convenience function for creating a list with zero elements. + /// + /// This function is preferred over [`ListType::construct`] if the length is known to always be + /// 0, as this function avoids injecting an IR assertion for checking if a non-empty untyped + /// list is being allocated. + /// + /// The returned list will contain: + /// + /// - `data`: Initialized to `(T*) 0`. + /// - `len`: Initialized to `0`. + #[must_use] + pub fn construct_empty( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + let plist = self.alloca_var(generator, ctx, name); + + plist.store_size(ctx, generator, self.llvm_usize.const_zero()); + plist.create_data(ctx, self.item.unwrap_or(self.llvm_usize.into()), None); + + plist + } + /// Converts an existing value into a [`ListValue`]. #[must_use] pub fn map_value( diff --git a/nac3core/src/codegen/types/structure.rs b/nac3core/src/codegen/types/structure.rs index 4d6dcaf..87781d1 100644 --- a/nac3core/src/codegen/types/structure.rs +++ b/nac3core/src/codegen/types/structure.rs @@ -5,6 +5,7 @@ use inkwell::{ types::{BasicTypeEnum, IntType, StructType}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, }; +use itertools::Itertools; use crate::codegen::CodeGenContext; @@ -55,6 +56,20 @@ pub trait StructFields<'ctx>: Eq + Copy { { self.into_vec().into_iter() } + + /// Returns the field index of a field in this structure. + fn index_of_field(&self, name: impl FnOnce(&Self) -> StructField<'ctx, V>) -> u32 + where + V: BasicValue<'ctx> + TryFrom, Error = ()>, + { + let field_name = name(self).name; + self.index_of_field_name(field_name).unwrap() + } + + /// Returns the field index of a field with the given name in this structure. + fn index_of_field_name(&self, field_name: &str) -> Option { + self.iter().find_position(|(name, _)| *name == field_name).map(|(idx, _)| idx as u32) + } } /// A single field of an LLVM structure. diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 549bfe3..bd115a2 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -8,7 +8,7 @@ use super::{ ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::{ - types::ListType, + types::{structure::StructField, ListType}, {CodeGenContext, CodeGenerator}, }; @@ -42,48 +42,26 @@ impl<'ctx> ListValue<'ctx> { ListValue { value: ptr, llvm_usize, name } } + fn items_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields(&ctx.ctx).items + } + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Returns the pointer to the field storing the size of this `list`. - fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ) - .unwrap() - } + self.items_field(ctx).ptr_by_gep(ctx, self.value, self.name) } /// Stores the array of data elements `data` into this instance. fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { - ctx.builder.build_store(self.pptr_to_data(ctx), data).unwrap(); + self.items_field(ctx).set(ctx, self.value, data, self.name); } /// Convenience method for creating a new array storing data elements with the given element /// type `elem_ty` and `size`. /// - /// If `size` is [None], the size stored in the field of this instance is used instead. + /// If `size` is [None], the size stored in the field of this instance is used instead. If + /// `size` is resolved to `0` at runtime, `(T*) 0` will be assigned to `data`. pub fn create_data( &self, ctx: &mut CodeGenContext<'ctx, '_>, @@ -114,6 +92,10 @@ impl<'ctx> ListValue<'ctx> { ListDataProxy(self) } + fn len_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields(&ctx.ctx).len + } + /// Stores the `size` of this `list` into this instance. pub fn store_size( &self, @@ -123,22 +105,16 @@ impl<'ctx> ListValue<'ctx> { ) { debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx)); - let psize = self.ptr_to_size(ctx); - ctx.builder.build_store(psize, size).unwrap(); + self.len_field(ctx).set(ctx, self.value, size, self.name); } /// Returns the size of this `list` as a value. - pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { - let psize = self.ptr_to_size(ctx); - let var_name = name - .map(ToString::to_string) - .or_else(|| self.name.map(|v| format!("{v}.size"))) - .unwrap_or_default(); - - ctx.builder - .build_load(psize, var_name.as_str()) - .map(BasicValueEnum::into_int_value) - .unwrap() + pub fn load_size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> IntValue<'ctx> { + self.len_field(ctx).get(ctx, self.value, name) } }