forked from M-Labs/nac3
1
0
Fork 0
This commit is contained in:
lyken 2024-08-28 11:47:57 +08:00
parent dfbbe66154
commit 4e714cb53b
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
7 changed files with 265 additions and 373 deletions

View File

@ -1,11 +1,9 @@
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
types::{BasicType, BasicTypeEnum, IntType},
values::IntValue,
};
use crate::codegen::CodeGenerator;
use super::*;
/// A [`Model`] of any [`BasicTypeEnum`].
@ -14,25 +12,17 @@ use super::*;
#[derive(Debug, Clone, Copy)]
pub struct Any<'ctx>(pub BasicTypeEnum<'ctx>);
impl<'ctx> Model<'ctx> for Any<'ctx> {
type Value = BasicValueEnum<'ctx>;
type Type = BasicTypeEnum<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> Self::Type {
self.0
impl<'ctx> ModelBase<'ctx> for Any<'ctx> {
fn get_type_impl(&self, _size_t: IntType<'ctx>, _ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
self.0.as_basic_type_enum()
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
fn check_type_impl(
&self,
_generator: &mut G,
_size_t: IntType<'ctx>,
_ctx: &'ctx Context,
ty: T,
ty: BasicTypeEnum<'ctx>,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
if ty == self.0 {
Ok(())
} else {
@ -40,3 +30,8 @@ impl<'ctx> Model<'ctx> for Any<'ctx> {
}
}
}
impl<'ctx> Model<'ctx> for Any<'ctx> {
type Type = IntType<'ctx>;
type Value = IntValue<'ctx>;
}

View File

