diff --git a/nac3core/src/codegen/model/any.rs b/nac3core/src/codegen/model/any.rs index 9df863e..b74ba32 100644 --- a/nac3core/src/codegen/model/any.rs +++ b/nac3core/src/codegen/model/any.rs @@ -1,11 +1,9 @@ use inkwell::{ context::Context, - types::{BasicType, BasicTypeEnum}, - values::BasicValueEnum, + types::{BasicType, BasicTypeEnum, IntType}, + values::IntValue, }; -use crate::codegen::CodeGenerator; - use super::*; /// A [`Model`] of any [`BasicTypeEnum`]. @@ -14,25 +12,17 @@ use super::*; #[derive(Debug, Clone, Copy)] pub struct Any<'ctx>(pub BasicTypeEnum<'ctx>); -impl<'ctx> Model<'ctx> for Any<'ctx> { - type Value = BasicValueEnum<'ctx>; - type Type = BasicTypeEnum<'ctx>; - - fn get_type( - &self, - _generator: &G, - _ctx: &'ctx Context, - ) -> Self::Type { - self.0 +impl<'ctx> ModelBase<'ctx> for Any<'ctx> { + fn get_type_impl(&self, _size_t: IntType<'ctx>, _ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.as_basic_type_enum() } - fn check_type, G: CodeGenerator + ?Sized>( + fn check_type_impl( &self, - _generator: &mut G, + _size_t: IntType<'ctx>, _ctx: &'ctx Context, - ty: T, + ty: BasicTypeEnum<'ctx>, ) -> Result<(), ModelError> { - let ty = ty.as_basic_type_enum(); if ty == self.0 { Ok(()) } else { @@ -40,3 +30,8 @@ impl<'ctx> Model<'ctx> for Any<'ctx> { } } } + +impl<'ctx> Model<'ctx> for Any<'ctx> { + type Type = IntType<'ctx>; + type Value = IntValue<'ctx>; +} diff --git a/nac3core/src/codegen/model/array.rs b/nac3core/src/codegen/model/array.rs index be8dc0b..70cd129 100644 --- a/nac3core/src/codegen/model/array.rs +++ b/nac3core/src/codegen/model/array.rs @@ -2,7 +2,7 @@ use std::fmt; use inkwell::{ context::Context, - types::{ArrayType, BasicType, BasicTypeEnum}, + types::{ArrayType, BasicType, BasicTypeEnum, IntType}, values::{ArrayValue, IntValue}, }; @@ -46,21 +46,18 @@ pub struct Array { pub item: Item, } -impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array { - type Value = ArrayValue<'ctx>; - type Type = ArrayType<'ctx>; - - fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { - self.item.get_type(generator, ctx).array_type(self.len.get_length()) +impl<'ctx, Len: LenKind, Item: ModelBase<'ctx>> ModelBase<'ctx> for Array { + fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + let item = self.item.get_type_impl(size_t, ctx); + item.array_type(self.len.get_length()).into() } - fn check_type, G: CodeGenerator + ?Sized>( + fn check_type_impl( &self, - generator: &mut G, + size_t: IntType<'ctx>, ctx: &'ctx Context, - ty: T, + ty: BasicTypeEnum<'ctx>, ) -> Result<(), ModelError> { - let ty = ty.as_basic_type_enum(); let BasicTypeEnum::ArrayType(ty) = ty else { return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}"))); }; @@ -74,13 +71,18 @@ impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array { } self.item - .check_type(generator, ctx, ty.get_element_type()) + .check_type_impl(size_t, ctx, ty.get_element_type()) .map_err(|err| err.under_context("an ArrayType"))?; Ok(()) } } +impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array { + type Type = ArrayType<'ctx>; + type Value = ArrayValue<'ctx>; +} + impl<'ctx, Len: LenKind, Item: Model<'ctx>> Instance<'ctx, Ptr>> { /// Get the pointer to the `i`-th (0-based) array element. pub fn gep( diff --git a/nac3core/src/codegen/model/core.rs b/nac3core/src/codegen/model/core.rs index 25faeea..55ad463 100644 --- a/nac3core/src/codegen/model/core.rs +++ b/nac3core/src/codegen/model/core.rs @@ -19,6 +19,23 @@ impl ModelError { } } +// NOTE: A watered down version of `Model` trait. Made to be object safe. +pub trait ModelBase<'ctx> { + // NOTE: Taking `size_t` here instead of `CodeGenerator` to be object safe. + // In fact, all the entire model abstraction need from the `CodeGenerator` is its `get_size_type()`. + + // NOTE: Model's get_type but object-safe and returns BasicTypeEnum, instead of a known BasicType variant. + fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>; + + // NOTE: Model's check_type but object-safe. + fn check_type_impl( + &self, + size_t: IntType<'ctx>, + ctx: &'ctx Context, + scrutinee: BasicTypeEnum<'ctx>, + ) -> Result<(), ModelError>; +} + /// Trait for Rust structs identifying [`BasicType`]s in the context of a known [`CodeGenerator`] and [`CodeGenContext`]. /// /// For instance, @@ -59,16 +76,24 @@ impl ModelError { /// // or, if you are absolutely certain that `my_value` is 32-bit and doing extra checks is a waste of time: /// let my_value = Int(Int32).believe_value(my_value); /// ``` -pub trait Model<'ctx>: fmt::Debug + Clone + Copy { +pub trait Model<'ctx>: fmt::Debug + Clone + Copy + ModelBase<'ctx> { /// The [`BasicType`] *variant* this model is identifying. - type Type: BasicType<'ctx>; + type Type: BasicType<'ctx> + TryFrom>; /// The [`BasicValue`] type of the [`BasicType`] of this model. type Value: BasicValue<'ctx> + TryFrom>; /// Return the [`BasicType`] of this model. #[must_use] - fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type; + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { + let size_t = generator.get_size_type(ctx); + + let ty = self.get_type_impl(size_t, ctx); + match Self::Type::try_from(ty) { + Ok(ty) => ty, + _ => panic!("Model::Type is inconsistent with what is returned from ModelBase::get_type_impl()! Got {ty:?}."), + } + } /// Get the number of bytes of the [`BasicType`] of this model. fn sizeof( @@ -85,7 +110,10 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy { generator: &mut G, ctx: &'ctx Context, ty: T, - ) -> Result<(), ModelError>; + ) -> Result<(), ModelError> { + let size_t = generator.get_size_type(ctx); + self.check_type_impl(size_t, ctx, ty.as_basic_type_enum()) + } /// Create an instance from a value. /// diff --git a/nac3core/src/codegen/model/float.rs b/nac3core/src/codegen/model/float.rs index 88bff80..ef5111a 100644 --- a/nac3core/src/codegen/model/float.rs +++ b/nac3core/src/codegen/model/float.rs @@ -2,20 +2,14 @@ use std::fmt; use inkwell::{ context::Context, - types::{BasicType, FloatType}, + types::{BasicTypeEnum, FloatType, IntType}, values::FloatValue, }; -use crate::codegen::CodeGenerator; - use super::*; pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy { - fn get_float_type( - &self, - generator: &G, - ctx: &'ctx Context, - ) -> FloatType<'ctx>; + fn get_float_type(&self, ctx: &'ctx Context) -> FloatType<'ctx>; } #[derive(Debug, Clone, Copy, Default)] @@ -24,21 +18,13 @@ pub struct Float32; pub struct Float64; impl<'ctx> FloatKind<'ctx> for Float32 { - fn get_float_type( - &self, - _generator: &G, - ctx: &'ctx Context, - ) -> FloatType<'ctx> { + fn get_float_type(&self, ctx: &'ctx Context) -> FloatType<'ctx> { ctx.f32_type() } } impl<'ctx> FloatKind<'ctx> for Float64 { - fn get_float_type( - &self, - _generator: &G, - ctx: &'ctx Context, - ) -> FloatType<'ctx> { + fn get_float_type(&self, ctx: &'ctx Context) -> FloatType<'ctx> { ctx.f64_type() } } @@ -47,11 +33,7 @@ impl<'ctx> FloatKind<'ctx> for Float64 { pub struct AnyFloat<'ctx>(FloatType<'ctx>); impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> { - fn get_float_type( - &self, - _generator: &G, - _ctx: &'ctx Context, - ) -> FloatType<'ctx> { + fn get_float_type(&self, _ctx: &'ctx Context) -> FloatType<'ctx> { self.0 } } @@ -59,32 +41,31 @@ impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> { #[derive(Debug, Clone, Copy, Default)] pub struct Float(pub N); -impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for Float { - type Value = FloatValue<'ctx>; - type Type = FloatType<'ctx>; - - fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { - self.0.get_float_type(generator, ctx) +impl<'ctx, N: FloatKind<'ctx>> ModelBase<'ctx> for Float { + fn get_type_impl(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.get_float_type(ctx).into() } - fn check_type, G: CodeGenerator + ?Sized>( + fn check_type_impl( &self, - generator: &mut G, + _size_t: IntType<'ctx>, ctx: &'ctx Context, - ty: T, + ty: BasicTypeEnum<'ctx>, ) -> Result<(), ModelError> { - let ty = ty.as_basic_type_enum(); let Ok(ty) = FloatType::try_from(ty) else { return Err(ModelError(format!("Expecting FloatType, but got {ty:?}"))); }; - let exp_ty = self.0.get_float_type(generator, ctx); - - // TODO: Inkwell does not have get_bit_width for FloatType? - if ty != exp_ty { - return Err(ModelError(format!("Expecting {exp_ty:?}, but got {ty:?}"))); + let expected_ty = self.0.get_float_type(ctx); + if ty != expected_ty { + return Err(ModelError(format!("Expecting {expected_ty:?}, but got {ty:?}"))); } Ok(()) } } + +impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for Float { + type Value = FloatValue<'ctx>; + type Type = FloatType<'ctx>; +} diff --git a/nac3core/src/codegen/model/int.rs b/nac3core/src/codegen/model/int.rs index 3a8a4fe..13667cf 100644 --- a/nac3core/src/codegen/model/int.rs +++ b/nac3core/src/codegen/model/int.rs @@ -2,7 +2,7 @@ use std::{cmp::Ordering, fmt}; use inkwell::{ context::Context, - types::{BasicType, IntType}, + types::{BasicTypeEnum, IntType}, values::IntValue, IntPredicate, }; @@ -12,11 +12,7 @@ use crate::codegen::{CodeGenContext, CodeGenerator}; use super::*; pub trait IntKind<'ctx>: fmt::Debug + Clone + Copy { - fn get_int_type( - &self, - generator: &G, - ctx: &'ctx Context, - ) -> IntType<'ctx>; + fn get_int_type(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx>; } #[derive(Debug, Clone, Copy, Default)] @@ -31,52 +27,32 @@ pub struct Int64; pub struct SizeT; impl<'ctx> IntKind<'ctx> for Bool { - fn get_int_type( - &self, - _generator: &G, - ctx: &'ctx Context, - ) -> IntType<'ctx> { + fn get_int_type(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> { ctx.bool_type() } } impl<'ctx> IntKind<'ctx> for Byte { - fn get_int_type( - &self, - _generator: &G, - ctx: &'ctx Context, - ) -> IntType<'ctx> { + fn get_int_type(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> { ctx.i8_type() } } impl<'ctx> IntKind<'ctx> for Int32 { - fn get_int_type( - &self, - _generator: &G, - ctx: &'ctx Context, - ) -> IntType<'ctx> { + fn get_int_type(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> { ctx.i32_type() } } impl<'ctx> IntKind<'ctx> for Int64 { - fn get_int_type( - &self, - _generator: &G, - ctx: &'ctx Context, - ) -> IntType<'ctx> { + fn get_int_type(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> { ctx.i64_type() } } impl<'ctx> IntKind<'ctx> for SizeT { - fn get_int_type( - &self, - generator: &G, - ctx: &'ctx Context, - ) -> IntType<'ctx> { - generator.get_size_type(ctx) + fn get_int_type(&self, size_t: IntType<'ctx>, _ctx: &'ctx Context) -> IntType<'ctx> { + size_t } } @@ -84,11 +60,7 @@ impl<'ctx> IntKind<'ctx> for SizeT { pub struct AnyInt<'ctx>(pub IntType<'ctx>); impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> { - fn get_int_type( - &self, - _generator: &G, - _ctx: &'ctx Context, - ) -> IntType<'ctx> { + fn get_int_type(&self, _size_t: IntType<'ctx>, _ctx: &'ctx Context) -> IntType<'ctx> { self.0 } } @@ -96,26 +68,22 @@ impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> { #[derive(Debug, Clone, Copy, Default)] pub struct Int(pub N); -impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int { - type Value = IntValue<'ctx>; - type Type = IntType<'ctx>; - - fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { - self.0.get_int_type(generator, ctx) +impl<'ctx, N: IntKind<'ctx>> ModelBase<'ctx> for Int { + fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.get_int_type(size_t, ctx).into() } - fn check_type, G: CodeGenerator + ?Sized>( + fn check_type_impl( &self, - generator: &mut G, + size_t: IntType<'ctx>, ctx: &'ctx Context, - ty: T, + ty: BasicTypeEnum<'ctx>, ) -> Result<(), ModelError> { - let ty = ty.as_basic_type_enum(); let Ok(ty) = IntType::try_from(ty) else { return Err(ModelError(format!("Expecting IntType, but got {ty:?}"))); }; - let exp_ty = self.0.get_int_type(generator, ctx); + let exp_ty = self.0.get_int_type(size_t, ctx); if ty.get_bit_width() != exp_ty.get_bit_width() { return Err(ModelError(format!( "Expecting IntType to have {} bit(s), but got {} bit(s)", @@ -128,6 +96,11 @@ impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int { } } +impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int { + type Type = IntType<'ctx>; + type Value = IntValue<'ctx>; +} + impl<'ctx, N: IntKind<'ctx>> Int { pub fn const_int( &self, @@ -173,7 +146,7 @@ impl<'ctx, N: IntKind<'ctx>> Int { ) -> Instance<'ctx, Self> { assert!( value.get_type().get_bit_width() - <= self.0.get_int_type(generator, ctx.ctx).get_bit_width() + <= self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width() ); let value = ctx .builder @@ -190,7 +163,7 @@ impl<'ctx, N: IntKind<'ctx>> Int { ) -> Instance<'ctx, Self> { assert!( value.get_type().get_bit_width() - < self.0.get_int_type(generator, ctx.ctx).get_bit_width() + < self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width() ); let value = ctx.builder.build_int_s_extend(value, self.get_type(generator, ctx.ctx), "").unwrap(); @@ -205,7 +178,7 @@ impl<'ctx, N: IntKind<'ctx>> Int { ) -> Instance<'ctx, Self> { assert!( value.get_type().get_bit_width() - <= self.0.get_int_type(generator, ctx.ctx).get_bit_width() + <= self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width() ); let value = ctx .builder @@ -222,7 +195,7 @@ impl<'ctx, N: IntKind<'ctx>> Int { ) -> Instance<'ctx, Self> { assert!( value.get_type().get_bit_width() - < self.0.get_int_type(generator, ctx.ctx).get_bit_width() + < self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width() ); let value = ctx.builder.build_int_z_extend(value, self.get_type(generator, ctx.ctx), "").unwrap(); @@ -237,7 +210,7 @@ impl<'ctx, N: IntKind<'ctx>> Int { ) -> Instance<'ctx, Self> { assert!( value.get_type().get_bit_width() - >= self.0.get_int_type(generator, ctx.ctx).get_bit_width() + >= self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width() ); let value = ctx .builder @@ -254,7 +227,7 @@ impl<'ctx, N: IntKind<'ctx>> Int { ) -> Instance<'ctx, Self> { assert!( value.get_type().get_bit_width() - > self.0.get_int_type(generator, ctx.ctx).get_bit_width() + > self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width() ); let value = ctx.builder.build_int_truncate(value, self.get_type(generator, ctx.ctx), "").unwrap(); @@ -269,7 +242,8 @@ impl<'ctx, N: IntKind<'ctx>> Int { value: IntValue<'ctx>, ) -> Instance<'ctx, Self> { let their_width = value.get_type().get_bit_width(); - let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width(); + let our_width = + self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width(); match their_width.cmp(&our_width) { Ordering::Less => self.s_extend(generator, ctx, value), Ordering::Equal => self.believe_value(value), @@ -285,7 +259,8 @@ impl<'ctx, N: IntKind<'ctx>> Int { value: IntValue<'ctx>, ) -> Instance<'ctx, Self> { let their_width = value.get_type().get_bit_width(); - let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width(); + let our_width = + self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width(); match their_width.cmp(&our_width) { Ordering::Less => self.z_extend(generator, ctx, value), Ordering::Equal => self.believe_value(value), diff --git a/nac3core/src/codegen/model/ptr.rs b/nac3core/src/codegen/model/ptr.rs index e2d6128..ac9493b 100644 --- a/nac3core/src/codegen/model/ptr.rs +++ b/nac3core/src/codegen/model/ptr.rs @@ -1,6 +1,6 @@ use inkwell::{ context::Context, - types::{BasicType, BasicTypeEnum, PointerType}, + types::{BasicType, BasicTypeEnum, IntType, PointerType}, values::{IntValue, PointerValue}, AddressSpace, }; @@ -23,26 +23,23 @@ pub struct Ptr(pub Item); /// `.load()/.store()` is not available for [`Instance`]s of opaque pointers. pub type OpaquePtr = Ptr<()>; -// TODO: LLVM 15: `Item: Model<'ctx>` don't even need to be a model anymore. It will only be +// TODO: LLVM 15: `Item: ModelBase<'ctx>` don't even need to be a model anymore. It will only be // a type-hint for the `.load()/.store()` functions for the `pointee_ty`. // // See https://thedan64.github.io/inkwell/inkwell/builder/struct.Builder.html#method.build_load. -impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr { - type Value = PointerValue<'ctx>; - type Type = PointerType<'ctx>; - - fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { +impl<'ctx, Item: ModelBase<'ctx>> ModelBase<'ctx> for Ptr { + fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { // TODO: LLVM 15: ctx.ptr_type(AddressSpace::default()) - self.0.get_type(generator, ctx).ptr_type(AddressSpace::default()) + let item = self.0.get_type_impl(size_t, ctx); + item.ptr_type(AddressSpace::default()).into() } - fn check_type, G: CodeGenerator + ?Sized>( + fn check_type_impl( &self, - generator: &mut G, + size_t: IntType<'ctx>, ctx: &'ctx Context, - ty: T, + ty: BasicTypeEnum<'ctx>, ) -> Result<(), ModelError> { - let ty = ty.as_basic_type_enum(); let Ok(ty) = PointerType::try_from(ty) else { return Err(ModelError(format!("Expecting PointerType, but got {ty:?}"))); }; @@ -57,13 +54,18 @@ impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr { // TODO: inkwell `get_element_type()` will be deprecated. // Remove the check for `get_element_type()` when the time comes. self.0 - .check_type(generator, ctx, elem_ty) + .check_type_impl(size_t, ctx, elem_ty) .map_err(|err| err.under_context("a PointerType"))?; Ok(()) } } +impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr { + type Type = PointerType<'ctx>; + type Value = PointerValue<'ctx>; +} + impl<'ctx, Item: Model<'ctx>> Ptr { /// Return a ***constant*** nullptr. pub fn nullptr( @@ -71,6 +73,7 @@ impl<'ctx, Item: Model<'ctx>> Ptr { generator: &mut G, ctx: &'ctx Context, ) -> Instance<'ctx, Ptr> { + // TODO: LLVM 15: Write in an impl where `Item` does not have to be `Model<'ctx>`. let ptr = self.get_type(generator, ctx).const_null(); self.believe_value(ptr) } diff --git a/nac3core/src/codegen/model/structure.rs b/nac3core/src/codegen/model/structure.rs index a989904..f94f460 100644 --- a/nac3core/src/codegen/model/structure.rs +++ b/nac3core/src/codegen/model/structure.rs @@ -1,290 +1,141 @@ -use std::fmt; +use std::{fmt, marker::PhantomData}; use inkwell::{ context::Context, - types::{BasicType, BasicTypeEnum, StructType}, + types::{BasicType, BasicTypeEnum, IntType, StructType}, values::{BasicValueEnum, StructValue}, }; +use itertools::{izip, Itertools}; use crate::codegen::{CodeGenContext, CodeGenerator}; use super::*; -/// A traveral that traverses a Rust `struct` that is used to declare an LLVM's struct's field types. -pub trait FieldTraversal<'ctx> { - /// Output type of [`FieldTraversal::add`]. - type Out; +// pub trait StructKind2<'ctx>: fmt::Debug + Clone + Copy { +// type Fields> = ; +// } - /// Traverse through the type of a declared field and do something with it. - /// - /// * `name` - The cosmetic name of the LLVM field. Used for debugging. - /// * `model` - The [`Model`] representing the LLVM type of this field. - fn add>(&mut self, name: &'static str, model: M) -> Self::Out; +pub struct Field { + gep_index: u32, + model: M, + name: &'static str, +} - /// Like [`FieldTraversal::add`] but [`Model`] is automatically inferred from its [`Default`] trait. - fn add_auto + Default>(&mut self, name: &'static str) -> Self::Out { +// NOTE: Very similar to Field, but is forall on `M`, (and also uses ModelBase to get object safety for the `Box`. +pub struct Entry<'ctx> { + model: Box + 'ctx>, + name: &'static str, +} + +pub struct FieldMapper<'ctx> { + gep_index_counter: u32, + entries: Vec>, +} + +impl<'ctx> FieldMapper<'ctx> { + fn add>(&mut self, name: &'static str, model: M) -> Field { + let entry = Entry { model: Box::new(model), name }; + self.entries.push(entry); + + let gep_index = self.gep_index_counter; + self.gep_index_counter += 1; + Field { gep_index, model, name } + } + + fn add_auto + Default>(&mut self, name: &'static str) -> Field { self.add(name, M::default()) } } -/// Descriptor of an LLVM struct field. -#[derive(Debug, Clone, Copy)] -pub struct GepField { - /// The GEP index of this field. This is the index to use with `build_gep`. - pub gep_index: u64, - /// The cosmetic name of this field. - pub name: &'static str, - /// The [`Model`] of this field's type. - pub model: M, -} - -/// A traversal to calculate the GEP index of fields. -pub struct GepFieldTraversal { - /// The current GEP index. - gep_index_counter: u64, -} - -impl<'ctx> FieldTraversal<'ctx> for GepFieldTraversal { - type Out = GepField; - - fn add>(&mut self, name: &'static str, model: M) -> Self::Out { - let gep_index = self.gep_index_counter; - self.gep_index_counter += 1; - Self::Out { gep_index, name, model } - } -} - -/// A traversal to collect the field types of a struct. -/// -/// This is used to collect field types and construct the LLVM struct type with [`Context::struct_type`]. -struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> { - generator: &'a G, - ctx: &'ctx Context, - /// The collected field types so far in exact order. - field_types: Vec>, -} - -impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> { - type Out = (); // Checking types return nothing. - - fn add>(&mut self, _name: &'static str, model: M) -> Self::Out { - let t = model.get_type(self.generator, self.ctx).as_basic_type_enum(); - self.field_types.push(t); - } -} - -/// A traversal to check the types of fields. -struct CheckTypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> { - generator: &'a mut G, - ctx: &'ctx Context, - /// The current GEP index, so we can tell the index of the field we are checking - /// and report the GEP index. - gep_index_counter: u32, - /// The [`StructType`] to check. - scrutinee: StructType<'ctx>, - /// The list of collected errors so far. - errors: Vec, -} - -impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> - for CheckTypeFieldTraversal<'ctx, 'a, G> -{ - type Out = (); // Checking types return nothing. - - fn add>(&mut self, name: &'static str, model: M) -> Self::Out { - let gep_index = self.gep_index_counter; - self.gep_index_counter += 1; - - if let Some(t) = self.scrutinee.get_field_type_at_index(gep_index) { - if let Err(err) = model.check_type(self.generator, self.ctx, t) { - self.errors - .push(err.under_context(format!("field #{gep_index} '{name}'").as_str())); - } - } // Otherwise, it will be caught by Struct's `check_type`. - } -} - -/// A trait for Rust structs identifying LLVM structures. -/// -/// ### Example -/// -/// Suppose you want to define this structure: -/// ```c -/// template -/// struct ContiguousNDArray { -/// size_t ndims; -/// size_t* shape; -/// T* data; -/// } -/// ``` -/// -/// This is how it should be done: -/// ```ignore -/// pub struct ContiguousNDArrayFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> { -/// pub ndims: F::Out>, -/// pub shape: F::Out>>, -/// pub data: F::Out>, -/// } -/// -/// /// An ndarray without strides and non-opaque `data` field in NAC3. -/// #[derive(Debug, Clone, Copy)] -/// pub struct ContiguousNDArray { -/// /// [`Model`] of the items. -/// pub item: M, -/// } -/// -/// impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for ContiguousNDArray { -/// type Fields> = ContiguousNDArrayFields<'ctx, F, Item>; -/// -/// fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { -/// // The order of `traversal.add*` is important -/// Self::Fields { -/// ndims: traversal.add_auto("ndims"), -/// shape: traversal.add_auto("shape"), -/// data: traversal.add("data", Ptr(self.item)), -/// } -/// } -/// } -/// ``` -/// -/// The [`FieldTraversal`] here is a mechanism to allow the fields of `ContiguousNDArrayFields` to be -/// traversed to do useful work such as: -/// -/// - To create the [`StructType`] of `ContiguousNDArray` by collecting [`BasicType`]s of the fields. -/// - To enable the `.gep(ctx, |f| f.ndims).store(ctx, ...)` syntax. -/// -/// Suppose now that you have defined `ContiguousNDArray` and you want to allocate a `ContiguousNDArray` -/// with dtype `float64` in LLVM, this is how you do it: -/// ```ignore -/// type F64NDArray = Struct>>; // Type alias for leaner documentation -/// let model: F64NDArray = Struct(ContigousNDArray { item: Float(Float64) }); -/// let ndarray: Instance<'ctx, Ptr> = model.alloca(generator, ctx); -/// ``` -/// -/// ...and here is how you may manipulate/access `ndarray`: -/// -/// (NOTE: some arguments have been omitted) -/// -/// ```ignore -/// // Get `&ndarray->data` -/// ndarray.gep(|f| f.data); // type: Instance<'ctx, Ptr>> -/// -/// // Get `ndarray->ndims` -/// ndarray.get(|f| f.ndims); // type: Instance<'ctx, Int> -/// -/// // Get `&ndarray->ndims` -/// ndarray.gep(|f| f.ndims); // type: Instance<'ctx, Ptr>> -/// -/// // Get `ndarray->shape[0]` -/// ndarray.get(|f| f.shape).get_index_const(0); // Instance<'ctx, Int> -/// -/// // Get `&ndarray->shape[2]` -/// ndarray.get(|f| f.shape).offset_const(2); // Instance<'ctx, Ptr>> -/// -/// // Do `ndarray->ndims = 3;` -/// let num_3 = Int(SizeT).const_int(3); -/// ndarray.set(|f| f.ndims, num_3); -/// ``` pub trait StructKind<'ctx>: fmt::Debug + Clone + Copy { - /// The associated fields of this struct. - type Fields>; + type Fields; - /// Traverse through all fields of this [`StructKind`]. - /// - /// Only used internally in this module for implementing other components. - fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields; + fn iter_fields(&self, mapper: &mut FieldMapper<'ctx>) -> Self::Fields; - /// Get a convenience structure to get a struct field's GEP index through its corresponding Rust field. - /// - /// Only used internally in this module for implementing other components. - fn fields(&self) -> Self::Fields { - self.traverse_fields(&mut GepFieldTraversal { gep_index_counter: 0 }) + // Produce `Vec` and `Self::Fields` simultaneously. + // The former is for doing field-wise type checks. + // The latter is for enabling the `.gep(|f| f.data)` syntax. + fn entries_and_fields(&self) -> (Vec>, Self::Fields) { + let mut mapper = FieldMapper { gep_index_counter: 0, entries: Vec::new() }; + let fields = self.iter_fields(&mut mapper); + (mapper.entries, fields) + } + + fn entries(&self) -> Vec> { + self.entries_and_fields().0 + } + + fn fields(&self) -> Self::Fields { + self.entries_and_fields().1 } /// Get the LLVM [`StructType`] of this [`StructKind`]. - fn get_struct_type( - &self, - generator: &G, - ctx: &'ctx Context, - ) -> StructType<'ctx> { - let mut traversal = TypeFieldTraversal { generator, ctx, field_types: Vec::new() }; - self.traverse_fields(&mut traversal); - - ctx.struct_type(&traversal.field_types, false) + fn get_struct_type(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> StructType<'ctx> { + let entries = self.entries(); + let entries = entries.into_iter().map(|t| t.model.get_type_impl(size_t, ctx)).collect_vec(); + ctx.struct_type(&entries, false) } } -/// A model for LLVM struct. -/// -/// `S` should be of a [`StructKind`]. #[derive(Debug, Clone, Copy, Default)] pub struct Struct(pub S); impl<'ctx, S: StructKind<'ctx>> Struct { - /// Create a constant struct value from its fields. - /// - /// This function also validates `fields` and panic when there is something wrong. pub fn const_struct( &self, generator: &mut G, ctx: &'ctx Context, fields: &[BasicValueEnum<'ctx>], ) -> Instance<'ctx, Self> { - // NOTE: There *could* have been a functor `F = Instance<'ctx, M>` for `S::Fields` - // to create a more user-friendly interface, but Rust's type system is not sophisticated enough - // and if you try doing that Rust would force you put lifetimes everywhere. let val = ctx.const_struct(fields, false); self.check_value(generator, ctx, val).unwrap() } } -impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for Struct { - type Value = StructValue<'ctx>; - type Type = StructType<'ctx>; - - fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { - self.0.get_struct_type(generator, ctx) +impl<'ctx, S: StructKind<'ctx>> ModelBase<'ctx> for Struct { + fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.get_struct_type(size_t, ctx).as_basic_type_enum() } - fn check_type, G: CodeGenerator + ?Sized>( + fn check_type_impl( &self, - generator: &mut G, + size_t: IntType<'ctx>, ctx: &'ctx Context, - ty: T, + ty: BasicTypeEnum<'ctx>, ) -> Result<(), ModelError> { - let ty = ty.as_basic_type_enum(); let Ok(ty) = StructType::try_from(ty) else { return Err(ModelError(format!("Expecting StructType, but got {ty:?}"))); }; - // Check each field individually. - let mut traversal = CheckTypeFieldTraversal { - generator, - ctx, - gep_index_counter: 0, - errors: Vec::new(), - scrutinee: ty, - }; - self.0.traverse_fields(&mut traversal); + let entries = self.0.entries(); + let field_types = ty.get_field_types(); // Check the number of fields. - let exp_num_fields = traversal.gep_index_counter; - let got_num_fields = u32::try_from(ty.get_field_types().len()).unwrap(); - if exp_num_fields != got_num_fields { + if entries.len() != field_types.len() { return Err(ModelError(format!( - "Expecting StructType with {exp_num_fields} field(s), but got {got_num_fields}" + "Expecting StructType with {} field(s), but got {}", + entries.len(), + field_types.len() ))); } - if !traversal.errors.is_empty() { - // Currently, only the first error is reported. - return Err(traversal.errors[0].clone()); + // Check each field. + for (i, (entry, field_type)) in izip!(entries, field_types).enumerate() { + entry.model.check_type_impl(size_t, ctx, field_type).map_err(|err| { + let context = &format!("in field #{i} '{}'", entry.name); + err.under_context(context) + })?; } Ok(()) } } +impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for Struct { + type Type = StructType<'ctx>; + type Value = StructValue<'ctx>; +} + impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Struct> { /// Get a field with [`StructValue::get_field_at_index`]. pub fn get_field( @@ -295,10 +146,10 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Struct> { ) -> Instance<'ctx, M> where M: Model<'ctx>, - GetField: FnOnce(S::Fields) -> GepField, + GetField: FnOnce(S::Fields) -> Field, { let field = get_field(self.model.0.fields()); - let val = self.value.get_field_at_index(field.gep_index as u32).unwrap(); + let val = self.value.get_field_at_index(field.gep_index).unwrap(); field.model.check_value(generator, ctx, val).unwrap() } } @@ -312,7 +163,7 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr>> { ) -> Instance<'ctx, Ptr> where M: Model<'ctx>, - GetField: FnOnce(S::Fields) -> GepField, + GetField: FnOnce(S::Fields) -> Field, { let field = get_field(self.model.0 .0.fields()); let llvm_i32 = ctx.ctx.i32_type(); @@ -321,7 +172,7 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr>> { ctx.builder .build_in_bounds_gep( self.value, - &[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)], + &[llvm_i32.const_zero(), llvm_i32.const_int(u64::from(field.gep_index), false)], field.name, ) .unwrap() @@ -339,7 +190,7 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr>> { ) -> Instance<'ctx, M> where M: Model<'ctx>, - GetField: FnOnce(S::Fields) -> GepField, + GetField: FnOnce(S::Fields) -> Field, { self.gep(ctx, get_field).load(generator, ctx) } @@ -352,8 +203,65 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr>> { value: Instance<'ctx, M>, ) where M: Model<'ctx>, - GetField: FnOnce(S::Fields) -> GepField, + GetField: FnOnce(S::Fields) -> Field, { self.gep(ctx, get_field).store(ctx, value); } } + +/////////////////////// Example; Delete later + +// Example: NDArray. +// +// Compared to List, it has no generic models. +pub struct NDArrayFields { + data: Field>>, + itemsize: Field>, + ndims: Field>, + shape: Field>>, + strides: Field>>, +} + +#[derive(Debug, Clone, Copy, Default)] +struct NDArray; + +impl<'ctx> StructKind<'ctx> for NDArray { + type Fields = NDArrayFields; + + fn iter_fields(&self, mapper: &mut FieldMapper<'ctx>) -> Self::Fields { + NDArrayFields { + data: mapper.add_auto("data"), + itemsize: mapper.add_auto("itemsize"), + ndims: mapper.add_auto("ndims"), + shape: mapper.add_auto("shape"), + strides: mapper.add_auto("strides"), + } + } +} + +// Example: List. +// +// Compared to NDArray, it has generic models. +pub struct ListFields<'ctx, Item: Model<'ctx>> { + items: Field>, + len: Field>, + _phantom: PhantomData<&'ctx ()>, +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct List<'ctx, Item: Model<'ctx>> { + item: Item, + _phantom: PhantomData<&'ctx ()>, +} + +impl<'ctx, Item: Model<'ctx> + 'ctx> StructKind<'ctx> for List<'ctx, Item> { + type Fields = ListFields<'ctx, Item>; + + fn iter_fields(&self, mapper: &mut FieldMapper<'ctx>) -> Self::Fields { + ListFields { + items: mapper.add("items", Ptr(self.item)), + len: mapper.add_auto("len"), + _phantom: PhantomData, + } + } +}