redesign
This commit is contained in:
parent
dfbbe66154
commit
4e714cb53b
|
@ -1,11 +1,9 @@
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::Context,
|
||||||
types::{BasicType, BasicTypeEnum},
|
types::{BasicType, BasicTypeEnum, IntType},
|
||||||
values::BasicValueEnum,
|
values::IntValue,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::codegen::CodeGenerator;
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
/// A [`Model`] of any [`BasicTypeEnum`].
|
/// A [`Model`] of any [`BasicTypeEnum`].
|
||||||
|
@ -14,25 +12,17 @@ use super::*;
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct Any<'ctx>(pub BasicTypeEnum<'ctx>);
|
pub struct Any<'ctx>(pub BasicTypeEnum<'ctx>);
|
||||||
|
|
||||||
impl<'ctx> Model<'ctx> for Any<'ctx> {
|
impl<'ctx> ModelBase<'ctx> for Any<'ctx> {
|
||||||
type Value = BasicValueEnum<'ctx>;
|
fn get_type_impl(&self, _size_t: IntType<'ctx>, _ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
type Type = BasicTypeEnum<'ctx>;
|
self.0.as_basic_type_enum()
|
||||||
|
|
||||||
fn get_type<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
_ctx: &'ctx Context,
|
|
||||||
) -> Self::Type {
|
|
||||||
self.0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
fn check_type_impl(
|
||||||
&self,
|
&self,
|
||||||
_generator: &mut G,
|
_size_t: IntType<'ctx>,
|
||||||
_ctx: &'ctx Context,
|
_ctx: &'ctx Context,
|
||||||
ty: T,
|
ty: BasicTypeEnum<'ctx>,
|
||||||
) -> Result<(), ModelError> {
|
) -> Result<(), ModelError> {
|
||||||
let ty = ty.as_basic_type_enum();
|
|
||||||
if ty == self.0 {
|
if ty == self.0 {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
|
@ -40,3 +30,8 @@ impl<'ctx> Model<'ctx> for Any<'ctx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ctx> Model<'ctx> for Any<'ctx> {
|
||||||
|
type Type = IntType<'ctx>;
|
||||||
|
type Value = IntValue<'ctx>;
|
||||||
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@ use std::fmt;
|
||||||
|
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::Context,
|
||||||
types::{ArrayType, BasicType, BasicTypeEnum},
|
types::{ArrayType, BasicType, BasicTypeEnum, IntType},
|
||||||
values::{ArrayValue, IntValue},
|
values::{ArrayValue, IntValue},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -46,21 +46,18 @@ pub struct Array<Len, Item> {
|
||||||
pub item: Item,
|
pub item: Item,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array<Len, Item> {
|
impl<'ctx, Len: LenKind, Item: ModelBase<'ctx>> ModelBase<'ctx> for Array<Len, Item> {
|
||||||
type Value = ArrayValue<'ctx>;
|
fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
type Type = ArrayType<'ctx>;
|
let item = self.item.get_type_impl(size_t, ctx);
|
||||||
|
item.array_type(self.len.get_length()).into()
|
||||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
|
||||||
self.item.get_type(generator, ctx).array_type(self.len.get_length())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
fn check_type_impl(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
size_t: IntType<'ctx>,
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
ty: T,
|
ty: BasicTypeEnum<'ctx>,
|
||||||
) -> Result<(), ModelError> {
|
) -> Result<(), ModelError> {
|
||||||
let ty = ty.as_basic_type_enum();
|
|
||||||
let BasicTypeEnum::ArrayType(ty) = ty else {
|
let BasicTypeEnum::ArrayType(ty) = ty else {
|
||||||
return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}")));
|
return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}")));
|
||||||
};
|
};
|
||||||
|
@ -74,13 +71,18 @@ impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array<Len, Item> {
|
||||||
}
|
}
|
||||||
|
|
||||||
self.item
|
self.item
|
||||||
.check_type(generator, ctx, ty.get_element_type())
|
.check_type_impl(size_t, ctx, ty.get_element_type())
|
||||||
.map_err(|err| err.under_context("an ArrayType"))?;
|
.map_err(|err| err.under_context("an ArrayType"))?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array<Len, Item> {
|
||||||
|
type Type = ArrayType<'ctx>;
|
||||||
|
type Value = ArrayValue<'ctx>;
|
||||||
|
}
|
||||||
|
|
||||||
impl<'ctx, Len: LenKind, Item: Model<'ctx>> Instance<'ctx, Ptr<Array<Len, Item>>> {
|
impl<'ctx, Len: LenKind, Item: Model<'ctx>> Instance<'ctx, Ptr<Array<Len, Item>>> {
|
||||||
/// Get the pointer to the `i`-th (0-based) array element.
|
/// Get the pointer to the `i`-th (0-based) array element.
|
||||||
pub fn gep(
|
pub fn gep(
|
||||||
|
|
|
@ -19,6 +19,23 @@ impl ModelError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: A watered down version of `Model` trait. Made to be object safe.
|
||||||
|
pub trait ModelBase<'ctx> {
|
||||||
|
// NOTE: Taking `size_t` here instead of `CodeGenerator` to be object safe.
|
||||||
|
// In fact, all the entire model abstraction need from the `CodeGenerator` is its `get_size_type()`.
|
||||||
|
|
||||||
|
// NOTE: Model's get_type but object-safe and returns BasicTypeEnum, instead of a known BasicType variant.
|
||||||
|
fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
|
||||||
|
|
||||||
|
// NOTE: Model's check_type but object-safe.
|
||||||
|
fn check_type_impl(
|
||||||
|
&self,
|
||||||
|
size_t: IntType<'ctx>,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
scrutinee: BasicTypeEnum<'ctx>,
|
||||||
|
) -> Result<(), ModelError>;
|
||||||
|
}
|
||||||
|
|
||||||
/// Trait for Rust structs identifying [`BasicType`]s in the context of a known [`CodeGenerator`] and [`CodeGenContext`].
|
/// Trait for Rust structs identifying [`BasicType`]s in the context of a known [`CodeGenerator`] and [`CodeGenContext`].
|
||||||
///
|
///
|
||||||
/// For instance,
|
/// For instance,
|
||||||
|
@ -59,16 +76,24 @@ impl ModelError {
|
||||||
/// // or, if you are absolutely certain that `my_value` is 32-bit and doing extra checks is a waste of time:
|
/// // or, if you are absolutely certain that `my_value` is 32-bit and doing extra checks is a waste of time:
|
||||||
/// let my_value = Int(Int32).believe_value(my_value);
|
/// let my_value = Int(Int32).believe_value(my_value);
|
||||||
/// ```
|
/// ```
|
||||||
pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
|
pub trait Model<'ctx>: fmt::Debug + Clone + Copy + ModelBase<'ctx> {
|
||||||
/// The [`BasicType`] *variant* this model is identifying.
|
/// The [`BasicType`] *variant* this model is identifying.
|
||||||
type Type: BasicType<'ctx>;
|
type Type: BasicType<'ctx> + TryFrom<BasicTypeEnum<'ctx>>;
|
||||||
|
|
||||||
/// The [`BasicValue`] type of the [`BasicType`] of this model.
|
/// The [`BasicValue`] type of the [`BasicType`] of this model.
|
||||||
type Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>;
|
type Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>;
|
||||||
|
|
||||||
/// Return the [`BasicType`] of this model.
|
/// Return the [`BasicType`] of this model.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type;
|
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
||||||
|
let size_t = generator.get_size_type(ctx);
|
||||||
|
|
||||||
|
let ty = self.get_type_impl(size_t, ctx);
|
||||||
|
match Self::Type::try_from(ty) {
|
||||||
|
Ok(ty) => ty,
|
||||||
|
_ => panic!("Model::Type is inconsistent with what is returned from ModelBase::get_type_impl()! Got {ty:?}."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the number of bytes of the [`BasicType`] of this model.
|
/// Get the number of bytes of the [`BasicType`] of this model.
|
||||||
fn sizeof<G: CodeGenerator + ?Sized>(
|
fn sizeof<G: CodeGenerator + ?Sized>(
|
||||||
|
@ -85,7 +110,10 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
ty: T,
|
ty: T,
|
||||||
) -> Result<(), ModelError>;
|
) -> Result<(), ModelError> {
|
||||||
|
let size_t = generator.get_size_type(ctx);
|
||||||
|
self.check_type_impl(size_t, ctx, ty.as_basic_type_enum())
|
||||||
|
}
|
||||||
|
|
||||||
/// Create an instance from a value.
|
/// Create an instance from a value.
|
||||||
///
|
///
|
||||||
|
|
|
@ -2,20 +2,14 @@ use std::fmt;
|
||||||
|
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::Context,
|
||||||
types::{BasicType, FloatType},
|
types::{BasicTypeEnum, FloatType, IntType},
|
||||||
values::FloatValue,
|
values::FloatValue,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::codegen::CodeGenerator;
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy {
|
pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy {
|
||||||
fn get_float_type<G: CodeGenerator + ?Sized>(
|
fn get_float_type(&self, ctx: &'ctx Context) -> FloatType<'ctx>;
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> FloatType<'ctx>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
@ -24,21 +18,13 @@ pub struct Float32;
|
||||||
pub struct Float64;
|
pub struct Float64;
|
||||||
|
|
||||||
impl<'ctx> FloatKind<'ctx> for Float32 {
|
impl<'ctx> FloatKind<'ctx> for Float32 {
|
||||||
fn get_float_type<G: CodeGenerator + ?Sized>(
|
fn get_float_type(&self, ctx: &'ctx Context) -> FloatType<'ctx> {
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> FloatType<'ctx> {
|
|
||||||
ctx.f32_type()
|
ctx.f32_type()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> FloatKind<'ctx> for Float64 {
|
impl<'ctx> FloatKind<'ctx> for Float64 {
|
||||||
fn get_float_type<G: CodeGenerator + ?Sized>(
|
fn get_float_type(&self, ctx: &'ctx Context) -> FloatType<'ctx> {
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> FloatType<'ctx> {
|
|
||||||
ctx.f64_type()
|
ctx.f64_type()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,11 +33,7 @@ impl<'ctx> FloatKind<'ctx> for Float64 {
|
||||||
pub struct AnyFloat<'ctx>(FloatType<'ctx>);
|
pub struct AnyFloat<'ctx>(FloatType<'ctx>);
|
||||||
|
|
||||||
impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> {
|
impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> {
|
||||||
fn get_float_type<G: CodeGenerator + ?Sized>(
|
fn get_float_type(&self, _ctx: &'ctx Context) -> FloatType<'ctx> {
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
_ctx: &'ctx Context,
|
|
||||||
) -> FloatType<'ctx> {
|
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -59,32 +41,31 @@ impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> {
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
pub struct Float<N>(pub N);
|
pub struct Float<N>(pub N);
|
||||||
|
|
||||||
impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for Float<N> {
|
impl<'ctx, N: FloatKind<'ctx>> ModelBase<'ctx> for Float<N> {
|
||||||
type Value = FloatValue<'ctx>;
|
fn get_type_impl(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
type Type = FloatType<'ctx>;
|
self.0.get_float_type(ctx).into()
|
||||||
|
|
||||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
|
||||||
self.0.get_float_type(generator, ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
fn check_type_impl(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
_size_t: IntType<'ctx>,
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
ty: T,
|
ty: BasicTypeEnum<'ctx>,
|
||||||
) -> Result<(), ModelError> {
|
) -> Result<(), ModelError> {
|
||||||
let ty = ty.as_basic_type_enum();
|
|
||||||
let Ok(ty) = FloatType::try_from(ty) else {
|
let Ok(ty) = FloatType::try_from(ty) else {
|
||||||
return Err(ModelError(format!("Expecting FloatType, but got {ty:?}")));
|
return Err(ModelError(format!("Expecting FloatType, but got {ty:?}")));
|
||||||
};
|
};
|
||||||
|
|
||||||
let exp_ty = self.0.get_float_type(generator, ctx);
|
let expected_ty = self.0.get_float_type(ctx);
|
||||||
|
if ty != expected_ty {
|
||||||
// TODO: Inkwell does not have get_bit_width for FloatType?
|
return Err(ModelError(format!("Expecting {expected_ty:?}, but got {ty:?}")));
|
||||||
if ty != exp_ty {
|
|
||||||
return Err(ModelError(format!("Expecting {exp_ty:?}, but got {ty:?}")));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for Float<N> {
|
||||||
|
type Value = FloatValue<'ctx>;
|
||||||
|
type Type = FloatType<'ctx>;
|
||||||
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@ use std::{cmp::Ordering, fmt};
|
||||||
|
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::Context,
|
||||||
types::{BasicType, IntType},
|
types::{BasicTypeEnum, IntType},
|
||||||
values::IntValue,
|
values::IntValue,
|
||||||
IntPredicate,
|
IntPredicate,
|
||||||
};
|
};
|
||||||
|
@ -12,11 +12,7 @@ use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
pub trait IntKind<'ctx>: fmt::Debug + Clone + Copy {
|
pub trait IntKind<'ctx>: fmt::Debug + Clone + Copy {
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
fn get_int_type(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx>;
|
||||||
&self,
|
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
@ -31,52 +27,32 @@ pub struct Int64;
|
||||||
pub struct SizeT;
|
pub struct SizeT;
|
||||||
|
|
||||||
impl<'ctx> IntKind<'ctx> for Bool {
|
impl<'ctx> IntKind<'ctx> for Bool {
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
fn get_int_type(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> {
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx> {
|
|
||||||
ctx.bool_type()
|
ctx.bool_type()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> IntKind<'ctx> for Byte {
|
impl<'ctx> IntKind<'ctx> for Byte {
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
fn get_int_type(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> {
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx> {
|
|
||||||
ctx.i8_type()
|
ctx.i8_type()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> IntKind<'ctx> for Int32 {
|
impl<'ctx> IntKind<'ctx> for Int32 {
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
fn get_int_type(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> {
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx> {
|
|
||||||
ctx.i32_type()
|
ctx.i32_type()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> IntKind<'ctx> for Int64 {
|
impl<'ctx> IntKind<'ctx> for Int64 {
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
fn get_int_type(&self, _size_t: IntType<'ctx>, ctx: &'ctx Context) -> IntType<'ctx> {
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx> {
|
|
||||||
ctx.i64_type()
|
ctx.i64_type()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> IntKind<'ctx> for SizeT {
|
impl<'ctx> IntKind<'ctx> for SizeT {
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
fn get_int_type(&self, size_t: IntType<'ctx>, _ctx: &'ctx Context) -> IntType<'ctx> {
|
||||||
&self,
|
size_t
|
||||||
generator: &G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx> {
|
|
||||||
generator.get_size_type(ctx)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,11 +60,7 @@ impl<'ctx> IntKind<'ctx> for SizeT {
|
||||||
pub struct AnyInt<'ctx>(pub IntType<'ctx>);
|
pub struct AnyInt<'ctx>(pub IntType<'ctx>);
|
||||||
|
|
||||||
impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> {
|
impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> {
|
||||||
fn get_int_type<G: CodeGenerator + ?Sized>(
|
fn get_int_type(&self, _size_t: IntType<'ctx>, _ctx: &'ctx Context) -> IntType<'ctx> {
|
||||||
&self,
|
|
||||||
_generator: &G,
|
|
||||||
_ctx: &'ctx Context,
|
|
||||||
) -> IntType<'ctx> {
|
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -96,26 +68,22 @@ impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> {
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
pub struct Int<N>(pub N);
|
pub struct Int<N>(pub N);
|
||||||
|
|
||||||
impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int<N> {
|
impl<'ctx, N: IntKind<'ctx>> ModelBase<'ctx> for Int<N> {
|
||||||
type Value = IntValue<'ctx>;
|
fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
type Type = IntType<'ctx>;
|
self.0.get_int_type(size_t, ctx).into()
|
||||||
|
|
||||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
|
||||||
self.0.get_int_type(generator, ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
fn check_type_impl(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
size_t: IntType<'ctx>,
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
ty: T,
|
ty: BasicTypeEnum<'ctx>,
|
||||||
) -> Result<(), ModelError> {
|
) -> Result<(), ModelError> {
|
||||||
let ty = ty.as_basic_type_enum();
|
|
||||||
let Ok(ty) = IntType::try_from(ty) else {
|
let Ok(ty) = IntType::try_from(ty) else {
|
||||||
return Err(ModelError(format!("Expecting IntType, but got {ty:?}")));
|
return Err(ModelError(format!("Expecting IntType, but got {ty:?}")));
|
||||||
};
|
};
|
||||||
|
|
||||||
let exp_ty = self.0.get_int_type(generator, ctx);
|
let exp_ty = self.0.get_int_type(size_t, ctx);
|
||||||
if ty.get_bit_width() != exp_ty.get_bit_width() {
|
if ty.get_bit_width() != exp_ty.get_bit_width() {
|
||||||
return Err(ModelError(format!(
|
return Err(ModelError(format!(
|
||||||
"Expecting IntType to have {} bit(s), but got {} bit(s)",
|
"Expecting IntType to have {} bit(s), but got {} bit(s)",
|
||||||
|
@ -128,6 +96,11 @@ impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int<N> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int<N> {
|
||||||
|
type Type = IntType<'ctx>;
|
||||||
|
type Value = IntValue<'ctx>;
|
||||||
|
}
|
||||||
|
|
||||||
impl<'ctx, N: IntKind<'ctx>> Int<N> {
|
impl<'ctx, N: IntKind<'ctx>> Int<N> {
|
||||||
pub fn const_int<G: CodeGenerator + ?Sized>(
|
pub fn const_int<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
|
@ -173,7 +146,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
|
||||||
) -> Instance<'ctx, Self> {
|
) -> Instance<'ctx, Self> {
|
||||||
assert!(
|
assert!(
|
||||||
value.get_type().get_bit_width()
|
value.get_type().get_bit_width()
|
||||||
<= self.0.get_int_type(generator, ctx.ctx).get_bit_width()
|
<= self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width()
|
||||||
);
|
);
|
||||||
let value = ctx
|
let value = ctx
|
||||||
.builder
|
.builder
|
||||||
|
@ -190,7 +163,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
|
||||||
) -> Instance<'ctx, Self> {
|
) -> Instance<'ctx, Self> {
|
||||||
assert!(
|
assert!(
|
||||||
value.get_type().get_bit_width()
|
value.get_type().get_bit_width()
|
||||||
< self.0.get_int_type(generator, ctx.ctx).get_bit_width()
|
< self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width()
|
||||||
);
|
);
|
||||||
let value =
|
let value =
|
||||||
ctx.builder.build_int_s_extend(value, self.get_type(generator, ctx.ctx), "").unwrap();
|
ctx.builder.build_int_s_extend(value, self.get_type(generator, ctx.ctx), "").unwrap();
|
||||||
|
@ -205,7 +178,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
|
||||||
) -> Instance<'ctx, Self> {
|
) -> Instance<'ctx, Self> {
|
||||||
assert!(
|
assert!(
|
||||||
value.get_type().get_bit_width()
|
value.get_type().get_bit_width()
|
||||||
<= self.0.get_int_type(generator, ctx.ctx).get_bit_width()
|
<= self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width()
|
||||||
);
|
);
|
||||||
let value = ctx
|
let value = ctx
|
||||||
.builder
|
.builder
|
||||||
|
@ -222,7 +195,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
|
||||||
) -> Instance<'ctx, Self> {
|
) -> Instance<'ctx, Self> {
|
||||||
assert!(
|
assert!(
|
||||||
value.get_type().get_bit_width()
|
value.get_type().get_bit_width()
|
||||||
< self.0.get_int_type(generator, ctx.ctx).get_bit_width()
|
< self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width()
|
||||||
);
|
);
|
||||||
let value =
|
let value =
|
||||||
ctx.builder.build_int_z_extend(value, self.get_type(generator, ctx.ctx), "").unwrap();
|
ctx.builder.build_int_z_extend(value, self.get_type(generator, ctx.ctx), "").unwrap();
|
||||||
|
@ -237,7 +210,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
|
||||||
) -> Instance<'ctx, Self> {
|
) -> Instance<'ctx, Self> {
|
||||||
assert!(
|
assert!(
|
||||||
value.get_type().get_bit_width()
|
value.get_type().get_bit_width()
|
||||||
>= self.0.get_int_type(generator, ctx.ctx).get_bit_width()
|
>= self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width()
|
||||||
);
|
);
|
||||||
let value = ctx
|
let value = ctx
|
||||||
.builder
|
.builder
|
||||||
|
@ -254,7 +227,7 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
|
||||||
) -> Instance<'ctx, Self> {
|
) -> Instance<'ctx, Self> {
|
||||||
assert!(
|
assert!(
|
||||||
value.get_type().get_bit_width()
|
value.get_type().get_bit_width()
|
||||||
> self.0.get_int_type(generator, ctx.ctx).get_bit_width()
|
> self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width()
|
||||||
);
|
);
|
||||||
let value =
|
let value =
|
||||||
ctx.builder.build_int_truncate(value, self.get_type(generator, ctx.ctx), "").unwrap();
|
ctx.builder.build_int_truncate(value, self.get_type(generator, ctx.ctx), "").unwrap();
|
||||||
|
@ -269,7 +242,8 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
|
||||||
value: IntValue<'ctx>,
|
value: IntValue<'ctx>,
|
||||||
) -> Instance<'ctx, Self> {
|
) -> Instance<'ctx, Self> {
|
||||||
let their_width = value.get_type().get_bit_width();
|
let their_width = value.get_type().get_bit_width();
|
||||||
let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width();
|
let our_width =
|
||||||
|
self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width();
|
||||||
match their_width.cmp(&our_width) {
|
match their_width.cmp(&our_width) {
|
||||||
Ordering::Less => self.s_extend(generator, ctx, value),
|
Ordering::Less => self.s_extend(generator, ctx, value),
|
||||||
Ordering::Equal => self.believe_value(value),
|
Ordering::Equal => self.believe_value(value),
|
||||||
|
@ -285,7 +259,8 @@ impl<'ctx, N: IntKind<'ctx>> Int<N> {
|
||||||
value: IntValue<'ctx>,
|
value: IntValue<'ctx>,
|
||||||
) -> Instance<'ctx, Self> {
|
) -> Instance<'ctx, Self> {
|
||||||
let their_width = value.get_type().get_bit_width();
|
let their_width = value.get_type().get_bit_width();
|
||||||
let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width();
|
let our_width =
|
||||||
|
self.0.get_int_type(generator.get_size_type(ctx.ctx), ctx.ctx).get_bit_width();
|
||||||
match their_width.cmp(&our_width) {
|
match their_width.cmp(&our_width) {
|
||||||
Ordering::Less => self.z_extend(generator, ctx, value),
|
Ordering::Less => self.z_extend(generator, ctx, value),
|
||||||
Ordering::Equal => self.believe_value(value),
|
Ordering::Equal => self.believe_value(value),
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::Context,
|
||||||
types::{BasicType, BasicTypeEnum, PointerType},
|
types::{BasicType, BasicTypeEnum, IntType, PointerType},
|
||||||
values::{IntValue, PointerValue},
|
values::{IntValue, PointerValue},
|
||||||
AddressSpace,
|
AddressSpace,
|
||||||
};
|
};
|
||||||
|
@ -23,26 +23,23 @@ pub struct Ptr<Item>(pub Item);
|
||||||
/// `.load()/.store()` is not available for [`Instance`]s of opaque pointers.
|
/// `.load()/.store()` is not available for [`Instance`]s of opaque pointers.
|
||||||
pub type OpaquePtr = Ptr<()>;
|
pub type OpaquePtr = Ptr<()>;
|
||||||
|
|
||||||
// TODO: LLVM 15: `Item: Model<'ctx>` don't even need to be a model anymore. It will only be
|
// TODO: LLVM 15: `Item: ModelBase<'ctx>` don't even need to be a model anymore. It will only be
|
||||||
// a type-hint for the `.load()/.store()` functions for the `pointee_ty`.
|
// a type-hint for the `.load()/.store()` functions for the `pointee_ty`.
|
||||||
//
|
//
|
||||||
// See https://thedan64.github.io/inkwell/inkwell/builder/struct.Builder.html#method.build_load.
|
// See https://thedan64.github.io/inkwell/inkwell/builder/struct.Builder.html#method.build_load.
|
||||||
impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr<Item> {
|
impl<'ctx, Item: ModelBase<'ctx>> ModelBase<'ctx> for Ptr<Item> {
|
||||||
type Value = PointerValue<'ctx>;
|
fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
type Type = PointerType<'ctx>;
|
|
||||||
|
|
||||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
|
||||||
// TODO: LLVM 15: ctx.ptr_type(AddressSpace::default())
|
// TODO: LLVM 15: ctx.ptr_type(AddressSpace::default())
|
||||||
self.0.get_type(generator, ctx).ptr_type(AddressSpace::default())
|
let item = self.0.get_type_impl(size_t, ctx);
|
||||||
|
item.ptr_type(AddressSpace::default()).into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
fn check_type_impl(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
size_t: IntType<'ctx>,
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
ty: T,
|
ty: BasicTypeEnum<'ctx>,
|
||||||
) -> Result<(), ModelError> {
|
) -> Result<(), ModelError> {
|
||||||
let ty = ty.as_basic_type_enum();
|
|
||||||
let Ok(ty) = PointerType::try_from(ty) else {
|
let Ok(ty) = PointerType::try_from(ty) else {
|
||||||
return Err(ModelError(format!("Expecting PointerType, but got {ty:?}")));
|
return Err(ModelError(format!("Expecting PointerType, but got {ty:?}")));
|
||||||
};
|
};
|
||||||
|
@ -57,13 +54,18 @@ impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr<Item> {
|
||||||
// TODO: inkwell `get_element_type()` will be deprecated.
|
// TODO: inkwell `get_element_type()` will be deprecated.
|
||||||
// Remove the check for `get_element_type()` when the time comes.
|
// Remove the check for `get_element_type()` when the time comes.
|
||||||
self.0
|
self.0
|
||||||
.check_type(generator, ctx, elem_ty)
|
.check_type_impl(size_t, ctx, elem_ty)
|
||||||
.map_err(|err| err.under_context("a PointerType"))?;
|
.map_err(|err| err.under_context("a PointerType"))?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr<Item> {
|
||||||
|
type Type = PointerType<'ctx>;
|
||||||
|
type Value = PointerValue<'ctx>;
|
||||||
|
}
|
||||||
|
|
||||||
impl<'ctx, Item: Model<'ctx>> Ptr<Item> {
|
impl<'ctx, Item: Model<'ctx>> Ptr<Item> {
|
||||||
/// Return a ***constant*** nullptr.
|
/// Return a ***constant*** nullptr.
|
||||||
pub fn nullptr<G: CodeGenerator + ?Sized>(
|
pub fn nullptr<G: CodeGenerator + ?Sized>(
|
||||||
|
@ -71,6 +73,7 @@ impl<'ctx, Item: Model<'ctx>> Ptr<Item> {
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
) -> Instance<'ctx, Ptr<Item>> {
|
) -> Instance<'ctx, Ptr<Item>> {
|
||||||
|
// TODO: LLVM 15: Write in an impl where `Item` does not have to be `Model<'ctx>`.
|
||||||
let ptr = self.get_type(generator, ctx).const_null();
|
let ptr = self.get_type(generator, ctx).const_null();
|
||||||
self.believe_value(ptr)
|
self.believe_value(ptr)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,290 +1,141 @@
|
||||||
use std::fmt;
|
use std::{fmt, marker::PhantomData};
|
||||||
|
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::Context,
|
||||||
types::{BasicType, BasicTypeEnum, StructType},
|
types::{BasicType, BasicTypeEnum, IntType, StructType},
|
||||||
values::{BasicValueEnum, StructValue},
|
values::{BasicValueEnum, StructValue},
|
||||||
};
|
};
|
||||||
|
use itertools::{izip, Itertools};
|
||||||
|
|
||||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
/// A traveral that traverses a Rust `struct` that is used to declare an LLVM's struct's field types.
|
// pub trait StructKind2<'ctx>: fmt::Debug + Clone + Copy {
|
||||||
pub trait FieldTraversal<'ctx> {
|
// type Fields<F: FieldTraversal2<'ctx>> = ;
|
||||||
/// Output type of [`FieldTraversal::add`].
|
// }
|
||||||
type Out<M>;
|
|
||||||
|
|
||||||
/// Traverse through the type of a declared field and do something with it.
|
pub struct Field<M> {
|
||||||
///
|
gep_index: u32,
|
||||||
/// * `name` - The cosmetic name of the LLVM field. Used for debugging.
|
model: M,
|
||||||
/// * `model` - The [`Model`] representing the LLVM type of this field.
|
name: &'static str,
|
||||||
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M>;
|
}
|
||||||
|
|
||||||
/// Like [`FieldTraversal::add`] but [`Model`] is automatically inferred from its [`Default`] trait.
|
// NOTE: Very similar to Field, but is forall on `M`, (and also uses ModelBase to get object safety for the `Box<dyn ____>`.
|
||||||
fn add_auto<M: Model<'ctx> + Default>(&mut self, name: &'static str) -> Self::Out<M> {
|
pub struct Entry<'ctx> {
|
||||||
|
model: Box<dyn ModelBase<'ctx> + 'ctx>,
|
||||||
|
name: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FieldMapper<'ctx> {
|
||||||
|
gep_index_counter: u32,
|
||||||
|
entries: Vec<Entry<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> FieldMapper<'ctx> {
|
||||||
|
fn add<M: 'ctx + Model<'ctx>>(&mut self, name: &'static str, model: M) -> Field<M> {
|
||||||
|
let entry = Entry { model: Box::new(model), name };
|
||||||
|
self.entries.push(entry);
|
||||||
|
|
||||||
|
let gep_index = self.gep_index_counter;
|
||||||
|
self.gep_index_counter += 1;
|
||||||
|
Field { gep_index, model, name }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_auto<M: 'ctx + Model<'ctx> + Default>(&mut self, name: &'static str) -> Field<M> {
|
||||||
self.add(name, M::default())
|
self.add(name, M::default())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Descriptor of an LLVM struct field.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct GepField<M> {
|
|
||||||
/// The GEP index of this field. This is the index to use with `build_gep`.
|
|
||||||
pub gep_index: u64,
|
|
||||||
/// The cosmetic name of this field.
|
|
||||||
pub name: &'static str,
|
|
||||||
/// The [`Model`] of this field's type.
|
|
||||||
pub model: M,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A traversal to calculate the GEP index of fields.
|
|
||||||
pub struct GepFieldTraversal {
|
|
||||||
/// The current GEP index.
|
|
||||||
gep_index_counter: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> FieldTraversal<'ctx> for GepFieldTraversal {
|
|
||||||
type Out<M> = GepField<M>;
|
|
||||||
|
|
||||||
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M> {
|
|
||||||
let gep_index = self.gep_index_counter;
|
|
||||||
self.gep_index_counter += 1;
|
|
||||||
Self::Out { gep_index, name, model }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A traversal to collect the field types of a struct.
|
|
||||||
///
|
|
||||||
/// This is used to collect field types and construct the LLVM struct type with [`Context::struct_type`].
|
|
||||||
struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
|
|
||||||
generator: &'a G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
/// The collected field types so far in exact order.
|
|
||||||
field_types: Vec<BasicTypeEnum<'ctx>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> {
|
|
||||||
type Out<M> = (); // Checking types return nothing.
|
|
||||||
|
|
||||||
fn add<M: Model<'ctx>>(&mut self, _name: &'static str, model: M) -> Self::Out<M> {
|
|
||||||
let t = model.get_type(self.generator, self.ctx).as_basic_type_enum();
|
|
||||||
self.field_types.push(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A traversal to check the types of fields.
|
|
||||||
struct CheckTypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
|
|
||||||
generator: &'a mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
/// The current GEP index, so we can tell the index of the field we are checking
|
|
||||||
/// and report the GEP index.
|
|
||||||
gep_index_counter: u32,
|
|
||||||
/// The [`StructType`] to check.
|
|
||||||
scrutinee: StructType<'ctx>,
|
|
||||||
/// The list of collected errors so far.
|
|
||||||
errors: Vec<ModelError>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx>
|
|
||||||
for CheckTypeFieldTraversal<'ctx, 'a, G>
|
|
||||||
{
|
|
||||||
type Out<M> = (); // Checking types return nothing.
|
|
||||||
|
|
||||||
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M> {
|
|
||||||
let gep_index = self.gep_index_counter;
|
|
||||||
self.gep_index_counter += 1;
|
|
||||||
|
|
||||||
if let Some(t) = self.scrutinee.get_field_type_at_index(gep_index) {
|
|
||||||
if let Err(err) = model.check_type(self.generator, self.ctx, t) {
|
|
||||||
self.errors
|
|
||||||
.push(err.under_context(format!("field #{gep_index} '{name}'").as_str()));
|
|
||||||
}
|
|
||||||
} // Otherwise, it will be caught by Struct's `check_type`.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A trait for Rust structs identifying LLVM structures.
|
|
||||||
///
|
|
||||||
/// ### Example
|
|
||||||
///
|
|
||||||
/// Suppose you want to define this structure:
|
|
||||||
/// ```c
|
|
||||||
/// template <typename T>
|
|
||||||
/// struct ContiguousNDArray {
|
|
||||||
/// size_t ndims;
|
|
||||||
/// size_t* shape;
|
|
||||||
/// T* data;
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// This is how it should be done:
|
|
||||||
/// ```ignore
|
|
||||||
/// pub struct ContiguousNDArrayFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> {
|
|
||||||
/// pub ndims: F::Out<Int<SizeT>>,
|
|
||||||
/// pub shape: F::Out<Ptr<Int<SizeT>>>,
|
|
||||||
/// pub data: F::Out<Ptr<Item>>,
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// /// An ndarray without strides and non-opaque `data` field in NAC3.
|
|
||||||
/// #[derive(Debug, Clone, Copy)]
|
|
||||||
/// pub struct ContiguousNDArray<M> {
|
|
||||||
/// /// [`Model`] of the items.
|
|
||||||
/// pub item: M,
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for ContiguousNDArray<Item> {
|
|
||||||
/// type Fields<F: FieldTraversal<'ctx>> = ContiguousNDArrayFields<'ctx, F, Item>;
|
|
||||||
///
|
|
||||||
/// fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
|
||||||
/// // The order of `traversal.add*` is important
|
|
||||||
/// Self::Fields {
|
|
||||||
/// ndims: traversal.add_auto("ndims"),
|
|
||||||
/// shape: traversal.add_auto("shape"),
|
|
||||||
/// data: traversal.add("data", Ptr(self.item)),
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// The [`FieldTraversal`] here is a mechanism to allow the fields of `ContiguousNDArrayFields` to be
|
|
||||||
/// traversed to do useful work such as:
|
|
||||||
///
|
|
||||||
/// - To create the [`StructType`] of `ContiguousNDArray` by collecting [`BasicType`]s of the fields.
|
|
||||||
/// - To enable the `.gep(ctx, |f| f.ndims).store(ctx, ...)` syntax.
|
|
||||||
///
|
|
||||||
/// Suppose now that you have defined `ContiguousNDArray` and you want to allocate a `ContiguousNDArray`
|
|
||||||
/// with dtype `float64` in LLVM, this is how you do it:
|
|
||||||
/// ```ignore
|
|
||||||
/// type F64NDArray = Struct<ContiguousNDArray<Float<Float64>>>; // Type alias for leaner documentation
|
|
||||||
/// let model: F64NDArray = Struct(ContigousNDArray { item: Float(Float64) });
|
|
||||||
/// let ndarray: Instance<'ctx, Ptr<F64NDArray>> = model.alloca(generator, ctx);
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// ...and here is how you may manipulate/access `ndarray`:
|
|
||||||
///
|
|
||||||
/// (NOTE: some arguments have been omitted)
|
|
||||||
///
|
|
||||||
/// ```ignore
|
|
||||||
/// // Get `&ndarray->data`
|
|
||||||
/// ndarray.gep(|f| f.data); // type: Instance<'ctx, Ptr<Float<Float64>>>
|
|
||||||
///
|
|
||||||
/// // Get `ndarray->ndims`
|
|
||||||
/// ndarray.get(|f| f.ndims); // type: Instance<'ctx, Int<SizeT>>
|
|
||||||
///
|
|
||||||
/// // Get `&ndarray->ndims`
|
|
||||||
/// ndarray.gep(|f| f.ndims); // type: Instance<'ctx, Ptr<Int<SizeT>>>
|
|
||||||
///
|
|
||||||
/// // Get `ndarray->shape[0]`
|
|
||||||
/// ndarray.get(|f| f.shape).get_index_const(0); // Instance<'ctx, Int<SizeT>>
|
|
||||||
///
|
|
||||||
/// // Get `&ndarray->shape[2]`
|
|
||||||
/// ndarray.get(|f| f.shape).offset_const(2); // Instance<'ctx, Ptr<Int<SizeT>>>
|
|
||||||
///
|
|
||||||
/// // Do `ndarray->ndims = 3;`
|
|
||||||
/// let num_3 = Int(SizeT).const_int(3);
|
|
||||||
/// ndarray.set(|f| f.ndims, num_3);
|
|
||||||
/// ```
|
|
||||||
pub trait StructKind<'ctx>: fmt::Debug + Clone + Copy {
|
pub trait StructKind<'ctx>: fmt::Debug + Clone + Copy {
|
||||||
/// The associated fields of this struct.
|
type Fields;
|
||||||
type Fields<F: FieldTraversal<'ctx>>;
|
|
||||||
|
|
||||||
/// Traverse through all fields of this [`StructKind`].
|
fn iter_fields(&self, mapper: &mut FieldMapper<'ctx>) -> Self::Fields;
|
||||||
///
|
|
||||||
/// Only used internally in this module for implementing other components.
|
|
||||||
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F>;
|
|
||||||
|
|
||||||
/// Get a convenience structure to get a struct field's GEP index through its corresponding Rust field.
|
// Produce `Vec<Entry>` and `Self::Fields` simultaneously.
|
||||||
///
|
// The former is for doing field-wise type checks.
|
||||||
/// Only used internally in this module for implementing other components.
|
// The latter is for enabling the `.gep(|f| f.data)` syntax.
|
||||||
fn fields(&self) -> Self::Fields<GepFieldTraversal> {
|
fn entries_and_fields(&self) -> (Vec<Entry<'ctx>>, Self::Fields) {
|
||||||
self.traverse_fields(&mut GepFieldTraversal { gep_index_counter: 0 })
|
let mut mapper = FieldMapper { gep_index_counter: 0, entries: Vec::new() };
|
||||||
|
let fields = self.iter_fields(&mut mapper);
|
||||||
|
(mapper.entries, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn entries(&self) -> Vec<Entry<'ctx>> {
|
||||||
|
self.entries_and_fields().0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fields(&self) -> Self::Fields {
|
||||||
|
self.entries_and_fields().1
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the LLVM [`StructType`] of this [`StructKind`].
|
/// Get the LLVM [`StructType`] of this [`StructKind`].
|
||||||
fn get_struct_type<G: CodeGenerator + ?Sized>(
|
fn get_struct_type(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> StructType<'ctx> {
|
||||||
&self,
|
let entries = self.entries();
|
||||||
generator: &G,
|
let entries = entries.into_iter().map(|t| t.model.get_type_impl(size_t, ctx)).collect_vec();
|
||||||
ctx: &'ctx Context,
|
ctx.struct_type(&entries, false)
|
||||||
) -> StructType<'ctx> {
|
|
||||||
let mut traversal = TypeFieldTraversal { generator, ctx, field_types: Vec::new() };
|
|
||||||
self.traverse_fields(&mut traversal);
|
|
||||||
|
|
||||||
ctx.struct_type(&traversal.field_types, false)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A model for LLVM struct.
|
|
||||||
///
|
|
||||||
/// `S` should be of a [`StructKind`].
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
pub struct Struct<S>(pub S);
|
pub struct Struct<S>(pub S);
|
||||||
|
|
||||||
impl<'ctx, S: StructKind<'ctx>> Struct<S> {
|
impl<'ctx, S: StructKind<'ctx>> Struct<S> {
|
||||||
/// Create a constant struct value from its fields.
|
|
||||||
///
|
|
||||||
/// This function also validates `fields` and panic when there is something wrong.
|
|
||||||
pub fn const_struct<G: CodeGenerator + ?Sized>(
|
pub fn const_struct<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
fields: &[BasicValueEnum<'ctx>],
|
fields: &[BasicValueEnum<'ctx>],
|
||||||
) -> Instance<'ctx, Self> {
|
) -> Instance<'ctx, Self> {
|
||||||
// NOTE: There *could* have been a functor `F<M> = Instance<'ctx, M>` for `S::Fields<F>`
|
|
||||||
// to create a more user-friendly interface, but Rust's type system is not sophisticated enough
|
|
||||||
// and if you try doing that Rust would force you put lifetimes everywhere.
|
|
||||||
let val = ctx.const_struct(fields, false);
|
let val = ctx.const_struct(fields, false);
|
||||||
self.check_value(generator, ctx, val).unwrap()
|
self.check_value(generator, ctx, val).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for Struct<S> {
|
impl<'ctx, S: StructKind<'ctx>> ModelBase<'ctx> for Struct<S> {
|
||||||
type Value = StructValue<'ctx>;
|
fn get_type_impl(&self, size_t: IntType<'ctx>, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||||
type Type = StructType<'ctx>;
|
self.0.get_struct_type(size_t, ctx).as_basic_type_enum()
|
||||||
|
|
||||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
|
||||||
self.0.get_struct_type(generator, ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
fn check_type_impl(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
size_t: IntType<'ctx>,
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
ty: T,
|
ty: BasicTypeEnum<'ctx>,
|
||||||
) -> Result<(), ModelError> {
|
) -> Result<(), ModelError> {
|
||||||
let ty = ty.as_basic_type_enum();
|
|
||||||
let Ok(ty) = StructType::try_from(ty) else {
|
let Ok(ty) = StructType::try_from(ty) else {
|
||||||
return Err(ModelError(format!("Expecting StructType, but got {ty:?}")));
|
return Err(ModelError(format!("Expecting StructType, but got {ty:?}")));
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check each field individually.
|
let entries = self.0.entries();
|
||||||
let mut traversal = CheckTypeFieldTraversal {
|
let field_types = ty.get_field_types();
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
gep_index_counter: 0,
|
|
||||||
errors: Vec::new(),
|
|
||||||
scrutinee: ty,
|
|
||||||
};
|
|
||||||
self.0.traverse_fields(&mut traversal);
|
|
||||||
|
|
||||||
// Check the number of fields.
|
// Check the number of fields.
|
||||||
let exp_num_fields = traversal.gep_index_counter;
|
if entries.len() != field_types.len() {
|
||||||
let got_num_fields = u32::try_from(ty.get_field_types().len()).unwrap();
|
|
||||||
if exp_num_fields != got_num_fields {
|
|
||||||
return Err(ModelError(format!(
|
return Err(ModelError(format!(
|
||||||
"Expecting StructType with {exp_num_fields} field(s), but got {got_num_fields}"
|
"Expecting StructType with {} field(s), but got {}",
|
||||||
|
entries.len(),
|
||||||
|
field_types.len()
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if !traversal.errors.is_empty() {
|
// Check each field.
|
||||||
// Currently, only the first error is reported.
|
for (i, (entry, field_type)) in izip!(entries, field_types).enumerate() {
|
||||||
return Err(traversal.errors[0].clone());
|
entry.model.check_type_impl(size_t, ctx, field_type).map_err(|err| {
|
||||||
|
let context = &format!("in field #{i} '{}'", entry.name);
|
||||||
|
err.under_context(context)
|
||||||
|
})?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for Struct<S> {
|
||||||
|
type Type = StructType<'ctx>;
|
||||||
|
type Value = StructValue<'ctx>;
|
||||||
|
}
|
||||||
|
|
||||||
impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Struct<S>> {
|
impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Struct<S>> {
|
||||||
/// Get a field with [`StructValue::get_field_at_index`].
|
/// Get a field with [`StructValue::get_field_at_index`].
|
||||||
pub fn get_field<G: CodeGenerator + ?Sized, M, GetField>(
|
pub fn get_field<G: CodeGenerator + ?Sized, M, GetField>(
|
||||||
|
@ -295,10 +146,10 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Struct<S>> {
|
||||||
) -> Instance<'ctx, M>
|
) -> Instance<'ctx, M>
|
||||||
where
|
where
|
||||||
M: Model<'ctx>,
|
M: Model<'ctx>,
|
||||||
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
|
GetField: FnOnce(S::Fields) -> Field<M>,
|
||||||
{
|
{
|
||||||
let field = get_field(self.model.0.fields());
|
let field = get_field(self.model.0.fields());
|
||||||
let val = self.value.get_field_at_index(field.gep_index as u32).unwrap();
|
let val = self.value.get_field_at_index(field.gep_index).unwrap();
|
||||||
field.model.check_value(generator, ctx, val).unwrap()
|
field.model.check_value(generator, ctx, val).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -312,7 +163,7 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr<Struct<S>>> {
|
||||||
) -> Instance<'ctx, Ptr<M>>
|
) -> Instance<'ctx, Ptr<M>>
|
||||||
where
|
where
|
||||||
M: Model<'ctx>,
|
M: Model<'ctx>,
|
||||||
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
|
GetField: FnOnce(S::Fields) -> Field<M>,
|
||||||
{
|
{
|
||||||
let field = get_field(self.model.0 .0.fields());
|
let field = get_field(self.model.0 .0.fields());
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
@ -321,7 +172,7 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr<Struct<S>>> {
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_in_bounds_gep(
|
.build_in_bounds_gep(
|
||||||
self.value,
|
self.value,
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)],
|
&[llvm_i32.const_zero(), llvm_i32.const_int(u64::from(field.gep_index), false)],
|
||||||
field.name,
|
field.name,
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -339,7 +190,7 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr<Struct<S>>> {
|
||||||
) -> Instance<'ctx, M>
|
) -> Instance<'ctx, M>
|
||||||
where
|
where
|
||||||
M: Model<'ctx>,
|
M: Model<'ctx>,
|
||||||
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
|
GetField: FnOnce(S::Fields) -> Field<M>,
|
||||||
{
|
{
|
||||||
self.gep(ctx, get_field).load(generator, ctx)
|
self.gep(ctx, get_field).load(generator, ctx)
|
||||||
}
|
}
|
||||||
|
@ -352,8 +203,65 @@ impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr<Struct<S>>> {
|
||||||
value: Instance<'ctx, M>,
|
value: Instance<'ctx, M>,
|
||||||
) where
|
) where
|
||||||
M: Model<'ctx>,
|
M: Model<'ctx>,
|
||||||
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
|
GetField: FnOnce(S::Fields) -> Field<M>,
|
||||||
{
|
{
|
||||||
self.gep(ctx, get_field).store(ctx, value);
|
self.gep(ctx, get_field).store(ctx, value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/////////////////////// Example; Delete later
|
||||||
|
|
||||||
|
// Example: NDArray.
|
||||||
|
//
|
||||||
|
// Compared to List, it has no generic models.
|
||||||
|
pub struct NDArrayFields {
|
||||||
|
data: Field<Ptr<Int<Byte>>>,
|
||||||
|
itemsize: Field<Int<SizeT>>,
|
||||||
|
ndims: Field<Int<SizeT>>,
|
||||||
|
shape: Field<Ptr<Int<SizeT>>>,
|
||||||
|
strides: Field<Ptr<Int<SizeT>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
struct NDArray;
|
||||||
|
|
||||||
|
impl<'ctx> StructKind<'ctx> for NDArray {
|
||||||
|
type Fields = NDArrayFields;
|
||||||
|
|
||||||
|
fn iter_fields(&self, mapper: &mut FieldMapper<'ctx>) -> Self::Fields {
|
||||||
|
NDArrayFields {
|
||||||
|
data: mapper.add_auto("data"),
|
||||||
|
itemsize: mapper.add_auto("itemsize"),
|
||||||
|
ndims: mapper.add_auto("ndims"),
|
||||||
|
shape: mapper.add_auto("shape"),
|
||||||
|
strides: mapper.add_auto("strides"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example: List.
|
||||||
|
//
|
||||||
|
// Compared to NDArray, it has generic models.
|
||||||
|
pub struct ListFields<'ctx, Item: Model<'ctx>> {
|
||||||
|
items: Field<Ptr<Item>>,
|
||||||
|
len: Field<Int<SizeT>>,
|
||||||
|
_phantom: PhantomData<&'ctx ()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct List<'ctx, Item: Model<'ctx>> {
|
||||||
|
item: Item,
|
||||||
|
_phantom: PhantomData<&'ctx ()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, Item: Model<'ctx> + 'ctx> StructKind<'ctx> for List<'ctx, Item> {
|
||||||
|
type Fields = ListFields<'ctx, Item>;
|
||||||
|
|
||||||
|
fn iter_fields(&self, mapper: &mut FieldMapper<'ctx>) -> Self::Fields {
|
||||||
|
ListFields {
|
||||||
|
items: mapper.add("items", Ptr(self.item)),
|
||||||
|
len: mapper.add_auto("len"),
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue