From c772fdb83a6c8db5e7931852f5ea0d67277ca23a Mon Sep 17 00:00:00 2001 From: lyken Date: Sat, 27 Jul 2024 18:21:01 +0800 Subject: [PATCH] core/model: introduce codegen/model --- nac3core/src/codegen/mod.rs | 1 + nac3core/src/codegen/model/core.rs | 161 +++++++++++++++++ nac3core/src/codegen/model/int.rs | 228 ++++++++++++++++++++++++ nac3core/src/codegen/model/mod.rs | 11 ++ nac3core/src/codegen/model/ptr.rs | 141 +++++++++++++++ nac3core/src/codegen/model/slice.rs | 72 ++++++++ nac3core/src/codegen/model/structure.rs | 174 ++++++++++++++++++ 7 files changed, 788 insertions(+) create mode 100644 nac3core/src/codegen/model/core.rs create mode 100644 nac3core/src/codegen/model/int.rs create mode 100644 nac3core/src/codegen/model/mod.rs create mode 100644 nac3core/src/codegen/model/ptr.rs create mode 100644 nac3core/src/codegen/model/slice.rs create mode 100644 nac3core/src/codegen/model/structure.rs diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 2f4a9ec5..dc53953c 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -41,6 +41,7 @@ pub mod extern_fns; mod generator; pub mod irrt; pub mod llvm_intrinsics; +pub mod model; pub mod numpy; pub mod stmt; diff --git a/nac3core/src/codegen/model/core.rs b/nac3core/src/codegen/model/core.rs new file mode 100644 index 00000000..d2dcd8f0 --- /dev/null +++ b/nac3core/src/codegen/model/core.rs @@ -0,0 +1,161 @@ +use std::fmt; + +use inkwell::{context::Context, types::*, values::*}; + +use super::*; +use crate::codegen::{CodeGenContext, CodeGenerator}; + +#[derive(Clone, Copy)] +pub struct TypeContext<'ctx> { + pub size_type: IntType<'ctx>, +} + +pub trait HasTypeContext { + fn type_context<'ctx>(&self, ctx: &'ctx Context) -> TypeContext<'ctx>; +} + +impl HasTypeContext for T { + fn type_context<'ctx>(&self, ctx: &'ctx Context) -> TypeContext<'ctx> { + TypeContext { size_type: self.get_size_type(ctx) } + } +} + +#[derive(Debug, Clone)] +pub struct ModelError(pub String); + +impl ModelError { + pub(super) fn under_context(mut self, context: &str) -> Self { + self.0.push_str(" ... in "); + self.0.push_str(context); + self + } +} + +/// A [`Model`] is a singleton object that uniquely identifies a [`BasicType`] +/// solely from a [`CodeGenerator`] and a [`Context`]. +pub trait Model: CheckType + fmt::Debug + Clone + Copy + Default { + type Value<'ctx>: BasicValue<'ctx> + TryFrom>; + type Type<'ctx>: BasicType<'ctx>; + + /// Return the [`BasicType`] of this model. + fn get_type<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> Self::Type<'ctx>; + + /// Check if a [`BasicType`] is the same type of this model. + fn check_type<'ctx, T: BasicType<'ctx>>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + self.check_type_impl(tyctx, ctx, ty.as_basic_type_enum()) + } + + /// Create an instance from a value with [`Instance::model`] being this model. + /// + /// Caller must make sure the type of `value` and the type of this `model` are equivalent. + fn believe_value<'ctx>(&self, value: Self::Value<'ctx>) -> Instance<'ctx, Self> { + Instance { model: *self, value } + } + + /// Check if a [`BasicValue`]'s type is equivalent to the type of this model. + /// Wrap it into an [`Instance`] if it is. + fn check_value<'ctx, V: BasicValue<'ctx>>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + value: V, + ) -> Result, ModelError> { + let value = value.as_basic_value_enum(); + self.check_type(tyctx, ctx, value.get_type()) + .map_err(|err| err.under_context("the value {value:?}"))?; + + let Ok(value) = Self::Value::try_from(value) else { + unreachable!("check_type() has bad implementation") + }; + Ok(self.believe_value(value)) + } + + // Allocate a value on the stack and return its pointer. + fn alloca<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + name: &str, + ) -> Ptr<'ctx, Self> { + let pmodel = PtrModel(*self); + let p = ctx.builder.build_alloca(self.get_type(tyctx, ctx.ctx), name).unwrap(); + pmodel.believe_value(p) + } + + // Allocate an array on the stack and return its pointer. + fn array_alloca<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + len: IntValue<'ctx>, + name: &str, + ) -> Ptr<'ctx, Self> { + let pmodel = PtrModel(*self); + let p = ctx.builder.build_array_alloca(self.get_type(tyctx, ctx.ctx), len, name).unwrap(); + pmodel.believe_value(p) + } + + fn var_alloca<'ctx, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&str>, + ) -> Result, String> { + let tyctx = generator.type_context(ctx.ctx); + + let pmodel = PtrModel(*self); + let p = generator.gen_var_alloc( + ctx, + self.get_type(tyctx, ctx.ctx).as_basic_type_enum(), + name, + )?; + Ok(pmodel.believe_value(p)) + } + + fn array_var_alloca<'ctx, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + len: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> Result, String> { + let tyctx = generator.type_context(ctx.ctx); + + // TODO: Remove ArraySliceValue + let pmodel = PtrModel(*self); + let p = generator.gen_array_var_alloc( + ctx, + self.get_type(tyctx, ctx.ctx).as_basic_type_enum(), + len, + name, + )?; + Ok(pmodel.believe_value(PointerValue::from(p))) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct Instance<'ctx, M: Model> { + /// The model of this instance. + pub model: M, + /// The value of this instance. + /// + /// Caller must make sure the type of `value` and the type of this `model` are equivalent, + /// down to having the same [`IntType::get_bit_width`] in case of [`IntType`] for example. + pub value: M::Value<'ctx>, +} + +// NOTE: Must be Rust object-safe - This must be typeable for a Rust trait object. +pub trait CheckType { + fn check_type_impl<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + ty: BasicTypeEnum<'ctx>, + ) -> Result<(), ModelError>; +} diff --git a/nac3core/src/codegen/model/int.rs b/nac3core/src/codegen/model/int.rs new file mode 100644 index 00000000..5953c22c --- /dev/null +++ b/nac3core/src/codegen/model/int.rs @@ -0,0 +1,228 @@ +use std::fmt; + +use inkwell::{ + context::Context, + types::{BasicTypeEnum, IntType}, + values::IntValue, + IntPredicate, +}; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::*; + +pub trait IntKind: fmt::Debug + Clone + Copy + Default { + fn get_int_type<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> IntType<'ctx>; +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Bool; +#[derive(Debug, Clone, Copy, Default)] +pub struct Byte; +#[derive(Debug, Clone, Copy, Default)] +pub struct Int32; +#[derive(Debug, Clone, Copy, Default)] +pub struct Int64; +#[derive(Debug, Clone, Copy, Default)] +pub struct SizeT; + +impl IntKind for Bool { + fn get_int_type<'ctx>(&self, _tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> { + ctx.bool_type() + } +} + +impl IntKind for Byte { + fn get_int_type<'ctx>(&self, _tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> { + ctx.i8_type() + } +} + +impl IntKind for Int32 { + fn get_int_type<'ctx>(&self, _tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> { + ctx.i32_type() + } +} + +impl IntKind for Int64 { + fn get_int_type<'ctx>(&self, _tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> { + ctx.i64_type() + } +} + +impl IntKind for SizeT { + fn get_int_type<'ctx>(&self, tyctx: TypeContext<'ctx>, _ctx: &'ctx Context) -> IntType<'ctx> { + tyctx.size_type + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct IntModel(pub N); +pub type Int<'ctx, N> = Instance<'ctx, IntModel>; + +impl CheckType for IntModel { + fn check_type_impl<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + ty: BasicTypeEnum<'ctx>, + ) -> Result<(), ModelError> { + let Ok(ty) = IntType::try_from(ty) else { + return Err(ModelError(format!("Expecting IntType, but got {ty:?}"))); + }; + + let exp_ty = self.0.get_int_type(tyctx, ctx); + if ty.get_bit_width() != exp_ty.get_bit_width() { + return Err(ModelError(format!( + "Expecting IntType to have {} bit(s), but got {} bit(s)", + exp_ty.get_bit_width(), + ty.get_bit_width() + ))); + } + + Ok(()) + } +} + +impl Model for IntModel { + type Value<'ctx> = IntValue<'ctx>; + type Type<'ctx> = IntType<'ctx>; + + #[must_use] + fn get_type<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> Self::Type<'ctx> { + self.0.get_int_type(tyctx, ctx) + } +} + +impl IntModel { + pub fn constant<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + value: u64, + ) -> Int<'ctx, N> { + let value = self.get_type(tyctx, ctx).const_int(value, false); + self.believe_value(value) + } + + pub fn const_0<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> Int<'ctx, N> { + self.constant(tyctx, ctx, 0) + } + + pub fn const_1<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> Int<'ctx, N> { + self.constant(tyctx, ctx, 1) + } + + pub fn s_extend_or_bit_cast<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + name: &str, + ) -> Int<'ctx, N> { + let value = ctx + .builder + .build_int_s_extend_or_bit_cast(value, self.get_type(tyctx, ctx.ctx), name) + .unwrap(); + self.believe_value(value) + } + + pub fn truncate<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + name: &str, + ) -> Int<'ctx, N> { + let value = + ctx.builder.build_int_truncate(value, self.get_type(tyctx, ctx.ctx), name).unwrap(); + self.believe_value(value) + } +} + +impl IntModel { + #[must_use] + pub fn const_false<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + ) -> Int<'ctx, Bool> { + self.constant(tyctx, ctx, 0) + } + + #[must_use] + pub fn const_true<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + ) -> Int<'ctx, Bool> { + self.constant(tyctx, ctx, 1) + } +} + +impl<'ctx, N: IntKind> Int<'ctx, N> { + pub fn s_extend_or_bit_cast( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + name: &str, + ) -> Int<'ctx, NewN> { + IntModel(to_int_kind).s_extend_or_bit_cast(tyctx, ctx, self.value, name) + } + + pub fn truncate( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + name: &str, + ) -> Int<'ctx, NewN> { + IntModel(to_int_kind).truncate(tyctx, ctx, self.value, name) + } + + #[must_use] + pub fn add( + &self, + ctx: &CodeGenContext<'ctx, '_>, + other: Int<'ctx, N>, + name: &str, + ) -> Int<'ctx, N> { + let value = ctx.builder.build_int_add(self.value, other.value, name).unwrap(); + self.model.believe_value(value) + } + + #[must_use] + pub fn sub( + &self, + ctx: &CodeGenContext<'ctx, '_>, + other: Int<'ctx, N>, + name: &str, + ) -> Int<'ctx, N> { + let value = ctx.builder.build_int_sub(self.value, other.value, name).unwrap(); + self.model.believe_value(value) + } + + #[must_use] + pub fn mul( + &self, + ctx: &CodeGenContext<'ctx, '_>, + other: Int<'ctx, N>, + name: &str, + ) -> Int<'ctx, N> { + let value = ctx.builder.build_int_mul(self.value, other.value, name).unwrap(); + self.model.believe_value(value) + } + + pub fn compare( + &self, + ctx: &CodeGenContext<'ctx, '_>, + op: IntPredicate, + other: Int<'ctx, N>, + name: &str, + ) -> Int<'ctx, Bool> { + let bool_model = IntModel(Bool); + let value = ctx.builder.build_int_compare(op, self.value, other.value, name).unwrap(); + bool_model.believe_value(value) + } +} diff --git a/nac3core/src/codegen/model/mod.rs b/nac3core/src/codegen/model/mod.rs new file mode 100644 index 00000000..78da8048 --- /dev/null +++ b/nac3core/src/codegen/model/mod.rs @@ -0,0 +1,11 @@ +mod core; +mod int; +mod ptr; +mod slice; +mod structure; + +pub use core::*; +pub use int::*; +pub use ptr::*; +pub use slice::*; +pub use structure::*; diff --git a/nac3core/src/codegen/model/ptr.rs b/nac3core/src/codegen/model/ptr.rs new file mode 100644 index 00000000..6fe5deef --- /dev/null +++ b/nac3core/src/codegen/model/ptr.rs @@ -0,0 +1,141 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, PointerType}, + values::{IntValue, PointerValue}, + AddressSpace, +}; + +use crate::codegen::CodeGenContext; + +use super::*; + +#[derive(Debug, Clone, Copy, Default)] +pub struct PtrModel(pub Element); +pub type Ptr<'ctx, Element> = Instance<'ctx, PtrModel>; + +impl CheckType for PtrModel { + fn check_type_impl<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + ty: BasicTypeEnum<'ctx>, + ) -> Result<(), super::ModelError> { + let Ok(ty) = PointerType::try_from(ty) else { + return Err(ModelError(format!("Expecting PointerType, but got {ty:?}"))); + }; + + let elem_ty = ty.get_element_type(); + let Ok(elem_ty) = BasicTypeEnum::try_from(elem_ty) else { + return Err(ModelError(format!( + "Expecting pointer element type to be a BasicTypeEnum, but got {elem_ty:?}" + ))); + }; + + // TODO: inkwell `get_element_type()` will be deprecated. + // Remove the check for `get_element_type()` when the time comes. + self.0 + .check_type_impl(tyctx, ctx, elem_ty) + .map_err(|err| err.under_context("a PointerType"))?; + + Ok(()) + } +} + +impl Model for PtrModel { + type Value<'ctx> = PointerValue<'ctx>; + type Type<'ctx> = PointerType<'ctx>; + + fn get_type<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> Self::Type<'ctx> { + self.0.get_type(tyctx, ctx).ptr_type(AddressSpace::default()) + } +} + +impl PtrModel { + /// Return a ***constant*** nullptr. + pub fn nullptr<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + ) -> Ptr<'ctx, Element> { + let ptr = self.get_type(tyctx, ctx).const_null(); + self.believe_value(ptr) + } + + pub fn transmute<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + ptr: PointerValue<'ctx>, + name: &str, + ) -> Ptr<'ctx, Element> { + let ptr = ctx.builder.build_pointer_cast(ptr, self.get_type(tyctx, ctx.ctx), name).unwrap(); + self.believe_value(ptr) + } +} + +impl<'ctx, Element: Model> Ptr<'ctx, Element> { + /// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`]. + pub fn offset( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + offset: IntValue<'ctx>, + name: &str, + ) -> Ptr<'ctx, Element> { + let new_ptr = + unsafe { ctx.builder.build_in_bounds_gep(self.value, &[offset], name).unwrap() }; + self.model.check_value(tyctx, ctx.ctx, new_ptr).unwrap() + } + + // Load the `i`-th element (0-based) on the array with [`inkwell::builder::Builder::build_in_bounds_gep`]. + pub fn ix( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + i: IntValue<'ctx>, + name: &str, + ) -> Instance<'ctx, Element> { + self.offset(tyctx, ctx, i, name).load(tyctx, ctx, name) + } + + /// Load the value with [`inkwell::builder::Builder::build_load`]. + pub fn load( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + name: &str, + ) -> Instance<'ctx, Element> { + let value = ctx.builder.build_load(self.value, name).unwrap(); + self.model.0.check_value(tyctx, ctx.ctx, value).unwrap() // If unwrap() panics, there is a logic error. + } + + /// Store a value with [`inkwell::builder::Builder::build_store`]. + pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: Instance<'ctx, Element>) { + ctx.builder.build_store(self.value, value.value).unwrap(); + } + + /// Return a casted pointer of element type `NewElement` with [`inkwell::builder::Builder::build_pointer_cast`]. + pub fn transmute( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + new_model: NewElement, + name: &str, + ) -> Ptr<'ctx, NewElement> { + PtrModel(new_model).transmute(tyctx, ctx, self.value, name) + } + + /// Check if the pointer is null with [`inkwell::builder::Builder::build_is_null`]. + pub fn is_null(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Int<'ctx, Bool> { + let bool_model = IntModel(Bool); + let value = ctx.builder.build_is_null(self.value, name).unwrap(); + bool_model.believe_value(value) + } + + /// Check if the pointer is not null with [`inkwell::builder::Builder::build_is_not_null`]. + pub fn is_not_null(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Int<'ctx, Bool> { + let bool_model = IntModel(Bool); + let value = ctx.builder.build_is_not_null(self.value, name).unwrap(); + bool_model.believe_value(value) + } +} diff --git a/nac3core/src/codegen/model/slice.rs b/nac3core/src/codegen/model/slice.rs new file mode 100644 index 00000000..0aaa8447 --- /dev/null +++ b/nac3core/src/codegen/model/slice.rs @@ -0,0 +1,72 @@ +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::*; + +/// A slice - literally just a pointer and a length value. +/// +/// NOTE: This is NOT a [`Model`]. +pub struct ArraySlice<'ctx, Len: IntKind, Item: Model> { + pub base: Ptr<'ctx, Item>, + pub len: Int<'ctx, Len>, +} + +impl<'ctx, Len: IntKind, Item: Model> ArraySlice<'ctx, Len, Item> { + /// Get the `idx`-nth element of this [`ArraySlice`], but doesn't do an assertion to see if `idx` is out of bounds or not. + /// + /// Also see [`ArraySlice::ix`]. + pub fn ix_unchecked( + &self, + tyctx: TypeContext<'ctx>, + ctx: &CodeGenContext<'ctx, '_>, + idx: Int<'ctx, Len>, + name: &str, + ) -> Ptr<'ctx, Item> { + let element_ptr = unsafe { + ctx.builder.build_in_bounds_gep(self.base.value, &[idx.value], name).unwrap() + }; + self.base.model.check_value(tyctx, ctx.ctx, element_ptr).unwrap() + } + + /// Call [`ArraySlice::ix_unchecked`], but checks if `idx` is in bounds, otherwise a runtime `IndexError` will be thrown. + pub fn ix( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + idx: Int<'ctx, Len>, + name: &str, + ) -> Ptr<'ctx, Item> { + let tyctx = generator.type_context(ctx.ctx); + let len_model = IntModel(Len::default()); + + // Assert `0 <= idx < length` and throw an Exception if `idx` is out of bounds + let lower_bounded = ctx + .builder + .build_int_compare( + inkwell::IntPredicate::SLE, + len_model.constant(tyctx, ctx.ctx, 0).value, + idx.value, + "lower_bounded", + ) + .unwrap(); + let upper_bounded = ctx + .builder + .build_int_compare( + inkwell::IntPredicate::SLT, + idx.value, + self.len.value, + "upper_bounded", + ) + .unwrap(); + let bounded = ctx.builder.build_and(lower_bounded, upper_bounded, "bounded").unwrap(); + ctx.make_assert( + generator, + bounded, + "0:IndexError", + "nac3core LLVM codegen attempting to access out of bounds array index {0}. Must satisfy 0 <= index < {2}", + [ Some(idx.value), Some(self.len.value), None], + ctx.current_loc + ); + + self.ix_unchecked(tyctx, ctx, idx, name) + } +} diff --git a/nac3core/src/codegen/model/structure.rs b/nac3core/src/codegen/model/structure.rs new file mode 100644 index 00000000..beeae44e --- /dev/null +++ b/nac3core/src/codegen/model/structure.rs @@ -0,0 +1,174 @@ +use std::fmt; + +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, StructType}, + values::StructValue, +}; +use itertools::izip; + +use crate::codegen::CodeGenContext; + +use super::*; + +#[derive(Debug, Clone, Copy)] +pub struct GepField { + pub gep_index: u64, + pub name: &'static str, + pub model: M, +} + +pub trait FieldVisitor { + type Field; + + fn add(&mut self, name: &'static str) -> Self::Field; +} + +pub struct GepFieldVisitor { + gep_index_counter: u64, +} + +impl FieldVisitor for GepFieldVisitor { + type Field = GepField; + + fn add(&mut self, name: &'static str) -> Self::Field { + let gep_index = self.gep_index_counter; + self.gep_index_counter += 1; + Self::Field { gep_index, name, model: M::default() } + } +} + +struct TypeFieldVisitor<'ctx> { + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + field_types: Vec>, +} + +impl<'ctx> FieldVisitor for TypeFieldVisitor<'ctx> { + type Field = (); + + fn add(&mut self, _name: &'static str) -> Self::Field { + self.field_types.push(M::default().get_type(self.tyctx, self.ctx).as_basic_type_enum()); + } +} + +struct CheckTypeEntry { + check_type: Box, + name: &'static str, +} + +struct CheckTypeFieldVisitor<'ctx> { + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + check_types: Vec, +} + +impl<'ctx> FieldVisitor for CheckTypeFieldVisitor<'ctx> { + type Field = (); + + fn add(&mut self, name: &'static str) -> Self::Field { + self.check_types.push(CheckTypeEntry { check_type: Box::::default(), name }); + } +} + +pub trait StructKind: fmt::Debug + Clone + Copy + Default { + type Fields; + + fn visit_fields(&self, visitor: &mut F) -> Self::Fields; + + fn fields(&self) -> Self::Fields { + self.visit_fields(&mut GepFieldVisitor { gep_index_counter: 0 }) + } + + fn get_struct_type<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + ) -> StructType<'ctx> { + let mut visitor = TypeFieldVisitor { tyctx, ctx, field_types: Vec::new() }; + self.visit_fields(&mut visitor); + + ctx.struct_type(&visitor.field_types, false) + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct StructModel(pub S); +pub type Struct<'ctx, S> = Instance<'ctx, StructModel>; + +impl CheckType for StructModel { + fn check_type_impl<'ctx>( + &self, + tyctx: TypeContext<'ctx>, + ctx: &'ctx Context, + ty: BasicTypeEnum<'ctx>, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + let Ok(ty) = StructType::try_from(ty) else { + return Err(ModelError(format!("Expecting StructType, but got {ty:?}"))); + }; + + let field_types = ty.get_field_types(); + + let check_types = { + let mut builder = CheckTypeFieldVisitor { tyctx, ctx, check_types: Vec::new() }; + self.0.visit_fields(&mut builder); + builder.check_types + }; + + if check_types.len() != field_types.len() { + return Err(ModelError(format!( + "Expecting StructType to have {} field(s), but got {} field(s)", + check_types.len(), + field_types.len() + ))); + } + + for (field_i, (entry, field_type)) in izip!(check_types, field_types).enumerate() { + let field_at = field_i + 1; + + entry.check_type.check_type_impl(tyctx, ctx, field_type).map_err(|err| { + err.under_context(format!("struct field #{field_at} '{}'", entry.name).as_str()) + })?; + } + + Ok(()) + } +} + +impl Model for StructModel { + type Value<'ctx> = StructValue<'ctx>; + type Type<'ctx> = StructType<'ctx>; + + fn get_type<'ctx>(&self, tyctx: TypeContext<'ctx>, ctx: &'ctx Context) -> Self::Type<'ctx> { + self.0.get_struct_type(tyctx, ctx) + } +} + +impl<'ctx, S: StructKind> Ptr<'ctx, StructModel> { + pub fn gep( + &self, + ctx: &CodeGenContext<'ctx, '_>, + get_field: GetField, + ) -> Ptr<'ctx, M> + where + M: Model, + GetField: FnOnce(S::Fields) -> GepField, + { + let field = get_field(self.model.0 .0.fields()); + let llvm_i32 = ctx.ctx.i32_type(); // must be i32, if its i64 then rust segfaults + + let ptr = unsafe { + ctx.builder + .build_in_bounds_gep( + self.value, + &[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)], + field.name, + ) + .unwrap() + }; + + let ptr_model = PtrModel(field.model); + ptr_model.believe_value(ptr) + } +}