@ -2,7 +2,7 @@ use std::fmt;
use inkwell::{
context::Context,
types::{ArrayType, BasicType, BasicTypeEnum},
types::{ArrayType, BasicType, BasicTypeEnum, IntType},
values::{ArrayValue, IntValue},
};
@ -46,21 +46,18 @@ pub struct Array<Len, Item> {
pub item: Item,
}
impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array<Len, Item> {
type Value = ArrayValue<'ctx>;
type Type = ArrayType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.item.get_type(generator, ctx).array_type(self.len.get_length())
impl<'ctx, Len: LenKind, Item: ModelBase<'ctx>> ModelBase<'ctx> for Array<Len, Item> {
fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
let item = self.item.get_type_impl(size_t, ctx);
item.array_type(self.len.get_length()).into()
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
fn check_type_impl(
&self,
generator: &mut G,
size_t: IntType<'ctx>,
ctx: &'ctx Context,
ty: T,
ty: BasicTypeEnum<'ctx>,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let BasicTypeEnum::ArrayType(ty) = ty else {
return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}")));
};
@ -74,13 +71,18 @@ impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array<Len, Item> {
}
self.item
.check_type(generator, ctx, ty.get_element_type())
.check_type_impl(size_t, ctx, ty.get_element_type())
.map_err(|err| err.under_context("an ArrayType"))?;
Ok(())
}
}
impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array<Len, Item> {
type Type = ArrayType<'ctx>;
type Value = ArrayValue<'ctx>;
}
impl<'ctx, Len: LenKind, Item: Model<'ctx>> Instance<'ctx, Ptr<Array<Len, Item>>> {
/// Get the pointer to the `i`-th (0-based) array element.
pub fn gep(

View File

@ -19,6 +19,23 @@ impl ModelError {
}
}
// NOTE: A watered down version of `Model` trait. Made to be object safe.
pub trait ModelBase<'ctx> {
// NOTE: Taking `size_t` here instead of `CodeGenerator` to be object safe.
// In fact, all the entire model abstraction need from the `CodeGenerator` is its `get_size_type()`.
// NOTE: Model's get_type but object-safe and returns BasicTypeEnum, instead of a known BasicType variant.
fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
// NOTE: Model's check_type but object-safe.
fn check_type_impl(
&self,
size_t: IntType<'ctx>,
ctx: &'ctx Context,
scrutinee: BasicTypeEnum<'ctx>,
) -> Result<(), ModelError>;
}
/// Trait for Rust structs identifying [`BasicType`]s in the context of a known [`CodeGenerator`] and [`CodeGenContext`].
///
/// For instance,
@ -59,16 +76,24 @@ impl ModelError {
/// // or, if you are absolutely certain that `my_value` is 32-bit and doing extra checks is a waste of time:
/// let my_value = Int(Int32).believe_value(my_value);
/// ```
pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
pub trait Model<'ctx>: fmt::Debug + Clone + Copy + ModelBase<'ctx> {
/// The [`BasicType`] *variant* this model is identifying.
type Type: BasicType<'ctx>;
type Type: BasicType<'ctx> + TryFrom<BasicTypeEnum<'ctx>>;
/// The [`BasicValue`] type of the [`BasicType`] of this model.
type Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>;
/// Return the [`BasicType`] of this model.
#[must_use]
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
let size_t = generator.get_size_type(ctx);
let ty = self.get_type_impl(size_t, ctx);
match Self::Type::try_from(ty) {
Ok(ty) => ty,
_ => panic!("Model::Type is inconsistent with what is returned from ModelBase::get_type_impl()! Got {ty:?}."),
}
}
/// Get the number of bytes of the [`BasicType`] of this model.
fn sizeof<G: CodeGenerator + ?Sized>(
@ -85,7 +110,10 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError>;
) -> Result<(), ModelError> {
let size_t = generator.get_size_type(ctx);
self.check_type_impl(size_t, ctx, ty.as_basic_type_enum())
}
/// Create an instance from a value.
///

View File

@ -2,20 +2,14 @@ use std::fmt;
use inkwell::{
context::Context,
types::{BasicType, FloatType},
types::{BasicTypeEnum, FloatType, IntType},
values::FloatValue,
};
use crate::codegen::CodeGenerator;
use super::*;
pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx>;
fn get_float_type(&self, ctx: &'ctx Context) -> FloatType<'ctx>;
}
#[derive(Debug, Clone, Copy, Default)]
@ -24,21 +18,13 @@ pub struct Float32;
pub struct Float64;
impl<'ctx> FloatKind<'ctx> for Float32 {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx> {
fn get_float_type(&self, ctx: &'ctx Context) -> FloatType<'ctx> {
ctx.f32_type()
}
}
impl<'ctx> FloatKind<'ctx> for Float64 {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx> {
fn get_float_type(&self, ctx: &'ctx Context) -> FloatType<'ctx> {
ctx.f64_type()
}
}
@ -47,11 +33,7 @@ impl<'ctx> FloatKind<'ctx> for Float64 {
pub struct AnyFloat<'ctx>(FloatType<'ctx>);
impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> FloatType<'ctx> {
fn get_float_type(&self, _ctx: &'ctx Context) -> FloatType<'ctx> {
self.0
}
}
@ -59,32 +41,31 @@ impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> {
#[derive(Debug, Clone, Copy, Default)]
pub struct Float<N>(pub N);
impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for Float<N> {
type Value = FloatValue<'ctx>;
type Type = FloatType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_float_type(generator, ctx)
impl<'ctx, N: FloatKind<'ctx>> ModelBase<'ctx> for Float<N> {
fn get_type_impl(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
self.0.get_float_type(ctx).into()
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
fn check_type_impl(
&self,
generator: &mut G,
_size_t: IntType<'ctx>,
ctx: &'ctx Context,
ty: T,
ty: BasicTypeEnum<'ctx>,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = FloatType::try_from(ty) else {
return Err(ModelError(format!("Expecting FloatType, but got {ty:?}")));
};
let exp_ty = self.0.get_float_type(generator, ctx);
// TODO: Inkwell does not have get_bit_width for FloatType?
if ty != exp_ty {
return Err(ModelError(format!("Expecting {exp_ty:?}, but got {ty:?}")));
let expected_ty = self.0.get_float_type(ctx);
if ty != expected_ty {
return Err(ModelError(format!("Expecting {expected_ty:?}, but got {ty:?}")));
}
Ok(())
}
}
impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for Float<N> {
type Value = FloatValue<'ctx>;
type Type = FloatType<'ctx>;
}

View File

@ -2,7 +2,7 @@ use std::{cmp::Ordering, fmt};
use inkwell::{
context::Context,
types::{BasicType, IntType},
types::{BasicTypeEnum, IntType},
values::IntValue,
IntPredicate,
};
@ -12,11 +12,7 @@ use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
pub trait IntKind<'ctx>: fmt::Debug + Clone + Copy {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx>;
fn get_int_type(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx>;
}
#[derive(Debug, Clone, Copy, Default)]
@ -31,52 +27,32 @@ pub struct Int64;
pub struct SizeT;
impl<'ctx> IntKind<'ctx> for Bool {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
fn get_int_type(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> {
ctx.bool_type()
}
}
impl<'ctx> IntKind<'ctx> for Byte {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
fn get_int_type(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> {
ctx.i8_type()
}
}
impl<'ctx> IntKind<'ctx> for Int32 {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
fn get_int_type(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> {
ctx.i32_type()
}
}
impl<'ctx> IntKind<'ctx> for Int64 {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
fn get_int_type(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> {
ctx.i64_type()
}
}
impl<'ctx> IntKind<'ctx> for SizeT {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
generator.get_size_type(ctx)
fn get_int_type(&self, size_t: IntType<'ctx>, _ctx: &'ctx Context) -> IntType<'ctx> {
size_t
}
}
@ -84,11 +60,7 @@ impl<'ctx> IntKind<'ctx> for SizeT {
pub struct AnyInt<'ctx>(pub IntType<'ctx>);
impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> IntType<'ctx> {
fn get_int_type(&self, _size_t: IntType<'ctx>, _ctx: &'ctx Context) -> IntType<'ctx> {
self.0
}
}
@ -96,26 +68,22 @@ impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> {
#[derive(Debug, Clone, Copy, Default)]
pub struct Int<N>(pub N);
impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int<N> {
type Value = IntValue<'ctx>;
type Type = IntType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_int_type(generator, ctx)
impl<'ctx, N: IntKind<'ctx>> ModelBase<'ctx> for Int<N> {
fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
self.0.get_int_type(size_t, ctx).into()
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
fn check_type_impl(
&self,
generator: &mut G,
size_t: IntType<'ctx>,
ctx: &'ctx Context,
ty: T,
ty: BasicTypeEnum<'ctx>,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = IntType::try_from(ty) else {
return Err(ModelError(format!("Expecting IntType, but got {ty:?}")));
};
let exp_ty = self.0.get_int_type(generator, ctx);
let exp_ty = self.0.get_int_type(size_t, ctx);
if ty.get_bit_width() != exp_ty.get_bit_width() {
return Err(ModelError(format!(
"Expecting IntType to have {} bit(s), but got {} bit(s)",
@ -128,6 +96,11 @@ impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int<N> {
}
}
impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int<N> {
type Type = IntType<'ctx>;
type Value = IntValue<'ctx>;
}
impl<'ctx, N: IntKind<'ctx>> Int<N> {
pub fn const_int<G: CodeGenerator + ?Sized>(
&self,
@ -173,7 +146,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
) -> Instance<'ctx, Self> {
assert!(
value.get_type().get_bit_width()
<= self.0.get_int_type(generator, ctx.ctx).get_bit_width()
<= self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width()
);
let value = ctx
.builder
@ -190,7 +163,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
) -> Instance<'ctx, Self> {
assert!(
value.get_type().get_bit_width()
< self.0.get_int_type(generator, ctx.ctx).get_bit_width()
< self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width()
);
let value =
ctx.builder.build_int_s_extend(value, self.get_type(generator, ctx.ctx), "").unwrap();
@ -205,7 +178,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
) -> Instance<'ctx, Self> {
assert!(
value.get_type().get_bit_width()
<= self.0.get_int_type(generator, ctx.ctx).get_bit_width()
<= self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width()
);
let value = ctx
.builder
@ -222,7 +195,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
) -> Instance<'ctx, Self> {
assert!(
value.get_type().get_bit_width()
< self.0.get_int_type(generator, ctx.ctx).get_bit_width()
< self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width()
);
let value =
ctx.builder.build_int_z_extend(value, self.get_type(generator, ctx.ctx), "").unwrap();
@ -237,7 +210,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
) -> Instance<'ctx, Self> {
assert!(
value.get_type().get_bit_width()
>= self.0.get_int_type(generator, ctx.ctx).get_bit_width()
>= self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width()
);
let value = ctx
.builder
@ -254,7 +227,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
) -> Instance<'ctx, Self> {
assert!(
value.get_type().get_bit_width()
> self.0.get_int_type(generator, ctx.ctx).get_bit_width()
> self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width()
);
let value =
ctx.builder.build_int_truncate(value, self.get_type(generator, ctx.ctx), "").unwrap();
@ -269,7 +242,8 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
value: IntValue<'ctx>,
) -> Instance<'ctx, Self> {
let their_width = value.get_type().get_bit_width();
let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width();
let our_width =
self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width();
match their_width.cmp(&our_width) {
Ordering::Less => self.s_extend(generator, ctx, value),
Ordering::Equal => self.believe_value(value),
@ -285,7 +259,8 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
value: IntValue<'ctx>,
) -> Instance<'ctx, Self> {
let their_width = value.get_type().get_bit_width();
let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width();
let our_width =
self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width();
match their_width.cmp(&our_width) {
Ordering::Less => self.z_extend(generator, ctx, value),
Ordering::Equal => self.believe_value(value),

View File

@ -1,6 +1,6 @@
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, PointerType},
types::{BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
@ -23,26 +23,23 @@ pub struct Ptr<Item>(pub Item);
/// `.load()/.store()` is not available for [`Instance`]s of opaque pointers.
pub type OpaquePtr = Ptr<()>;
// TODO: LLVM 15: `Item: Model<'ctx>` don't even need to be a model anymore. It will only be
// TODO: LLVM 15: `Item: ModelBase<'ctx>` don't even need to be a model anymore. It will only be
// a type-hint for the `.load()/.store()` functions for the `pointee_ty`.
//
// See https://thedan64.github.io/inkwell/inkwell/builder/struct.Builder.html#method.build_load.
impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr<Item> {
type Value = PointerValue<'ctx>;
type Type = PointerType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
impl<'ctx, Item: ModelBase<'ctx>> ModelBase<'ctx> for Ptr<Item> {
fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
// TODO: LLVM 15: ctx.ptr_type(AddressSpace::default())
self.0.get_type(generator, ctx).ptr_type(AddressSpace::default())
let item = self.0.get_type_impl(size_t, ctx);
item.ptr_type(AddressSpace::default()).into()
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
fn check_type_impl(
&self,
generator: &mut G,
size_t: IntType<'ctx>,
ctx: &'ctx Context,
ty: T,
ty: BasicTypeEnum<'ctx>,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = PointerType::try_from(ty) else {
return Err(ModelError(format!("Expecting PointerType, but got {ty:?}")));
};
@ -57,13 +54,18 @@ impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr<Item> {
// TODO: inkwell `get_element_type()` will be deprecated.
// Remove the check for `get_element_type()` when the time comes.
self.0
.check_type(generator, ctx, elem_ty)
.check_type_impl(size_t, ctx, elem_ty)
.map_err(|err| err.under_context("a PointerType"))?;
Ok(())
}
}
impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr<Item> {
type Type = PointerType<'ctx>;
type Value = PointerValue<'ctx>;
}
impl<'ctx, Item: Model<'ctx>> Ptr<Item> {
/// Return a ***constant*** nullptr.
pub fn nullptr<G: CodeGenerator + ?Sized>(
@ -71,6 +73,7 @@ impl<'ctx, Item: Model<'ctx>> Ptr<Item> {
generator: &mut G,
ctx: &'ctx Context,
) -> Instance<'ctx, Ptr<Item>> {
// TODO: LLVM 15: Write in an impl where `Item` does not have to be `Model<'ctx>`.
let ptr = self.get_type(generator, ctx).const_null();
self.believe_value(ptr)
}

View File

@ -1,290 +1,141 @@
use std::fmt;
use std::{fmt, marker::PhantomData};
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, StructType},
types::{BasicType, BasicTypeEnum, IntType, StructType},
values::{BasicValueEnum, StructValue},
};
use itertools::{izip, Itertools};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
/// A traveral that traverses a Rust `struct` that is used to declare an LLVM's struct's field types.
pub trait FieldTraversal<'ctx> {
/// Output type of [`FieldTraversal::add`].
type Out<M>;
// pub trait StructKind2<'ctx>: fmt::Debug + Clone + Copy {
// type Fields<F: FieldTraversal2<'ctx>> = ;
// }
/// Traverse through the type of a declared field and do something with it.
///
/// * `name` - The cosmetic name of the LLVM field. Used for debugging.
/// * `model` - The [`Model`] representing the LLVM type of this field.
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M>;
pub struct Field<M> {
gep_index: u32,
model: M,
name: &'static str,
}
/// Like [`FieldTraversal::add`] but [`Model`] is automatically inferred from its [`Default`] trait.
fn add_auto<M: Model<'ctx> + Default>(&mut self, name: &'static str) -> Self::Out<M> {
// NOTE: Very similar to Field, but is forall on `M`, (and also uses ModelBase to get object safety for the `Box<dyn ____>`.
pub struct Entry<'ctx> {
model: Box<dyn ModelBase<'ctx> + 'ctx>,
name: &'static str,
}
pub struct FieldMapper<'ctx> {
gep_index_counter: u32,
entries: Vec<Entry<'ctx>>,
}
impl<'ctx> FieldMapper<'ctx> {
fn add<M: 'ctx + Model<'ctx>>(&mut self, name: &'static str, model: M) -> Field<M> {
let entry = Entry { model: Box::new(model), name };
self.entries.push(entry);
let gep_index = self.gep_index_counter;
self.gep_index_counter += 1;
Field { gep_index, model, name }
}
fn add_auto<M: 'ctx + Model<'ctx> + Default>(&mut self, name: &'static str) -> Field<M> {
self.add(name, M::default())
}
}
/// Descriptor of an LLVM struct field.
#[derive(Debug, Clone, Copy)]
pub struct GepField<M> {
/// The GEP index of this field. This is the index to use with `build_gep`.
pub gep_index: u64,
/// The cosmetic name of this field.
pub name: &'static str,
/// The [`Model`] of this field's type.
pub model: M,
}
/// A traversal to calculate the GEP index of fields.
pub struct GepFieldTraversal {
/// The current GEP index.
gep_index_counter: u64,
}
impl<'ctx> FieldTraversal<'ctx> for GepFieldTraversal {
type Out<M> = GepField<M>;
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M> {
let gep_index = self.gep_index_counter;
self.gep_index_counter += 1;
Self::Out { gep_index, name, model }
}
}
/// A traversal to collect the field types of a struct.
///
/// This is used to collect field types and construct the LLVM struct type with [`Context::struct_type`].
struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
generator: &'a G,
ctx: &'ctx Context,
/// The collected field types so far in exact order.
field_types: Vec<BasicTypeEnum<'ctx>>,
}
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> {
type Out<M> = (); // Checking types return nothing.
fn add<M: Model<'ctx>>(&mut self, _name: &'static str, model: M) -> Self::Out<M> {
let t = model.get_type(self.generator, self.ctx).as_basic_type_enum();
self.field_types.push(t);
}
}
/// A traversal to check the types of fields.
struct CheckTypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
generator: &'a mut G,
ctx: &'ctx Context,
/// The current GEP index, so we can tell the index of the field we are checking
/// and report the GEP index.
gep_index_counter: u32,
/// The [`StructType`] to check.
scrutinee: StructType<'ctx>,
/// The list of collected errors so far.
errors: Vec<ModelError>,
}
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx>
for CheckTypeFieldTraversal<'ctx, 'a, G>
{
type Out<M> = (); // Checking types return nothing.
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M> {
let gep_index = self.gep_index_counter;
self.gep_index_counter += 1;
if let Some(t) = self.scrutinee.get_field_type_at_index(gep_index) {
if let Err(err) = model.check_type(self.generator, self.ctx, t) {
self.errors
.push(err.under_context(format!("field #{gep_index} '{name}'").as_str()));
}
} // Otherwise, it will be caught by Struct's `check_type`.
}
}
/// A trait for Rust structs identifying LLVM structures.
///
/// ### Example
///
/// Suppose you want to define this structure:
/// ```c
/// template <typename T>
/// struct ContiguousNDArray {
/// size_t ndims;
/// size_t* shape;
/// T* data;
/// }
/// ```
///
/// This is how it should be done:
/// ```ignore
/// pub struct ContiguousNDArrayFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> {
/// pub ndims: F::Out<Int<SizeT>>,
/// pub shape: F::Out<Ptr<Int<SizeT>>>,
/// pub data: F::Out<Ptr<Item>>,
/// }
///
/// /// An ndarray without strides and non-opaque `data` field in NAC3.
/// #[derive(Debug, Clone, Copy)]
/// pub struct ContiguousNDArray<M> {
/// /// [`Model`] of the items.
/// pub item: M,
/// }
///
/// impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for ContiguousNDArray<Item> {
/// type Fields<F: FieldTraversal<'ctx>> = ContiguousNDArrayFields<'ctx, F, Item>;
///
/// fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
/// // The order of `traversal.add*` is important
/// Self::Fields {
/// ndims: traversal.add_auto("ndims"),
/// shape: traversal.add_auto("shape"),
/// data: traversal.add("data", Ptr(self.item)),
/// }
/// }
/// }
/// ```
///
/// The [`FieldTraversal`] here is a mechanism to allow the fields of `ContiguousNDArrayFields` to be
/// traversed to do useful work such as:
///
/// - To create the [`StructType`] of `ContiguousNDArray` by collecting [`BasicType`]s of the fields.
/// - To enable the `.gep(ctx, |f| f.ndims).store(ctx, ...)` syntax.
///
/// Suppose now that you have defined `ContiguousNDArray` and you want to allocate a `ContiguousNDArray`
/// with dtype `float64` in LLVM, this is how you do it:
/// ```ignore
/// type F64NDArray = Struct<ContiguousNDArray<Float<Float64>>>; // Type alias for leaner documentation
/// let model: F64NDArray = Struct(ContigousNDArray { item: Float(Float64) });
/// let ndarray: Instance<'ctx, Ptr<F64NDArray>> = model.alloca(generator, ctx);
/// ```
///
/// ...and here is how you may manipulate/access `ndarray`:
///
/// (NOTE: some arguments have been omitted)
///
/// ```ignore
/// // Get `&ndarray->data`
/// ndarray.gep(|f| f.data); // type: Instance<'ctx, Ptr<Float<Float64>>>
///
/// // Get `ndarray->ndims`
/// ndarray.get(|f| f.ndims); // type: Instance<'ctx, Int<SizeT>>
///
/// // Get `&ndarray->ndims`
/// ndarray.gep(|f| f.ndims); // type: Instance<'ctx, Ptr<Int<SizeT>>>
///
/// // Get `ndarray->shape[0]`
/// ndarray.get(|f| f.shape).get_index_const(0); // Instance<'ctx, Int<SizeT>>
///
/// // Get `&ndarray->shape[2]`
/// ndarray.get(|f| f.shape).offset_const(2); // Instance<'ctx, Ptr<Int<SizeT>>>
///
/// // Do `ndarray->ndims = 3;`
/// let num_3 = Int(SizeT).const_int(3);
/// ndarray.set(|f| f.ndims, num_3);
/// ```
pub trait StructKind<'ctx>: fmt::Debug + Clone + Copy {
/// The associated fields of this struct.
type Fields<F: FieldTraversal<'ctx>>;
type Fields;
/// Traverse through all fields of this [`StructKind`].
///
/// Only used internally in this module for implementing other components.
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F>;
fn iter_fields(&self, mapper: &mut FieldMapper<'ctx>) -> Self::Fields;
/// Get a convenience structure to get a struct field's GEP index through its corresponding Rust field.
///
/// Only used internally in this module for implementing other components.
fn fields(&self) -> Self::Fields<GepFieldTraversal> {
self.traverse_fields(&mut GepFieldTraversal { gep_index_counter: 0 })
// Produce `Vec<Entry>` and `Self::Fields` simultaneously.
// The former is for doing field-wise type checks.
// The latter is for enabling the `.gep(|f| f.data)` syntax.
fn entries_and_fields(&self) -> (Vec<Entry<'ctx>>, Self::Fields) {
let mut mapper = FieldMapper { gep_index_counter: 0, entries: Vec::new() };
let fields = self.iter_fields(&mut mapper);
(mapper.entries, fields)
}
fn entries(&self) -> Vec<Entry<'ctx>> {
self.entries_and_fields().0
}
fn fields(&self) -> Self::Fields {
self.entries_and_fields().1
}
/// Get the LLVM [`StructType`] of this [`StructKind`].
fn get_struct_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> StructType<'ctx> {
let mut traversal = TypeFieldTraversal { generator, ctx, field_types: Vec::new() };
self.traverse_fields(&mut traversal);
ctx.struct_type(&traversal.field_types, false)
fn get_struct_type(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> StructType<'ctx> {
let entries = self.entries();
let entries = entries.into_iter().map(|t| t.model.get_type_impl(size_t, ctx)).collect_vec();
ctx.struct_type(&entries, false)
}
}
/// A model for LLVM struct.
///
/// `S` should be of a [`StructKind`].
#[derive(Debug, Clone, Copy, Default)]
pub struct Struct<S>(pub S);
impl<'ctx, S: StructKind<'ctx>> Struct<S> {
/// Create a constant struct value from its fields.
///
/// This function also validates `fields` and panic when there is something wrong.
pub fn const_struct<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
fields: &[BasicValueEnum<'ctx>],
) -> Instance<'ctx, Self> {
// NOTE: There *could* have been a functor `F<M> = Instance<'ctx, M>` for `S::Fields<F>`
// to create a more user-friendly interface, but Rust's type system is not sophisticated enough
// and if you try doing that Rust would force you put lifetimes everywhere.
let val = ctx.const_struct(fields, false);
self.check_value(generator, ctx, val).unwrap()
}
}
impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for Struct<S> {
type Value = StructValue<'ctx>;
type Type = StructType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_struct_type(generator, ctx)
impl<'ctx, S: StructKind<'ctx>> ModelBase<'ctx> for Struct<S> {
fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
self.0.get_struct_type(size_t, ctx).as_basic_type_enum()
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
fn check_type_impl(
&self,
generator: &mut G,
size_t: IntType<'ctx>,
ctx: &'ctx Context,
ty: T,
ty: BasicTypeEnum<'ctx>,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = StructType::try_from(ty) else {
return Err(ModelError(format!("Expecting StructType, but got {ty:?}")));
};
// Check each field individually.
let mut traversal = CheckTypeFieldTraversal {
generator,
ctx,
gep_index_counter: 0,
errors: Vec::new(),
scrutinee: ty,
};
self.0.traverse_fields(&mut traversal);
let entries = self.0.entries();
let field_types = ty.get_field_types();
// Check the number of fields.
let exp_num_fields = traversal.gep_index_counter;
let got_num_fields = u32::try_from(ty.get_field_types().len()).unwrap();
if exp_num_fields != got_num_fields {
if entries.len() != field_types.len() {
return Err(ModelError(format!(
"Expecting StructType with {exp_num_fields} field(s), but got {got_num_fields}"
"Expecting StructType with {} field(s), but got {}",
entries.len(),
field_types.len()
)));
}
if !traversal.errors.is_empty() {
// Currently, only the first error is reported.
return Err(traversal.errors[0].clone());
// Check each field.
for (i, (entry, field_type)) in izip!(entries, field_types).enumerate() {
entry.model.check_type_impl(size_t, ctx, field_type).map_err(|err| {
let context = &format!("in field #{i} '{}'", entry.name);
err.under_context(context)
})?;
}
Ok(())
}
}
impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for Struct<S> {
type Type = StructType<'ctx>;
type Value = StructValue<'ctx>;
}
impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Struct<S>> {
/// Get a field with [`StructValue::get_field_at_index`].
pub fn get_field<G: CodeGenerator + ?Sized, M, GetField>(
@ -295,10 +146,10 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Struct<S>> {
) -> Instance<'ctx, M>
where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
GetField: FnOnce(S::Fields) -> Field<M>,
{
let field = get_field(self.model.0.fields());
let val = self.value.get_field_at_index(field.gep_index as u32).unwrap();
let val = self.value.get_field_at_index(field.gep_index).unwrap();
field.model.check_value(generator, ctx, val).unwrap()
}
}
@ -312,7 +163,7 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr<Struct<S>>> {
) -> Instance<'ctx, Ptr<M>>
where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
GetField: FnOnce(S::Fields) -> Field<M>,
{
let field = get_field(self.model.0 .0.fields());
let llvm_i32 = ctx.ctx.i32_type();
@ -321,7 +172,7 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr<Struct<S>>> {
ctx.builder
.build_in_bounds_gep(
self.value,
&[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)],
&[llvm_i32.const_zero(), llvm_i32.const_int(u64::from(field.gep_index), false)],
field.name,
)
.unwrap()
@ -339,7 +190,7 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr<Struct<S>>> {
) -> Instance<'ctx, M>
where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
GetField: FnOnce(S::Fields) -> Field<M>,
{
self.gep(ctx, get_field).load(generator, ctx)
}
@ -352,8 +203,65 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr<Struct<S>>> {
value: Instance<'ctx, M>,
) where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
GetField: FnOnce(S::Fields) -> Field<M>,
{
self.gep(ctx, get_field).store(ctx, value);
}
}
/////////////////////// Example; Delete later
// Example: NDArray.
//
// Compared to List, it has no generic models.
pub struct NDArrayFields {
data: Field<Ptr<Int<Byte>>>,
itemsize: Field<Int<SizeT>>,
ndims: Field<Int<SizeT>>,
shape: Field<Ptr<Int<SizeT>>>,
strides: Field<Ptr<Int<SizeT>>>,
}
#[derive(Debug, Clone, Copy, Default)]
struct NDArray;
impl<'ctx> StructKind<'ctx> for NDArray {
type Fields = NDArrayFields;
fn iter_fields(&self, mapper: &mut FieldMapper<'ctx>) -> Self::Fields {
NDArrayFields {
data: mapper.add_auto("data"),
itemsize: mapper.add_auto("itemsize"),
ndims: mapper.add_auto("ndims"),
shape: mapper.add_auto("shape"),
strides: mapper.add_auto("strides"),
}
}
}
// Example: List.
//
// Compared to NDArray, it has generic models.
pub struct ListFields<'ctx, Item: Model<'ctx>> {
items: Field<Ptr<Item>>,
len: Field<Int<SizeT>>,
_phantom: PhantomData<&'ctx ()>,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct List<'ctx, Item: Model<'ctx>> {
item: Item,
_phantom: PhantomData<&'ctx ()>,
}
impl<'ctx, Item: Model<'ctx> + 'ctx> StructKind<'ctx> for List<'ctx, Item> {
type Fields = ListFields<'ctx, Item>;
fn iter_fields(&self, mapper: &mut FieldMapper<'ctx>) -> Self::Fields {
ListFields {
items: mapper.add("items", Ptr(self.item)),
len: mapper.add_auto("len"),
_phantom: PhantomData,
}
}
}