diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 17952369..85b963bb 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -41,6 +41,7 @@ pub mod extern_fns; mod generator; pub mod irrt; pub mod llvm_intrinsics; +pub mod model; pub mod numpy; pub mod stmt; diff --git a/nac3core/src/codegen/model/core.rs b/nac3core/src/codegen/model/core.rs new file mode 100644 index 00000000..bcc4a6b7 --- /dev/null +++ b/nac3core/src/codegen/model/core.rs @@ -0,0 +1,92 @@ +use inkwell::{ + context::Context, + types::{AnyTypeEnum, BasicTypeEnum}, + values::{AnyValue, AnyValueEnum, BasicValueEnum}, +}; + +use crate::codegen::CodeGenContext; + +use super::{slice::ArraySlice, Int, Pointer}; + +/// A value that belongs to/produced by a [`Model<'ctx>`] +pub trait ModelValue<'ctx>: Clone + Copy { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx>; +} + +// Should have been within [`Model<'ctx>`], +// but rust object safety requirements made it necessary to +// split this interface out +pub trait CanCheckLLVMType<'ctx> { + /// Check if `scrutinee` matches the same LLVM type of this [`Model<'ctx>`]. + /// + /// If they don't not match, a human-readable error message is returned. + fn check_llvm_type( + &self, + ctx: &'ctx Context, + scrutinee: AnyTypeEnum<'ctx>, + ) -> Result<(), String>; +} + +/// A [`Model`] is a type-safe concrete representation of a complex LLVM type. +pub trait Model<'ctx>: Clone + Copy + CanCheckLLVMType<'ctx> + Sized { + /// The values that inhabit this [`Model<'ctx>`]. + /// + /// ...that is the type of wrapper that wraps the LLVM values that inhabit [`Model<'ctx>::get_llvm_type()`]. + type Value: ModelValue<'ctx>; + + /// Get the [`BasicTypeEnum<'ctx>`] this [`Model<'ctx>`] is representing. + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>; + + /// Cast an [`AnyValueEnum<'ctx>`] into [`Self::Value`]. + /// + /// Panics if `value` cannot pass [`CanCheckLLVMType::check_llvm_type()`]. + fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value; + + /// Check if [`Self::Value`] has the correct type described by this [`Model<'ctx>`] + /// + /// For example: + /// ```ignore + /// let ctx: &CodeGenContext<'ctx, '_>; + /// let my_i32 = IntModel(ctx.ctx.i32_type()); + /// let my_i64 = IntModel(ctx.ctx.i64_type()); + /// let value1 = my_i32.constant(3); + /// let value2 = my_i64.constant(3); + /// // Both value1 and value2 have type `IntModel<'ctx>`! + /// // There is no type constraints to tell which value has what int type. + /// my_i32.check(value1); // ok + /// my_i64.check(value2); // ok + /// + /// my_i32.check(value2); // PANIC + /// my_i64.check(value1); // PANIC + /// ``` + fn check(&self, ctx: &'ctx Context, value: Self::Value) { + self.review(ctx, value.get_llvm_value().as_any_value_enum()); + } + + /// Build an instruction to allocate a value of [`Model::get_llvm_type`]. + fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Pointer<'ctx, Self> { + Pointer { + element: *self, + value: ctx.builder.build_alloca(self.get_llvm_type(ctx.ctx), name).unwrap(), + } + } + + /// Build an instruction to allocate an array of [`Model::get_llvm_type`]. + fn array_alloca( + &self, + ctx: &CodeGenContext<'ctx, '_>, + count: Int<'ctx>, + name: &str, + ) -> ArraySlice<'ctx, Self> { + ArraySlice { + num_elements: count, + pointer: Pointer { + element: *self, + value: ctx + .builder + .build_array_alloca(self.get_llvm_type(ctx.ctx), count.0, name) + .unwrap(), + }, + } + } +} diff --git a/nac3core/src/codegen/model/fixed_int.rs b/nac3core/src/codegen/model/fixed_int.rs new file mode 100644 index 00000000..102e23cf --- /dev/null +++ b/nac3core/src/codegen/model/fixed_int.rs @@ -0,0 +1,159 @@ +use inkwell::{ + context::Context, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, + values::{AnyValueEnum, BasicValue, BasicValueEnum, IntValue}, +}; + +use crate::codegen::CodeGenContext; + +use super::{ + core::*, + int_util::{check_int_llvm_type, review_int_llvm_value}, + Int, IntModel, +}; + +/// A marker trait to mark singleton struct that describes a particular fixed integer type. +/// See [`Bool`], [`Byte`], [`Int32`], etc. +/// +/// The [`Default`] trait is to enable auto-derivations for utilities like +/// [`FieldBuilder::add_field_auto`] +pub trait IsFixedInt: Clone + Copy + Default { + fn get_int_type(ctx: &Context) -> IntType<'_>; + fn get_bit_width() -> u32; // This is required, instead of only relying on get_int_type + + fn constant<'ctx>(&self, ctx: &'ctx Context, value: u64) -> FixedInt<'ctx, Self> { + FixedInt { int: *self, value: Self::get_int_type(ctx).const_int(value, false) } + } +} + +// Some pre-defined fixed integers + +#[derive(Debug, Clone, Copy, Default)] +pub struct Bool; +pub type BoolModel = FixedIntModel; + +impl IsFixedInt for Bool { + fn get_int_type(ctx: &Context) -> IntType<'_> { + ctx.bool_type() + } + + fn get_bit_width() -> u32 { + 1 + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Byte; +pub type ByteModel = FixedIntModel; + +impl IsFixedInt for Byte { + fn get_int_type(ctx: &Context) -> IntType<'_> { + ctx.i8_type() + } + + fn get_bit_width() -> u32 { + 8 + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Int32; +pub type Int32Model = FixedIntModel; + +impl IsFixedInt for Int32 { + fn get_int_type(ctx: &Context) -> IntType<'_> { + ctx.i32_type() + } + + fn get_bit_width() -> u32 { + 32 + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Int64; +pub type Int64Model = FixedIntModel; + +impl IsFixedInt for Int64 { + fn get_int_type(ctx: &Context) -> IntType<'_> { + ctx.i64_type() + } + + fn get_bit_width() -> u32 { + 64 + } +} + +/// A model representing a compile-time known [`IntType<'ctx>`]. +/// +/// Also see [`IntModel`], which is less constrained than [`FixedIntModel`], +/// but enables one to handle [`IntType<'ctx>`] that could be dynamic +#[derive(Debug, Clone, Copy, Default)] +pub struct FixedIntModel(pub T); + +// FixedIntModel's implementation + +impl<'ctx, T: IsFixedInt> CanCheckLLVMType<'ctx> for FixedIntModel { + fn check_llvm_type( + &self, + ctx: &'ctx Context, + scrutinee: AnyTypeEnum<'ctx>, + ) -> Result<(), String> { + check_int_llvm_type(scrutinee, T::get_int_type(ctx)) + } +} + +impl<'ctx, T: IsFixedInt> Model<'ctx> for FixedIntModel { + type Value = FixedInt<'ctx, T>; + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + T::get_int_type(ctx).as_basic_type_enum() + } + + fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value { + let value = review_int_llvm_value(value, T::get_int_type(ctx)).unwrap(); + FixedInt { int: self.0, value } + } +} + +impl FixedIntModel { + pub fn to_int_model(self, ctx: &Context) -> IntModel<'_> { + IntModel(T::get_int_type(ctx)) + } +} + +/// An inhabitant of [`FixedIntModel<'ctx>`] +#[derive(Debug, Clone, Copy)] +pub struct FixedInt<'ctx, T: IsFixedInt> { + pub int: T, + pub value: IntValue<'ctx>, +} + +// FixedInt's Implementation + +impl<'ctx, T: IsFixedInt> ModelValue<'ctx> for FixedInt<'ctx, T> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.value.as_basic_value_enum() + } +} + +impl<'ctx, T: IsFixedInt> FixedInt<'ctx, T> { + pub fn to_int(self) -> Int<'ctx> { + Int(self.value) + } + + pub fn signed_cast_to_fixed( + self, + ctx: &CodeGenContext<'ctx, '_>, + target_fixed_int: R, + name: &str, + ) -> FixedInt<'ctx, R> { + FixedInt { + int: target_fixed_int, + value: ctx + .builder + .build_int_s_extend_or_bit_cast(self.value, R::get_int_type(ctx.ctx), name) + .unwrap(), + } + } +} diff --git a/nac3core/src/codegen/model/int.rs b/nac3core/src/codegen/model/int.rs new file mode 100644 index 00000000..0aacae56 --- /dev/null +++ b/nac3core/src/codegen/model/int.rs @@ -0,0 +1,97 @@ +use inkwell::{ + context::Context, + types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, + values::{AnyValueEnum, BasicValue, BasicValueEnum, IntValue}, +}; + +use crate::codegen::CodeGenContext; + +use super::{core::*, int_util::check_int_llvm_type, FixedInt, IsFixedInt}; + +/// A model representing an [`IntType<'ctx>`]. +/// +/// Also see [`FixedIntModel`], which is more constrained than [`IntModel`] +/// but provides more type-safe mechanisms and even auto-derivation of [`BasicTypeEnum<'ctx>`] +/// for creating LLVM structures. +#[derive(Debug, Clone, Copy)] +pub struct IntModel<'ctx>(pub IntType<'ctx>); + +impl<'ctx> CanCheckLLVMType<'ctx> for IntModel<'ctx> { + fn check_llvm_type( + &self, + _ctx: &'ctx Context, + scrutinee: AnyTypeEnum<'ctx>, + ) -> Result<(), String> { + check_int_llvm_type(scrutinee, self.0) + } +} + +impl<'ctx> Model<'ctx> for IntModel<'ctx> { + type Value = Int<'ctx>; + + fn get_llvm_type(&self, _ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.as_basic_type_enum() + } + + fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value { + let int = value.into_int_value(); + self.check_llvm_type(ctx, int.get_type().as_any_type_enum()).unwrap(); + Int(int) + } +} + +impl<'ctx> IntModel<'ctx> { + /// Create a constant value that inhabits this [`IntModel<'ctx>`]. + #[must_use] + pub fn constant(&self, value: u64) -> Int<'ctx> { + Int(self.0.const_int(value, false)) + } + + /// Check if `other` is fully compatible with this [`IntModel<'ctx>`]. + /// + /// This simply checks if the underlying [`IntType<'ctx>`] has + /// the same number of bits. + #[must_use] + pub fn same_as(&self, other: IntModel<'ctx>) -> bool { + // TODO: or `self.0 == other.0` would also work? + self.0.get_bit_width() == other.0.get_bit_width() + } +} + +/// An inhabitant of an [`IntModel<'ctx>`] +#[derive(Debug, Clone, Copy)] +pub struct Int<'ctx>(pub IntValue<'ctx>); + +impl<'ctx> ModelValue<'ctx> for Int<'ctx> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.0.as_basic_value_enum() + } +} + +impl<'ctx> Int<'ctx> { + #[must_use] + pub fn signed_cast_to_int( + self, + ctx: &CodeGenContext<'ctx, '_>, + target_int: IntModel<'ctx>, + name: &str, + ) -> Int<'ctx> { + Int(ctx.builder.build_int_s_extend_or_bit_cast(self.0, target_int.0, name).unwrap()) + } + + #[must_use] + pub fn signed_cast_to_fixed( + self, + ctx: &CodeGenContext<'ctx, '_>, + target_fixed: T, + name: &str, + ) -> FixedInt<'ctx, T> { + FixedInt { + int: target_fixed, + value: ctx + .builder + .build_int_s_extend_or_bit_cast(self.0, T::get_int_type(ctx.ctx), name) + .unwrap(), + } + } +} diff --git a/nac3core/src/codegen/model/int_util.rs b/nac3core/src/codegen/model/int_util.rs new file mode 100644 index 00000000..7f9ea7af --- /dev/null +++ b/nac3core/src/codegen/model/int_util.rs @@ -0,0 +1,39 @@ +use inkwell::{ + types::{AnyType, AnyTypeEnum, IntType}, + values::{AnyValueEnum, IntValue}, +}; + +/// Helper function to check if `scrutinee` is the same as `expected_int_type` +pub fn check_int_llvm_type<'ctx>( + scrutinee: AnyTypeEnum<'ctx>, + expected_int_type: IntType<'ctx>, +) -> Result<(), String> { + // Check if llvm_type is int type + let AnyTypeEnum::IntType(scrutinee) = scrutinee else { + return Err(format!("Expecting an int type but got {scrutinee:?}")); + }; + + // Check bit width + if scrutinee.get_bit_width() != expected_int_type.get_bit_width() { + return Err(format!( + "Expecting an int type of {}-bit(s) but got int type {}-bit(s)", + expected_int_type.get_bit_width(), + scrutinee.get_bit_width() + )); + } + + Ok(()) +} + +/// Helper function to cast `scrutinee` is into an [`IntValue<'ctx>`]. +/// The LLVM type of `scrutinee` will be checked with [`check_int_llvm_type`]. +pub fn review_int_llvm_value<'ctx>( + value: AnyValueEnum<'ctx>, + expected_int_type: IntType<'ctx>, +) -> Result, String> { + // Check if value is of int type, error if that is anything else + check_int_llvm_type(value.get_type().as_any_type_enum(), expected_int_type)?; + + // Ok, it is must be an int + Ok(value.into_int_value()) +} diff --git a/nac3core/src/codegen/model/mod.rs b/nac3core/src/codegen/model/mod.rs new file mode 100644 index 00000000..2ebf02b8 --- /dev/null +++ b/nac3core/src/codegen/model/mod.rs @@ -0,0 +1,16 @@ +pub mod core; +pub mod fixed_int; +pub mod int; +mod int_util; +pub mod opaque; +pub mod pointer; +pub mod slice; +pub mod structure; + +pub use core::*; +pub use fixed_int::*; +pub use int::*; +pub use opaque::*; +pub use pointer::*; +pub use slice::*; +pub use structure::*; diff --git a/nac3core/src/codegen/model/opaque.rs b/nac3core/src/codegen/model/opaque.rs new file mode 100644 index 00000000..8b3ef426 --- /dev/null +++ b/nac3core/src/codegen/model/opaque.rs @@ -0,0 +1,46 @@ +use inkwell::{ + context::Context, + types::{AnyTypeEnum, BasicTypeEnum}, + values::{AnyValueEnum, BasicValueEnum}, +}; + +use super::*; + +#[derive(Debug, Clone, Copy)] +pub struct OpaqueModel<'ctx>(pub BasicTypeEnum<'ctx>); + +impl<'ctx> CanCheckLLVMType<'ctx> for OpaqueModel<'ctx> { + fn check_llvm_type( + &self, + _ctx: &'ctx Context, + scrutinee: AnyTypeEnum<'ctx>, + ) -> Result<(), String> { + match BasicTypeEnum::try_from(scrutinee) { + Ok(_) => Ok(()), + Err(_err) => Err(format!("Expecting a BasicTypeEnum, but got {scrutinee:?}")), + } + } +} + +impl<'ctx> Model<'ctx> for OpaqueModel<'ctx> { + type Value = Opaque<'ctx>; + + fn get_llvm_type(&self, _ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0 + } + + fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value { + self.check_llvm_type(ctx, value.get_type()).unwrap(); + let value = BasicValueEnum::try_from(value).unwrap(); // Must work + Opaque(value) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct Opaque<'ctx>(pub BasicValueEnum<'ctx>); + +impl<'ctx> ModelValue<'ctx> for Opaque<'ctx> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.0 + } +} diff --git a/nac3core/src/codegen/model/pointer.rs b/nac3core/src/codegen/model/pointer.rs new file mode 100644 index 00000000..5dc1ea2a --- /dev/null +++ b/nac3core/src/codegen/model/pointer.rs @@ -0,0 +1,83 @@ +use inkwell::{ + context::Context, + types::{AnyTypeEnum, BasicType, BasicTypeEnum}, + values::{AnyValue, AnyValueEnum, BasicValue, BasicValueEnum, PointerValue}, + AddressSpace, +}; + +use crate::codegen::CodeGenContext; + +use super::{core::*, OpaqueModel}; + +/// A [`Model<'ctx>`] representing an LLVM [`PointerType<'ctx>`] +/// with *full* information on the element u +/// +/// [`self.0`] contains [`Model<'ctx>`] that represents the +/// LLVM type of element of the [`PointerType<'ctx>`] is pointing at +/// (like `PointerType<'ctx>::get_element_type()`, but abstracted as a [`Model<'ctx>`]). +#[derive(Debug, Clone, Copy, Default)] +pub struct PointerModel(pub E); + +impl<'ctx, E: Model<'ctx>> CanCheckLLVMType<'ctx> for PointerModel { + fn check_llvm_type( + &self, + ctx: &'ctx Context, + scrutinee: AnyTypeEnum<'ctx>, + ) -> Result<(), String> { + // Check if scrutinee is even a PointerValue + let AnyTypeEnum::PointerType(scrutinee) = scrutinee else { + return Err(format!("Expecting a pointer value, but got {scrutinee:?}")); + }; + + // Check the type of what the pointer is pointing at + // TODO: This will be deprecated by inkwell > llvm14 because `get_element_type()` will be gone + self.0.check_llvm_type(ctx, scrutinee.get_element_type())?; // TODO: Include backtrace? + + Ok(()) + } +} + +impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel { + type Value = Pointer<'ctx, E>; + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum() + } + + fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value { + self.check_llvm_type(ctx, value.get_type()).unwrap(); + + // TODO: Check get_element_type(). For inkwell LLVM 14 at least... + Pointer { element: self.0, value: value.into_pointer_value() } + } +} + +/// An inhabitant of [`PointerModel`] +#[derive(Debug, Clone, Copy)] +pub struct Pointer<'ctx, E: Model<'ctx>> { + pub element: E, + pub value: PointerValue<'ctx>, +} + +impl<'ctx, E: Model<'ctx>> ModelValue<'ctx> for Pointer<'ctx, E> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.value.as_basic_value_enum() + } +} + +impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> { + /// Build an instruction to store a value into this pointer + pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, val: E::Value) { + ctx.builder.build_store(self.value, val.get_llvm_value()).unwrap(); + } + + /// Build an instruction to load a value from this pointer + pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> E::Value { + let val = ctx.builder.build_load(self.value, name).unwrap(); + self.element.review(ctx.ctx, val.as_any_value_enum()) + } + + pub fn to_opaque(self, ctx: &'ctx Context) -> Pointer<'ctx, OpaqueModel<'ctx>> { + Pointer { element: OpaqueModel(self.element.get_llvm_type(ctx)), value: self.value } + } +} diff --git a/nac3core/src/codegen/model/slice.rs b/nac3core/src/codegen/model/slice.rs new file mode 100644 index 00000000..5441aba3 --- /dev/null +++ b/nac3core/src/codegen/model/slice.rs @@ -0,0 +1,73 @@ +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::{Int, Model, Pointer}; + +pub struct ArraySlice<'ctx, E: Model<'ctx>> { + pub num_elements: Int<'ctx>, + pub pointer: Pointer<'ctx, E>, +} + +impl<'ctx, E: Model<'ctx>> ArraySlice<'ctx, E> { + pub fn ix_unchecked( + &self, + ctx: &CodeGenContext<'ctx, '_>, + idx: Int<'ctx>, + name: &str, + ) -> Pointer<'ctx, E> { + let element_addr = + unsafe { ctx.builder.build_in_bounds_gep(self.pointer.value, &[idx.0], name).unwrap() }; + Pointer { value: element_addr, element: self.pointer.element } + } + + pub fn ix( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + idx: Int<'ctx>, + name: &str, + ) -> Pointer<'ctx, E> { + let int_type = self.num_elements.0.get_type(); // NOTE: Weird get_type(), see comment under `trait Ixed` + + assert_eq!(int_type.get_bit_width(), idx.0.get_type().get_bit_width()); // Might as well check bit width to catch bugs + + // TODO: SGE or UGE? or make it defined by the implementee? + + // Check `0 <= index` + let lower_bounded = ctx + .builder + .build_int_compare( + inkwell::IntPredicate::SLE, + int_type.const_zero(), + idx.0, + "lower_bounded", + ) + .unwrap(); + + // Check `index < num_elements` + let upper_bounded = ctx + .builder + .build_int_compare( + inkwell::IntPredicate::SLT, + idx.0, + self.num_elements.0, + "upper_bounded", + ) + .unwrap(); + + // Compute `0 <= index && index < num_elements` + let bounded = ctx.builder.build_and(lower_bounded, upper_bounded, "bounded").unwrap(); + + // Assert `bounded` + ctx.make_assert( + generator, + bounded, + "0:IndexError", + "nac3core LLVM codegen attempting to access out of bounds array index {0}. Must satisfy 0 <= index < {2}", + [ Some(idx.0), Some(self.num_elements.0), None], + ctx.current_loc + ); + + // ...and finally do indexing + self.ix_unchecked(ctx, idx, name) + } +} diff --git a/nac3core/src/codegen/model/structure.rs b/nac3core/src/codegen/model/structure.rs new file mode 100644 index 00000000..bb0404f4 --- /dev/null +++ b/nac3core/src/codegen/model/structure.rs @@ -0,0 +1,219 @@ +use inkwell::{ + context::Context, + types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, StructType}, + values::{AnyValueEnum, BasicValue, BasicValueEnum, StructValue}, +}; +use itertools::{izip, Itertools}; + +use crate::codegen::CodeGenContext; + +use super::{core::CanCheckLLVMType, Model, ModelValue, Pointer}; + +#[derive(Debug, Clone, Copy)] +pub struct Field { + pub gep_index: u64, + pub name: &'static str, + pub element: E, +} + +struct FieldLLVM<'ctx> { + gep_index: u64, + name: &'ctx str, + llvm_type: BasicTypeEnum<'ctx>, + + // Only CanCheckLLVMType is needed, dont use `Model<'ctx>` + llvm_type_model: Box + 'ctx>, +} + +pub struct FieldBuilder<'ctx> { + pub ctx: &'ctx Context, + gep_index_counter: u64, + struct_name: &'ctx str, + fields: Vec>, +} + +impl<'ctx> FieldBuilder<'ctx> { + #[must_use] + pub fn new(ctx: &'ctx Context, struct_name: &'ctx str) -> Self { + FieldBuilder { ctx, gep_index_counter: 0, struct_name, fields: Vec::new() } + } + + fn next_gep_index(&mut self) -> u64 { + let index = self.gep_index_counter; + self.gep_index_counter += 1; + index + } + + pub fn add_field + 'ctx>(&mut self, name: &'static str, element: E) -> Field { + let gep_index = self.next_gep_index(); + + self.fields.push(FieldLLVM { + gep_index, + name, + llvm_type: element.get_llvm_type(self.ctx), + llvm_type_model: Box::new(element), + }); + + Field { gep_index, name, element } + } + + pub fn add_field_auto + Default + 'ctx>( + &mut self, + name: &'static str, + ) -> Field { + self.add_field(name, E::default()) + } +} + +/// A marker trait to mark singleton struct that describes a particular LLVM structure. +pub trait IsStruct<'ctx>: Clone + Copy { + /// The type of the Rust `struct` that holds all the fields of this LLVM struct. + type Fields; + + /// A cosmetic name for this struct. + /// TODO: Currently unused. To be used in error reporting. + fn struct_name(&self) -> &'static str; + + fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields; + + fn get_fields(&self, ctx: &'ctx Context) -> Self::Fields { + let mut builder = FieldBuilder::new(ctx, self.struct_name()); + self.build_fields(&mut builder) + } + + /// Get the LLVM struct type this [`IsStruct<'ctx>`] is representing. + fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> { + let mut builder = FieldBuilder::new(ctx, self.struct_name()); + self.build_fields(&mut builder); // Self::Fields is discarded + + let field_types = builder.fields.iter().map(|f| f.llvm_type).collect_vec(); + ctx.struct_type(&field_types, false) + } + + /// Check if `scrutinee` matches the [`StructType<'ctx>`] this [`IsStruct<'ctx>`] is representing. + fn check_struct_type( + &self, + ctx: &'ctx Context, + scrutinee: StructType<'ctx>, + ) -> Result<(), String> { + // Details about scrutinee + let scrutinee_field_types = scrutinee.get_field_types(); + + // Details about the defined specifications of this struct + // We will access them through builder + let mut builder = FieldBuilder::new(ctx, self.struct_name()); + self.build_fields(&mut builder); + + // Check # of fields + if builder.fields.len() != scrutinee_field_types.len() { + return Err(format!( + "Expecting struct to have {} field(s), but scrutinee has {} field(s)", + builder.fields.len(), + scrutinee_field_types.len() + )); + } + + // Check the types of each field + // TODO: Traceback? + for (f, scrutinee_field_type) in izip!(builder.fields, scrutinee_field_types) { + f.llvm_type_model.check_llvm_type(ctx, scrutinee_field_type.as_any_type_enum())?; + } + + Ok(()) + } +} + +/// A [`Model<'ctx>`] that represents an LLVM struct. +/// +/// `self.0` contains a [`IsStruct<'ctx>`] that gives the details of the LLVM struct. +#[derive(Debug, Clone, Copy, Default)] +pub struct StructModel(pub S); + +impl<'ctx, S: IsStruct<'ctx>> CanCheckLLVMType<'ctx> for StructModel { + fn check_llvm_type( + &self, + ctx: &'ctx Context, + scrutinee: AnyTypeEnum<'ctx>, + ) -> Result<(), String> { + // Check if scrutinee is even a struct type + let AnyTypeEnum::StructType(scrutinee) = scrutinee else { + return Err(format!("Expecting a struct type, but got {scrutinee:?}")); + }; + + // Ok. now check the struct type *thoroughly* + self.0.check_struct_type(ctx, scrutinee) + } +} + +impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel { + type Value = Struct<'ctx, S>; + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.get_struct_type(ctx).as_basic_type_enum() + } + + fn review(&self, ctx: &'ctx Context, value: AnyValueEnum<'ctx>) -> Self::Value { + // Check that `value` is not some bogus values or an incorrect StructValue + self.check_llvm_type(ctx, value.get_type()).unwrap(); + + Struct { structure: self.0, value: value.into_struct_value() } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct Struct<'ctx, S> { + pub structure: S, + pub value: StructValue<'ctx>, +} + +impl<'ctx, S: IsStruct<'ctx>> ModelValue<'ctx> for Struct<'ctx, S> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.value.as_basic_value_enum() + } +} + +impl<'ctx, S: IsStruct<'ctx>> Pointer<'ctx, StructModel> { + /// Build an instruction that does `getelementptr` on an LLVM structure referenced by this pointer. + /// + /// This provides a nice syntax to chain up `getelementptr` in an intuitive and type-safe way: + /// + /// ```ignore + /// let ctx: &CodeGenContext<'ctx, '_>; + /// let ndarray: Pointer<'ctx, StructModel>>; + /// ndarray.gep(ctx, |f| f.ndims).store(); + /// ``` + /// + /// You might even write chains `gep`, i.e., + /// ```ignore + /// my_struct + /// .gep(ctx, |f| f.thing1) + /// .gep(ctx, |f| f.value) + /// .store(ctx, my_value) // Equivalent to `my_struct.thing1.value = my_value` + /// ``` + pub fn gep( + &self, + ctx: &CodeGenContext<'ctx, '_>, + get_field: GetFieldFn, + ) -> Pointer<'ctx, E> + where + E: Model<'ctx>, + GetFieldFn: FnOnce(S::Fields) -> Field, + { + let fields = self.element.0.get_fields(ctx.ctx); + let field = get_field(fields); + + let llvm_i32 = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to *just* use i32 for GEP like that + + let ptr = unsafe { + ctx.builder + .build_in_bounds_gep( + self.value, + &[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)], + field.name, + ) + .unwrap() + }; + + Pointer { element: field.element, value: ptr } + } +}