diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 71a2d52..ca5aa23 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -41,6 +41,7 @@ pub mod extern_fns; mod generator; pub mod irrt; pub mod llvm_intrinsics; +pub mod model; pub mod numpy; pub mod stmt; diff --git a/nac3core/src/codegen/model/any.rs b/nac3core/src/codegen/model/any.rs new file mode 100644 index 0000000..9df863e --- /dev/null +++ b/nac3core/src/codegen/model/any.rs @@ -0,0 +1,42 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum}, + values::BasicValueEnum, +}; + +use crate::codegen::CodeGenerator; + +use super::*; + +/// A [`Model`] of any [`BasicTypeEnum`]. +/// +/// Use this when it is infeasible to use model abstractions. +#[derive(Debug, Clone, Copy)] +pub struct Any<'ctx>(pub BasicTypeEnum<'ctx>); + +impl<'ctx> Model<'ctx> for Any<'ctx> { + type Value = BasicValueEnum<'ctx>; + type Type = BasicTypeEnum<'ctx>; + + fn get_type( + &self, + _generator: &G, + _ctx: &'ctx Context, + ) -> Self::Type { + self.0 + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + _generator: &mut G, + _ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + if ty == self.0 { + Ok(()) + } else { + Err(ModelError(format!("Expecting {}, but got {}", self.0, ty))) + } + } +} diff --git a/nac3core/src/codegen/model/array.rs b/nac3core/src/codegen/model/array.rs new file mode 100644 index 0000000..be8dc0b --- /dev/null +++ b/nac3core/src/codegen/model/array.rs @@ -0,0 +1,143 @@ +use std::fmt; + +use inkwell::{ + context::Context, + types::{ArrayType, BasicType, BasicTypeEnum}, + values::{ArrayValue, IntValue}, +}; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::*; + +/// Trait for Rust structs identifying length values for [`Array`]. +pub trait LenKind: fmt::Debug + Clone + Copy { + fn get_length(&self) -> u32; +} + +/// A statically known length. +#[derive(Debug, Clone, Copy, Default)] +pub struct Len; + +/// A dynamically known length. +#[derive(Debug, Clone, Copy)] +pub struct AnyLen(pub u32); + +impl LenKind for Len { + fn get_length(&self) -> u32 { + N + } +} + +impl LenKind for AnyLen { + fn get_length(&self) -> u32 { + self.0 + } +} + +/// A Model for an [`ArrayType`]. +/// +/// `Len` should be of a [`LenKind`] and `Item` should be a of [`Model`]. +#[derive(Debug, Clone, Copy, Default)] +pub struct Array { + /// Length of this array. + pub len: Len, + /// [`Model`] of the array items. + pub item: Item, +} + +impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array { + type Value = ArrayValue<'ctx>; + type Type = ArrayType<'ctx>; + + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { + self.item.get_type(generator, ctx).array_type(self.len.get_length()) + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + let BasicTypeEnum::ArrayType(ty) = ty else { + return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}"))); + }; + + if ty.len() != self.len.get_length() { + return Err(ModelError(format!( + "Expecting ArrayType with size {}, but got an ArrayType with size {}", + ty.len(), + self.len.get_length() + ))); + } + + self.item + .check_type(generator, ctx, ty.get_element_type()) + .map_err(|err| err.under_context("an ArrayType"))?; + + Ok(()) + } +} + +impl<'ctx, Len: LenKind, Item: Model<'ctx>> Instance<'ctx, Ptr>> { + /// Get the pointer to the `i`-th (0-based) array element. + pub fn gep( + &self, + ctx: &CodeGenContext<'ctx, '_>, + i: IntValue<'ctx>, + ) -> Instance<'ctx, Ptr> { + let zero = ctx.ctx.i32_type().const_zero(); + let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], "").unwrap() }; + + Ptr(self.model.0.item).believe_value(ptr) + } + + /// Like `gep` but `i` is a constant. + pub fn gep_const(&self, ctx: &CodeGenContext<'ctx, '_>, i: u64) -> Instance<'ctx, Ptr> { + assert!( + i < u64::from(self.model.0.len.get_length()), + "Index {i} is out of bounds. Array length = {}", + self.model.0.len.get_length() + ); + + let i = ctx.ctx.i32_type().const_int(i, false); + self.gep(ctx, i) + } + + /// Convenience function equivalent to `.gep(...).load(...)`. + pub fn get( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + i: IntValue<'ctx>, + ) -> Instance<'ctx, Item> { + self.gep(ctx, i).load(generator, ctx) + } + + /// Like `get` but `i` is a constant. + pub fn get_const( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + i: u64, + ) -> Instance<'ctx, Item> { + self.gep_const(ctx, i).load(generator, ctx) + } + + /// Convenience function equivalent to `.gep(...).store(...)`. + pub fn set( + &self, + ctx: &CodeGenContext<'ctx, '_>, + i: IntValue<'ctx>, + value: Instance<'ctx, Item>, + ) { + self.gep(ctx, i).store(ctx, value); + } + + /// Like `set` but `i` is a constant. + pub fn set_const(&self, ctx: &CodeGenContext<'ctx, '_>, i: u64, value: Instance<'ctx, Item>) { + self.gep_const(ctx, i).store(ctx, value); + } +} diff --git a/nac3core/src/codegen/model/core.rs b/nac3core/src/codegen/model/core.rs new file mode 100644 index 0000000..25faeea --- /dev/null +++ b/nac3core/src/codegen/model/core.rs @@ -0,0 +1,202 @@ +use std::fmt; + +use inkwell::{context::Context, types::*, values::*}; +use itertools::Itertools; + +use super::*; +use crate::codegen::{CodeGenContext, CodeGenerator}; + +/// A error type for reporting any [`Model`]-related error (e.g., a [`BasicType`] mismatch). +#[derive(Debug, Clone)] +pub struct ModelError(pub String); + +impl ModelError { + // Append a context message to the error. + pub(super) fn under_context(mut self, context: &str) -> Self { + self.0.push_str(" ... in "); + self.0.push_str(context); + self + } +} + +/// Trait for Rust structs identifying [`BasicType`]s in the context of a known [`CodeGenerator`] and [`CodeGenContext`]. +/// +/// For instance, +/// - [`Int`] identifies an [`IntType`] with 32-bits. +/// - [`Int`] identifies an [`IntType`] with bit-width [`CodeGenerator::get_size_type`]. +/// - [`Ptr>`] identifies a [`PointerType`] that points to an [`IntType`] with bit-width [`CodeGenerator::get_size_type`]. +/// - [`Int`] identifies an [`IntType`] with bit-width of whatever is set in the [`AnyInt`] object. +/// - [`Any`] identifies a [`BasicType`] set in the [`Any`] object itself. +/// +/// You can get the [`BasicType`] out of a model with [`Model::get_type`]. +/// +/// Furthermore, [`Instance<'ctx, M>`] is a simple structure that carries a [`BasicValue`] with [`BasicType`] identified by model `M`. +/// +/// The main purpose of this abstraction is to have a more Rust type-safe way to use Inkwell and give type-hints for programmers. +/// +/// ### Notes on `Default` trait +/// +/// For some models like [`Int`] or [`Int`], they have a [`Default`] trait since just by looking at their types, it is possible +/// to tell the [`BasicType`]s they are identifying. +/// +/// This can be used to create strongly-typed interfaces accepting only values of a specific [`BasicType`] without having to worry about +/// writing debug assertions to check, for example, if the programmer has passed in an [`IntValue`] with the wrong bit-width. +/// ```ignore +/// fn give_me_i32_and_get_a_size_t_back<'ctx>(i32: Instance<'ctx, Int>) -> Instance<'ctx, Int> { +/// // code... +/// } +/// ``` +/// +/// ### Notes on converting between Inkwell and model. +/// +/// Suppose you have an [`IntValue`], and you want to pass it into a function that takes a [`Instance<'ctx, Int>`]. You can do use +/// [`Model::check_value`] or [`Model::believe_value`]. +/// ```ignore +/// let my_value: IntValue<'ctx>; +/// +/// let my_value = Int(Int32).check_value(my_value).unwrap(); // Panics if `my_value` is not 32-bit with a descriptive error message. +/// +/// // or, if you are absolutely certain that `my_value` is 32-bit and doing extra checks is a waste of time: +/// let my_value = Int(Int32).believe_value(my_value); +/// ``` +pub trait Model<'ctx>: fmt::Debug + Clone + Copy { + /// The [`BasicType`] *variant* this model is identifying. + type Type: BasicType<'ctx>; + + /// The [`BasicValue`] type of the [`BasicType`] of this model. + type Value: BasicValue<'ctx> + TryFrom>; + + /// Return the [`BasicType`] of this model. + #[must_use] + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type; + + /// Get the number of bytes of the [`BasicType`] of this model. + fn sizeof( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> IntValue<'ctx> { + self.get_type(generator, ctx).size_of().unwrap() + } + + /// Check if a [`BasicType`] matches the [`BasicType`] of this model. + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError>; + + /// Create an instance from a value. + /// + /// Caller must make sure the type of `value` and the type of this `model` are equivalent. + #[must_use] + fn believe_value(&self, value: Self::Value) -> Instance<'ctx, Self> { + Instance { model: *self, value } + } + + /// Check if a [`BasicValue`]'s type is equivalent to the type of this model. + /// Wrap the [`BasicValue`] into an [`Instance`] if it is. + fn check_value, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + value: V, + ) -> Result, ModelError> { + let value = value.as_basic_value_enum(); + self.check_type(generator, ctx, value.get_type()) + .map_err(|err| err.under_context(format!("the value {value:?}").as_str()))?; + + let Ok(value) = Self::Value::try_from(value) else { + unreachable!("check_type() has bad implementation") + }; + Ok(self.believe_value(value)) + } + + // Allocate a value on the stack and return its pointer. + fn alloca( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Ptr> { + let p = ctx.builder.build_alloca(self.get_type(generator, ctx.ctx), "").unwrap(); + Ptr(*self).believe_value(p) + } + + // Allocate an array on the stack and return its pointer. + fn array_alloca( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + len: IntValue<'ctx>, + ) -> Instance<'ctx, Ptr> { + let p = ctx.builder.build_array_alloca(self.get_type(generator, ctx.ctx), len, "").unwrap(); + Ptr(*self).believe_value(p) + } + + fn var_alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&str>, + ) -> Result>, String> { + let ty = self.get_type(generator, ctx.ctx).as_basic_type_enum(); + let p = generator.gen_var_alloc(ctx, ty, name)?; + Ok(Ptr(*self).believe_value(p)) + } + + fn array_var_alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + len: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> Result>, String> { + // TODO: Remove ArraySliceValue + let ty = self.get_type(generator, ctx.ctx).as_basic_type_enum(); + let p = generator.gen_array_var_alloc(ctx, ty, len, name)?; + Ok(Ptr(*self).believe_value(PointerValue::from(p))) + } + + /// Allocate a constant array. + fn const_array( + &self, + generator: &mut G, + ctx: &'ctx Context, + values: &[Instance<'ctx, Self>], + ) -> Instance<'ctx, Array> { + macro_rules! make { + ($t:expr, $into_value:expr) => { + $t.const_array( + &values + .iter() + .map(|x| $into_value(x.value.as_basic_value_enum())) + .collect_vec(), + ) + }; + } + + let value = match self.get_type(generator, ctx).as_basic_type_enum() { + BasicTypeEnum::ArrayType(t) => make!(t, BasicValueEnum::into_array_value), + BasicTypeEnum::IntType(t) => make!(t, BasicValueEnum::into_int_value), + BasicTypeEnum::FloatType(t) => make!(t, BasicValueEnum::into_float_value), + BasicTypeEnum::PointerType(t) => make!(t, BasicValueEnum::into_pointer_value), + BasicTypeEnum::StructType(t) => make!(t, BasicValueEnum::into_struct_value), + BasicTypeEnum::VectorType(t) => make!(t, BasicValueEnum::into_vector_value), + }; + + Array { len: AnyLen(values.len() as u32), item: *self } + .check_value(generator, ctx, value) + .unwrap() + } +} + +#[derive(Debug, Clone, Copy)] +pub struct Instance<'ctx, M: Model<'ctx>> { + /// The model of this instance. + pub model: M, + /// The value of this instance. + /// + /// It is guaranteed the [`BasicType`] of `value` is consistent with that of `model`. + pub value: M::Value, +} diff --git a/nac3core/src/codegen/model/float.rs b/nac3core/src/codegen/model/float.rs new file mode 100644 index 0000000..88bff80 --- /dev/null +++ b/nac3core/src/codegen/model/float.rs @@ -0,0 +1,90 @@ +use std::fmt; + +use inkwell::{ + context::Context, + types::{BasicType, FloatType}, + values::FloatValue, +}; + +use crate::codegen::CodeGenerator; + +use super::*; + +pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy { + fn get_float_type( + &self, + generator: &G, + ctx: &'ctx Context, + ) -> FloatType<'ctx>; +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Float32; +#[derive(Debug, Clone, Copy, Default)] +pub struct Float64; + +impl<'ctx> FloatKind<'ctx> for Float32 { + fn get_float_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> FloatType<'ctx> { + ctx.f32_type() + } +} + +impl<'ctx> FloatKind<'ctx> for Float64 { + fn get_float_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> FloatType<'ctx> { + ctx.f64_type() + } +} + +#[derive(Debug, Clone, Copy)] +pub struct AnyFloat<'ctx>(FloatType<'ctx>); + +impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> { + fn get_float_type( + &self, + _generator: &G, + _ctx: &'ctx Context, + ) -> FloatType<'ctx> { + self.0 + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Float(pub N); + +impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for Float { + type Value = FloatValue<'ctx>; + type Type = FloatType<'ctx>; + + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { + self.0.get_float_type(generator, ctx) + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + let Ok(ty) = FloatType::try_from(ty) else { + return Err(ModelError(format!("Expecting FloatType, but got {ty:?}"))); + }; + + let exp_ty = self.0.get_float_type(generator, ctx); + + // TODO: Inkwell does not have get_bit_width for FloatType? + if ty != exp_ty { + return Err(ModelError(format!("Expecting {exp_ty:?}, but got {ty:?}"))); + } + + Ok(()) + } +} diff --git a/nac3core/src/codegen/model/function.rs b/nac3core/src/codegen/model/function.rs new file mode 100644 index 0000000..7ff2d74 --- /dev/null +++ b/nac3core/src/codegen/model/function.rs @@ -0,0 +1,122 @@ +use inkwell::{ + attributes::{Attribute, AttributeLoc}, + types::{BasicMetadataTypeEnum, BasicType, FunctionType}, + values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue}, +}; +use itertools::Itertools; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::*; + +#[derive(Debug, Clone, Copy)] +struct Arg<'ctx> { + ty: BasicMetadataTypeEnum<'ctx>, + val: BasicMetadataValueEnum<'ctx>, +} + +/// A convenience structure to construct & call an LLVM function. +/// +/// ### Usage +/// +/// The syntax is like this: +/// ```ignore +/// let result = CallFunction::begin("my_function_name") +/// .attrs(...) +/// .arg(arg1) +/// .arg(arg2) +/// .arg(arg3) +/// .returning("my_function_result", Int32); +/// ``` +/// +/// The function `my_function_name` is called when `.returning()` (or its variants) is called, returning +/// the result as an `Instance<'ctx, Int>`. +/// +/// If `my_function_name` has not been declared in `ctx.module`, once `.returning()` is called, a function +/// declaration of `my_function_name` is added to `ctx.module`, where the [`FunctionType`] is deduced from +/// the argument types and returning type. +pub struct CallFunction<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> { + generator: &'d mut G, + ctx: &'b CodeGenContext<'ctx, 'a>, + /// Function name + name: &'c str, + /// Call arguments + args: Vec>, + /// LLVM function Attributes + attrs: Vec<&'static str>, +} + +impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b, 'c, 'd, G> { + pub fn begin(generator: &'d mut G, ctx: &'b CodeGenContext<'ctx, 'a>, name: &'c str) -> Self { + CallFunction { generator, ctx, name, args: Vec::new(), attrs: Vec::new() } + } + + /// Push a list of LLVM function attributes to the function declaration. + #[must_use] + pub fn attrs(mut self, attrs: Vec<&'static str>) -> Self { + self.attrs = attrs; + self + } + + /// Push a call argument to the function call. + #[allow(clippy::needless_pass_by_value)] + #[must_use] + pub fn arg>(mut self, arg: Instance<'ctx, M>) -> Self { + let arg = Arg { + ty: arg.model.get_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(), + val: arg.value.as_basic_value_enum().into(), + }; + self.args.push(arg); + self + } + + /// Call the function and expect the function to return a value of type of `return_model`. + #[must_use] + pub fn returning>(self, name: &str, return_model: M) -> Instance<'ctx, M> { + let ret_ty = return_model.get_type(self.generator, self.ctx.ctx); + + let ret = self.call(|tys| ret_ty.fn_type(tys, false), name); + let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work + let ret = return_model.check_value(self.generator, self.ctx.ctx, ret).unwrap(); // Must work + ret + } + + /// Like [`CallFunction::returning_`] but `return_model` is automatically inferred. + #[must_use] + pub fn returning_auto + Default>(self, name: &str) -> Instance<'ctx, M> { + self.returning(name, M::default()) + } + + /// Call the function and expect the function to return a void-type. + pub fn returning_void(self) { + let ret_ty = self.ctx.ctx.void_type(); + + let _ = self.call(|tys| ret_ty.fn_type(tys, false), ""); + } + + fn call(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx> + where + F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>, + { + // Get the LLVM function. + let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| { + // Declare the function if it doesn't exist. + let tys = self.args.iter().map(|arg| arg.ty).collect_vec(); + + let func_type = make_fn_type(&tys); + let func = self.ctx.module.add_function(self.name, func_type, None); + + for attr in &self.attrs { + func.add_attribute( + AttributeLoc::Function, + self.ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + + func + }); + + let vals = self.args.iter().map(|arg| arg.val).collect_vec(); + self.ctx.builder.build_call(func, &vals, return_value_name).unwrap() + } +} diff --git a/nac3core/src/codegen/model/int.rs b/nac3core/src/codegen/model/int.rs new file mode 100644 index 0000000..3a8a4fe --- /dev/null +++ b/nac3core/src/codegen/model/int.rs @@ -0,0 +1,417 @@ +use std::{cmp::Ordering, fmt}; + +use inkwell::{ + context::Context, + types::{BasicType, IntType}, + values::IntValue, + IntPredicate, +}; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::*; + +pub trait IntKind<'ctx>: fmt::Debug + Clone + Copy { + fn get_int_type( + &self, + generator: &G, + ctx: &'ctx Context, + ) -> IntType<'ctx>; +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Bool; +#[derive(Debug, Clone, Copy, Default)] +pub struct Byte; +#[derive(Debug, Clone, Copy, Default)] +pub struct Int32; +#[derive(Debug, Clone, Copy, Default)] +pub struct Int64; +#[derive(Debug, Clone, Copy, Default)] +pub struct SizeT; + +impl<'ctx> IntKind<'ctx> for Bool { + fn get_int_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> IntType<'ctx> { + ctx.bool_type() + } +} + +impl<'ctx> IntKind<'ctx> for Byte { + fn get_int_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> IntType<'ctx> { + ctx.i8_type() + } +} + +impl<'ctx> IntKind<'ctx> for Int32 { + fn get_int_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> IntType<'ctx> { + ctx.i32_type() + } +} + +impl<'ctx> IntKind<'ctx> for Int64 { + fn get_int_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> IntType<'ctx> { + ctx.i64_type() + } +} + +impl<'ctx> IntKind<'ctx> for SizeT { + fn get_int_type( + &self, + generator: &G, + ctx: &'ctx Context, + ) -> IntType<'ctx> { + generator.get_size_type(ctx) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct AnyInt<'ctx>(pub IntType<'ctx>); + +impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> { + fn get_int_type( + &self, + _generator: &G, + _ctx: &'ctx Context, + ) -> IntType<'ctx> { + self.0 + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Int(pub N); + +impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int { + type Value = IntValue<'ctx>; + type Type = IntType<'ctx>; + + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { + self.0.get_int_type(generator, ctx) + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + let Ok(ty) = IntType::try_from(ty) else { + return Err(ModelError(format!("Expecting IntType, but got {ty:?}"))); + }; + + let exp_ty = self.0.get_int_type(generator, ctx); + if ty.get_bit_width() != exp_ty.get_bit_width() { + return Err(ModelError(format!( + "Expecting IntType to have {} bit(s), but got {} bit(s)", + exp_ty.get_bit_width(), + ty.get_bit_width() + ))); + } + + Ok(()) + } +} + +impl<'ctx, N: IntKind<'ctx>> Int { + pub fn const_int( + &self, + generator: &mut G, + ctx: &'ctx Context, + value: u64, + ) -> Instance<'ctx, Self> { + let value = self.get_type(generator, ctx).const_int(value, false); + self.believe_value(value) + } + + pub fn const_0( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Self> { + let value = self.get_type(generator, ctx).const_zero(); + self.believe_value(value) + } + + pub fn const_1( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Self> { + self.const_int(generator, ctx, 1) + } + + pub fn const_all_ones( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Self> { + let value = self.get_type(generator, ctx).const_all_ones(); + self.believe_value(value) + } + + pub fn s_extend_or_bit_cast( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + assert!( + value.get_type().get_bit_width() + <= self.0.get_int_type(generator, ctx.ctx).get_bit_width() + ); + let value = ctx + .builder + .build_int_s_extend_or_bit_cast(value, self.get_type(generator, ctx.ctx), "") + .unwrap(); + self.believe_value(value) + } + + pub fn s_extend( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + assert!( + value.get_type().get_bit_width() + < self.0.get_int_type(generator, ctx.ctx).get_bit_width() + ); + let value = + ctx.builder.build_int_s_extend(value, self.get_type(generator, ctx.ctx), "").unwrap(); + self.believe_value(value) + } + + pub fn z_extend_or_bit_cast( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + assert!( + value.get_type().get_bit_width() + <= self.0.get_int_type(generator, ctx.ctx).get_bit_width() + ); + let value = ctx + .builder + .build_int_z_extend_or_bit_cast(value, self.get_type(generator, ctx.ctx), "") + .unwrap(); + self.believe_value(value) + } + + pub fn z_extend( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + assert!( + value.get_type().get_bit_width() + < self.0.get_int_type(generator, ctx.ctx).get_bit_width() + ); + let value = + ctx.builder.build_int_z_extend(value, self.get_type(generator, ctx.ctx), "").unwrap(); + self.believe_value(value) + } + + pub fn truncate_or_bit_cast( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + assert!( + value.get_type().get_bit_width() + >= self.0.get_int_type(generator, ctx.ctx).get_bit_width() + ); + let value = ctx + .builder + .build_int_truncate_or_bit_cast(value, self.get_type(generator, ctx.ctx), "") + .unwrap(); + self.believe_value(value) + } + + pub fn truncate( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + assert!( + value.get_type().get_bit_width() + > self.0.get_int_type(generator, ctx.ctx).get_bit_width() + ); + let value = + ctx.builder.build_int_truncate(value, self.get_type(generator, ctx.ctx), "").unwrap(); + self.believe_value(value) + } + + /// `sext` or `trunc` an int to this model's int type. Does nothing if equal bit-widths. + pub fn s_extend_or_truncate( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + let their_width = value.get_type().get_bit_width(); + let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width(); + match their_width.cmp(&our_width) { + Ordering::Less => self.s_extend(generator, ctx, value), + Ordering::Equal => self.believe_value(value), + Ordering::Greater => self.truncate(generator, ctx, value), + } + } + + /// `zext` or `trunc` an int to this model's int type. Does nothing if equal bit-widths. + pub fn z_extend_or_truncate( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + let their_width = value.get_type().get_bit_width(); + let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width(); + match their_width.cmp(&our_width) { + Ordering::Less => self.z_extend(generator, ctx, value), + Ordering::Equal => self.believe_value(value), + Ordering::Greater => self.truncate(generator, ctx, value), + } + } +} + +impl Int { + #[must_use] + pub fn const_false<'ctx, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Self> { + self.const_int(generator, ctx, 0) + } + + #[must_use] + pub fn const_true<'ctx, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Self> { + self.const_int(generator, ctx, 1) + } +} + +impl<'ctx, N: IntKind<'ctx>> Instance<'ctx, Int> { + pub fn s_extend_or_bit_cast, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).s_extend_or_bit_cast(generator, ctx, self.value) + } + + pub fn s_extend, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).s_extend(generator, ctx, self.value) + } + + pub fn z_extend_or_bit_cast, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).z_extend_or_bit_cast(generator, ctx, self.value) + } + + pub fn z_extend, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).z_extend(generator, ctx, self.value) + } + + pub fn truncate_or_bit_cast, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).truncate_or_bit_cast(generator, ctx, self.value) + } + + pub fn truncate, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).truncate(generator, ctx, self.value) + } + + pub fn s_extend_or_truncate, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).s_extend_or_truncate(generator, ctx, self.value) + } + + pub fn z_extend_or_truncate, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).z_extend_or_truncate(generator, ctx, self.value) + } + + #[must_use] + pub fn add(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self { + let value = ctx.builder.build_int_add(self.value, other.value, "").unwrap(); + self.model.believe_value(value) + } + + #[must_use] + pub fn sub(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self { + let value = ctx.builder.build_int_sub(self.value, other.value, "").unwrap(); + self.model.believe_value(value) + } + + #[must_use] + pub fn mul(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self { + let value = ctx.builder.build_int_mul(self.value, other.value, "").unwrap(); + self.model.believe_value(value) + } + + pub fn compare( + &self, + ctx: &CodeGenContext<'ctx, '_>, + op: IntPredicate, + other: Self, + ) -> Instance<'ctx, Int> { + let value = ctx.builder.build_int_compare(op, self.value, other.value, "").unwrap(); + Int(Bool).believe_value(value) + } +} diff --git a/nac3core/src/codegen/model/mod.rs b/nac3core/src/codegen/model/mod.rs new file mode 100644 index 0000000..22bb333 --- /dev/null +++ b/nac3core/src/codegen/model/mod.rs @@ -0,0 +1,16 @@ +mod any; +mod array; +mod core; +mod float; +pub mod function; +mod int; +mod ptr; +mod structure; + +pub use any::*; +pub use array::*; +pub use core::*; +pub use float::*; +pub use int::*; +pub use ptr::*; +pub use structure::*; diff --git a/nac3core/src/codegen/model/ptr.rs b/nac3core/src/codegen/model/ptr.rs new file mode 100644 index 0000000..adf241e --- /dev/null +++ b/nac3core/src/codegen/model/ptr.rs @@ -0,0 +1,191 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, PointerType}, + values::{IntValue, PointerValue}, + AddressSpace, +}; + +use crate::codegen::{llvm_intrinsics::call_memcpy_generic, CodeGenContext, CodeGenerator}; + +use super::*; + +/// A model for [`PointerType`]. +/// +/// `Item` is the element type this pointer is pointing to, and should be of a [`Model`]. +#[derive(Debug, Clone, Copy, Default)] +pub struct Ptr(pub Item); + +impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr { + type Value = PointerValue<'ctx>; + type Type = PointerType<'ctx>; + + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { + self.0.get_type(generator, ctx).ptr_type(AddressSpace::default()) + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + let Ok(ty) = PointerType::try_from(ty) else { + return Err(ModelError(format!("Expecting PointerType, but got {ty:?}"))); + }; + + let elem_ty = ty.get_element_type(); + let Ok(elem_ty) = BasicTypeEnum::try_from(elem_ty) else { + return Err(ModelError(format!( + "Expecting pointer element type to be a BasicTypeEnum, but got {elem_ty:?}" + ))); + }; + + // TODO: inkwell `get_element_type()` will be deprecated. + // Remove the check for `get_element_type()` when the time comes. + self.0 + .check_type(generator, ctx, elem_ty) + .map_err(|err| err.under_context("a PointerType"))?; + + Ok(()) + } +} + +impl<'ctx, Element: Model<'ctx>> Ptr { + /// Return a ***constant*** nullptr. + pub fn nullptr( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Ptr> { + let ptr = self.get_type(generator, ctx).const_null(); + self.believe_value(ptr) + } + + /// Cast a pointer into this model with [`inkwell::builder::Builder::build_pointer_cast`] + pub fn pointer_cast( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + ptr: PointerValue<'ctx>, + ) -> Instance<'ctx, Ptr> { + let t = self.get_type(generator, ctx.ctx); + let ptr = ctx.builder.build_pointer_cast(ptr, t, "").unwrap(); + self.believe_value(ptr) + } +} + +impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr> { + /// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`]. + #[must_use] + pub fn offset( + &self, + ctx: &CodeGenContext<'ctx, '_>, + offset: IntValue<'ctx>, + ) -> Instance<'ctx, Ptr> { + let p = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[offset], "").unwrap() }; + self.model.believe_value(p) + } + + /// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`] by a constant offset. + #[must_use] + pub fn offset_const( + &self, + ctx: &CodeGenContext<'ctx, '_>, + offset: u64, + ) -> Instance<'ctx, Ptr> { + let offset = ctx.ctx.i32_type().const_int(offset, false); + self.offset(ctx, offset) + } + + pub fn set_index( + &self, + ctx: &CodeGenContext<'ctx, '_>, + index: IntValue<'ctx>, + value: Instance<'ctx, Item>, + ) { + self.offset(ctx, index).store(ctx, value); + } + + pub fn set_index_const( + &self, + ctx: &CodeGenContext<'ctx, '_>, + index: u64, + value: Instance<'ctx, Item>, + ) { + self.offset_const(ctx, index).store(ctx, value); + } + + pub fn get_index( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + index: IntValue<'ctx>, + ) -> Instance<'ctx, Item> { + self.offset(ctx, index).load(generator, ctx) + } + + pub fn get_index_const( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + index: u64, + ) -> Instance<'ctx, Item> { + self.offset_const(ctx, index).load(generator, ctx) + } + + /// Load the value with [`inkwell::builder::Builder::build_load`]. + pub fn load( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Item> { + let value = ctx.builder.build_load(self.value, "").unwrap(); + self.model.0.check_value(generator, ctx.ctx, value).unwrap() // If unwrap() panics, there is a logic error. + } + + /// Store a value with [`inkwell::builder::Builder::build_store`]. + pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: Instance<'ctx, Item>) { + ctx.builder.build_store(self.value, value.value).unwrap(); + } + + /// Return a casted pointer of element type `NewElement` with [`inkwell::builder::Builder::build_pointer_cast`]. + pub fn pointer_cast, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + new_item: NewItem, + ) -> Instance<'ctx, Ptr> { + Ptr(new_item).pointer_cast(generator, ctx, self.value) + } + + /// Check if the pointer is null with [`inkwell::builder::Builder::build_is_null`]. + pub fn is_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int> { + let value = ctx.builder.build_is_null(self.value, "").unwrap(); + Int(Bool).believe_value(value) + } + + /// Check if the pointer is not null with [`inkwell::builder::Builder::build_is_not_null`]. + pub fn is_not_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int> { + let value = ctx.builder.build_is_not_null(self.value, "").unwrap(); + Int(Bool).believe_value(value) + } + + /// `memcpy` from another pointer. + pub fn copy_from( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + source: Self, + num_items: IntValue<'ctx>, + ) { + // Force extend `num_items` and `itemsize` to `i64` so their types would match. + let itemsize = self.model.sizeof(generator, ctx.ctx); + let itemsize = Int(Int64).z_extend_or_truncate(generator, ctx, itemsize); + let num_items = Int(Int64).z_extend_or_truncate(generator, ctx, num_items); + let totalsize = itemsize.mul(ctx, num_items); + + let is_volatile = ctx.ctx.bool_type().const_zero(); // is_volatile = false + call_memcpy_generic(ctx, self.value, source.value, totalsize.value, is_volatile); + } +} diff --git a/nac3core/src/codegen/model/structure.rs b/nac3core/src/codegen/model/structure.rs new file mode 100644 index 0000000..a989904 --- /dev/null +++ b/nac3core/src/codegen/model/structure.rs @@ -0,0 +1,359 @@ +use std::fmt; + +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, StructType}, + values::{BasicValueEnum, StructValue}, +}; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::*; + +/// A traveral that traverses a Rust `struct` that is used to declare an LLVM's struct's field types. +pub trait FieldTraversal<'ctx> { + /// Output type of [`FieldTraversal::add`]. + type Out; + + /// Traverse through the type of a declared field and do something with it. + /// + /// * `name` - The cosmetic name of the LLVM field. Used for debugging. + /// * `model` - The [`Model`] representing the LLVM type of this field. + fn add>(&mut self, name: &'static str, model: M) -> Self::Out; + + /// Like [`FieldTraversal::add`] but [`Model`] is automatically inferred from its [`Default`] trait. + fn add_auto + Default>(&mut self, name: &'static str) -> Self::Out { + self.add(name, M::default()) + } +} + +/// Descriptor of an LLVM struct field. +#[derive(Debug, Clone, Copy)] +pub struct GepField { + /// The GEP index of this field. This is the index to use with `build_gep`. + pub gep_index: u64, + /// The cosmetic name of this field. + pub name: &'static str, + /// The [`Model`] of this field's type. + pub model: M, +} + +/// A traversal to calculate the GEP index of fields. +pub struct GepFieldTraversal { + /// The current GEP index. + gep_index_counter: u64, +} + +impl<'ctx> FieldTraversal<'ctx> for GepFieldTraversal { + type Out = GepField; + + fn add>(&mut self, name: &'static str, model: M) -> Self::Out { + let gep_index = self.gep_index_counter; + self.gep_index_counter += 1; + Self::Out { gep_index, name, model } + } +} + +/// A traversal to collect the field types of a struct. +/// +/// This is used to collect field types and construct the LLVM struct type with [`Context::struct_type`]. +struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> { + generator: &'a G, + ctx: &'ctx Context, + /// The collected field types so far in exact order. + field_types: Vec>, +} + +impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> { + type Out = (); // Checking types return nothing. + + fn add>(&mut self, _name: &'static str, model: M) -> Self::Out { + let t = model.get_type(self.generator, self.ctx).as_basic_type_enum(); + self.field_types.push(t); + } +} + +/// A traversal to check the types of fields. +struct CheckTypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> { + generator: &'a mut G, + ctx: &'ctx Context, + /// The current GEP index, so we can tell the index of the field we are checking + /// and report the GEP index. + gep_index_counter: u32, + /// The [`StructType`] to check. + scrutinee: StructType<'ctx>, + /// The list of collected errors so far. + errors: Vec, +} + +impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> + for CheckTypeFieldTraversal<'ctx, 'a, G> +{ + type Out = (); // Checking types return nothing. + + fn add>(&mut self, name: &'static str, model: M) -> Self::Out { + let gep_index = self.gep_index_counter; + self.gep_index_counter += 1; + + if let Some(t) = self.scrutinee.get_field_type_at_index(gep_index) { + if let Err(err) = model.check_type(self.generator, self.ctx, t) { + self.errors + .push(err.under_context(format!("field #{gep_index} '{name}'").as_str())); + } + } // Otherwise, it will be caught by Struct's `check_type`. + } +} + +/// A trait for Rust structs identifying LLVM structures. +/// +/// ### Example +/// +/// Suppose you want to define this structure: +/// ```c +/// template +/// struct ContiguousNDArray { +/// size_t ndims; +/// size_t* shape; +/// T* data; +/// } +/// ``` +/// +/// This is how it should be done: +/// ```ignore +/// pub struct ContiguousNDArrayFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> { +/// pub ndims: F::Out>, +/// pub shape: F::Out>>, +/// pub data: F::Out>, +/// } +/// +/// /// An ndarray without strides and non-opaque `data` field in NAC3. +/// #[derive(Debug, Clone, Copy)] +/// pub struct ContiguousNDArray { +/// /// [`Model`] of the items. +/// pub item: M, +/// } +/// +/// impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for ContiguousNDArray { +/// type Fields> = ContiguousNDArrayFields<'ctx, F, Item>; +/// +/// fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { +/// // The order of `traversal.add*` is important +/// Self::Fields { +/// ndims: traversal.add_auto("ndims"), +/// shape: traversal.add_auto("shape"), +/// data: traversal.add("data", Ptr(self.item)), +/// } +/// } +/// } +/// ``` +/// +/// The [`FieldTraversal`] here is a mechanism to allow the fields of `ContiguousNDArrayFields` to be +/// traversed to do useful work such as: +/// +/// - To create the [`StructType`] of `ContiguousNDArray` by collecting [`BasicType`]s of the fields. +/// - To enable the `.gep(ctx, |f| f.ndims).store(ctx, ...)` syntax. +/// +/// Suppose now that you have defined `ContiguousNDArray` and you want to allocate a `ContiguousNDArray` +/// with dtype `float64` in LLVM, this is how you do it: +/// ```ignore +/// type F64NDArray = Struct>>; // Type alias for leaner documentation +/// let model: F64NDArray = Struct(ContigousNDArray { item: Float(Float64) }); +/// let ndarray: Instance<'ctx, Ptr> = model.alloca(generator, ctx); +/// ``` +/// +/// ...and here is how you may manipulate/access `ndarray`: +/// +/// (NOTE: some arguments have been omitted) +/// +/// ```ignore +/// // Get `&ndarray->data` +/// ndarray.gep(|f| f.data); // type: Instance<'ctx, Ptr>> +/// +/// // Get `ndarray->ndims` +/// ndarray.get(|f| f.ndims); // type: Instance<'ctx, Int> +/// +/// // Get `&ndarray->ndims` +/// ndarray.gep(|f| f.ndims); // type: Instance<'ctx, Ptr>> +/// +/// // Get `ndarray->shape[0]` +/// ndarray.get(|f| f.shape).get_index_const(0); // Instance<'ctx, Int> +/// +/// // Get `&ndarray->shape[2]` +/// ndarray.get(|f| f.shape).offset_const(2); // Instance<'ctx, Ptr>> +/// +/// // Do `ndarray->ndims = 3;` +/// let num_3 = Int(SizeT).const_int(3); +/// ndarray.set(|f| f.ndims, num_3); +/// ``` +pub trait StructKind<'ctx>: fmt::Debug + Clone + Copy { + /// The associated fields of this struct. + type Fields>; + + /// Traverse through all fields of this [`StructKind`]. + /// + /// Only used internally in this module for implementing other components. + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields; + + /// Get a convenience structure to get a struct field's GEP index through its corresponding Rust field. + /// + /// Only used internally in this module for implementing other components. + fn fields(&self) -> Self::Fields { + self.traverse_fields(&mut GepFieldTraversal { gep_index_counter: 0 }) + } + + /// Get the LLVM [`StructType`] of this [`StructKind`]. + fn get_struct_type( + &self, + generator: &G, + ctx: &'ctx Context, + ) -> StructType<'ctx> { + let mut traversal = TypeFieldTraversal { generator, ctx, field_types: Vec::new() }; + self.traverse_fields(&mut traversal); + + ctx.struct_type(&traversal.field_types, false) + } +} + +/// A model for LLVM struct. +/// +/// `S` should be of a [`StructKind`]. +#[derive(Debug, Clone, Copy, Default)] +pub struct Struct(pub S); + +impl<'ctx, S: StructKind<'ctx>> Struct { + /// Create a constant struct value from its fields. + /// + /// This function also validates `fields` and panic when there is something wrong. + pub fn const_struct( + &self, + generator: &mut G, + ctx: &'ctx Context, + fields: &[BasicValueEnum<'ctx>], + ) -> Instance<'ctx, Self> { + // NOTE: There *could* have been a functor `F = Instance<'ctx, M>` for `S::Fields` + // to create a more user-friendly interface, but Rust's type system is not sophisticated enough + // and if you try doing that Rust would force you put lifetimes everywhere. + let val = ctx.const_struct(fields, false); + self.check_value(generator, ctx, val).unwrap() + } +} + +impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for Struct { + type Value = StructValue<'ctx>; + type Type = StructType<'ctx>; + + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { + self.0.get_struct_type(generator, ctx) + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + let Ok(ty) = StructType::try_from(ty) else { + return Err(ModelError(format!("Expecting StructType, but got {ty:?}"))); + }; + + // Check each field individually. + let mut traversal = CheckTypeFieldTraversal { + generator, + ctx, + gep_index_counter: 0, + errors: Vec::new(), + scrutinee: ty, + }; + self.0.traverse_fields(&mut traversal); + + // Check the number of fields. + let exp_num_fields = traversal.gep_index_counter; + let got_num_fields = u32::try_from(ty.get_field_types().len()).unwrap(); + if exp_num_fields != got_num_fields { + return Err(ModelError(format!( + "Expecting StructType with {exp_num_fields} field(s), but got {got_num_fields}" + ))); + } + + if !traversal.errors.is_empty() { + // Currently, only the first error is reported. + return Err(traversal.errors[0].clone()); + } + + Ok(()) + } +} + +impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Struct> { + /// Get a field with [`StructValue::get_field_at_index`]. + pub fn get_field( + &self, + generator: &mut G, + ctx: &'ctx Context, + get_field: GetField, + ) -> Instance<'ctx, M> + where + M: Model<'ctx>, + GetField: FnOnce(S::Fields) -> GepField, + { + let field = get_field(self.model.0.fields()); + let val = self.value.get_field_at_index(field.gep_index as u32).unwrap(); + field.model.check_value(generator, ctx, val).unwrap() + } +} + +impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr>> { + /// Get a pointer to a field with [`Builder::build_in_bounds_gep`]. + pub fn gep( + &self, + ctx: &CodeGenContext<'ctx, '_>, + get_field: GetField, + ) -> Instance<'ctx, Ptr> + where + M: Model<'ctx>, + GetField: FnOnce(S::Fields) -> GepField, + { + let field = get_field(self.model.0 .0.fields()); + let llvm_i32 = ctx.ctx.i32_type(); + + let ptr = unsafe { + ctx.builder + .build_in_bounds_gep( + self.value, + &[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)], + field.name, + ) + .unwrap() + }; + + Ptr(field.model).believe_value(ptr) + } + + /// Convenience function equivalent to `.gep(...).load(...)`. + pub fn get( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + get_field: GetField, + ) -> Instance<'ctx, M> + where + M: Model<'ctx>, + GetField: FnOnce(S::Fields) -> GepField, + { + self.gep(ctx, get_field).load(generator, ctx) + } + + /// Convenience function equivalent to `.gep(...).store(...)`. + pub fn set( + &self, + ctx: &CodeGenContext<'ctx, '_>, + get_field: GetField, + value: Instance<'ctx, M>, + ) where + M: Model<'ctx>, + GetField: FnOnce(S::Fields) -> GepField, + { + self.gep(ctx, get_field).store(ctx, value); + } +}