forked from M-Labs/nac3
1
0
Fork 0

core: optics.rs abstract inkwell

This commit is contained in:
lyken 2024-07-13 19:08:57 +08:00
parent 5faac4b9d4
commit 259481e8d0
3 changed files with 437 additions and 0 deletions

View File

@ -0,0 +1,94 @@
use inkwell::types::{BasicTypeEnum, IntType};
use crate::codegen::optics::{AddressLens, GepGetter, IntLens, StructureOptic};
// use crate::codegen::structure::{
// FieldLensBuilder, IntLens, LensWithFieldInfo, PointerLens, StructFieldLens,
// };
pub struct NpArrayFields<'ctx> {
pub data: GepGetter<AddressLens<IntLens<'ctx>>>,
pub itemsize: GepGetter<IntLens<'ctx>>,
pub ndims: GepGetter<IntLens<'ctx>>,
pub shape: GepGetter<AddressLens<IntLens<'ctx>>>,
pub strides: GepGetter<AddressLens<IntLens<'ctx>>>,
}
#[derive(Debug, Clone, Copy)]
pub struct NpArrayLens<'ctx> {
pub size_type: IntType<'ctx>,
pub elem_type: BasicTypeEnum<'ctx>,
}
impl<'ctx> StructureOptic<'ctx> for NpArrayLens<'ctx> {
type Fields = NpArrayFields<'ctx>;
fn struct_name(&self) -> &'static str {
"NDArray"
}
fn build_fields(
&self,
builder: &mut crate::codegen::optics::FieldBuilder<'ctx>,
) -> Self::Fields {
NpArrayFields {
data: builder.add_field("data", AddressLens(IntLens(builder.ctx.i8_type()))),
itemsize: builder.add_field("itemsize", IntLens(builder.ctx.i8_type())),
ndims: builder.add_field("ndims", IntLens(builder.ctx.i8_type())),
shape: builder.add_field("shape", AddressLens(IntLens(self.size_type))),
strides: builder.add_field("strides", AddressLens(IntLens(self.size_type))),
}
}
}
pub struct IrrtStringFields<'ctx> {
pub buffer: GepGetter<AddressLens<IntLens<'ctx>>>,
pub capacity: GepGetter<IntLens<'ctx>>,
pub cursor: GepGetter<IntLens<'ctx>>,
}
#[derive(Debug, Clone, Copy)]
pub struct IrrtStringLens;
impl<'ctx> StructureOptic<'ctx> for IrrtStringLens {
type Fields = IrrtStringFields<'ctx>;
fn struct_name(&self) -> &'static str {
todo!()
}
fn build_fields(
&self,
builder: &mut crate::codegen::optics::FieldBuilder<'ctx>,
) -> Self::Fields {
let llvm_i8 = builder.ctx.i8_type();
let llvm_i32 = builder.ctx.i32_type();
IrrtStringFields {
buffer: builder.add_field("buffer", AddressLens(IntLens(llvm_i8))),
capacity: builder.add_field("capacity", IntLens(llvm_i32)),
cursor: builder.add_field("cursor", IntLens(llvm_i32)),
}
}
}
pub struct ErrorContextFields {
pub message: GepGetter<IrrtStringLens>,
}
#[derive(Debug, Clone, Copy)]
pub struct ErrorContextLens;
impl<'ctx> StructureOptic<'ctx> for ErrorContextLens {
type Fields = ErrorContextFields;
fn struct_name(&self) -> &'static str {
"ErrorContext"
}
fn build_fields(
&self,
builder: &mut crate::codegen::optics::FieldBuilder<'ctx>,
) -> Self::Fields {
ErrorContextFields { message: builder.add_field("message", IrrtStringLens) }
}
}

View File

@ -42,6 +42,7 @@ mod generator;
pub mod irrt;
pub mod llvm_intrinsics;
pub mod numpy;
pub mod optics;
pub mod stmt;
#[cfg(test)]

View File

@ -0,0 +1,342 @@
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, IntType},
values::{AnyValue, BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
AddressSpace,
};
use itertools::Itertools;
use super::CodeGenContext;
// TODO: Write a taxonomy
pub trait OpticValue<'ctx> {
fn get_llvm_value(&self) -> BasicValueEnum<'ctx>;
}
impl<'ctx, T: BasicValue<'ctx>> OpticValue<'ctx> for T {
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
self.as_basic_value_enum()
}
}
// TODO: The interface is unintuitive
pub trait Optic<'ctx>: Clone {
type Value: OpticValue<'ctx>;
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Address<'ctx, Self> {
let ptr = ctx.builder.build_alloca(self.get_llvm_type(ctx.ctx), name).unwrap();
Address { addressee_optic: self.clone(), address: ptr }
}
}
pub trait Prism<'ctx>: Optic<'ctx> {
// TODO: Return error if `review` fails
fn review<V: AnyValue<'ctx>>(&self, value: V) -> Self::Value;
}
pub trait MemoryGetter<'ctx>: Optic<'ctx> {
fn get(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pointer: PointerValue<'ctx>,
name: &str,
) -> Self::Value;
}
pub trait MemorySetter<'ctx>: Optic<'ctx> {
fn set(&self, ctx: &CodeGenContext<'ctx, '_>, pointer: PointerValue<'ctx>, value: &Self::Value);
}
pub trait SizedIntLens<'ctx>: Optic<'ctx, Value = IntValue<'ctx>> {}
// NOTE: I wanted to make Int8Lens, Int16Lens, Int32Lens, with all
// having the trait IsIntLens, and implement `impl <S: IsIntLens> Optic<S> for T`,
// but that clashes with StructureOptic!!
#[derive(Debug, Clone, Copy)]
pub struct IntLens<'ctx>(pub IntType<'ctx>);
impl<'ctx> Optic<'ctx> for IntLens<'ctx> {
type Value = IntValue<'ctx>;
fn get_llvm_type(&self, _ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
self.0.as_basic_type_enum()
}
}
impl<'ctx> Prism<'ctx> for IntLens<'ctx> {
fn review<V: AnyValue<'ctx>>(&self, value: V) -> Self::Value {
let int = value.as_any_value_enum().into_int_value();
debug_assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width());
int
}
}
impl<'ctx> MemoryGetter<'ctx> for IntLens<'ctx> {
fn get(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pointer: PointerValue<'ctx>,
name: &str,
) -> Self::Value {
self.review(ctx.builder.build_load(pointer, name).unwrap())
}
}
impl<'ctx> MemorySetter<'ctx> for IntLens<'ctx> {
fn set(&self, ctx: &CodeGenContext<'ctx, '_>, pointer: PointerValue<'ctx>, int: &Self::Value) {
debug_assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width());
ctx.builder.build_store(pointer, int.as_basic_value_enum()).unwrap();
}
}
#[derive(Debug, Clone)]
pub struct Address<'ctx, AddresseeOptic> {
pub addressee_optic: AddresseeOptic,
pub address: PointerValue<'ctx>,
}
impl<'ctx, AddresseeOptic> Address<'ctx, AddresseeOptic> {
pub fn cast_to<S: Optic<'ctx>>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
new_optic: S,
) -> Address<'ctx, S> {
let to_ptr_type = new_optic.get_llvm_type(ctx.ctx).ptr_type(AddressSpace::default());
let casted_address =
ctx.builder.build_pointer_cast(self.address, to_ptr_type, "ptr_casted").unwrap();
Address { addressee_optic: new_optic, address: casted_address }
}
pub fn cast_to_opaque(&self, ctx: &CodeGenContext<'ctx, '_>) -> Address<'ctx, IntLens<'ctx>> {
self.cast_to(ctx, IntLens(ctx.ctx.i8_type()))
}
}
impl<'ctx, AddresseeOptic> OpticValue<'ctx> for Address<'ctx, AddresseeOptic> {
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
self.address.as_basic_value_enum()
}
}
#[derive(Debug, Clone)]
pub struct AddressLens<AddresseeOptic>(pub AddresseeOptic);
impl<'ctx, AddresseeOptic: Optic<'ctx>> Optic<'ctx> for AddressLens<AddresseeOptic> {
type Value = Address<'ctx, AddresseeOptic>;
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum()
}
}
impl<'ctx, AddresseeOptic: Optic<'ctx>> Prism<'ctx> for AddressLens<AddresseeOptic> {
fn review<V: AnyValue<'ctx>>(&self, value: V) -> Self::Value {
Address {
addressee_optic: self.0.clone(),
address: value.as_any_value_enum().into_pointer_value(),
}
}
}
impl<'ctx, AddressesOptic: Optic<'ctx>> MemoryGetter<'ctx> for AddressLens<AddressesOptic> {
fn get(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pointer: PointerValue<'ctx>,
name: &str,
) -> Self::Value {
self.review(ctx.builder.build_load(pointer, name).unwrap())
}
}
impl<'ctx, AddressesOptic: Optic<'ctx>> MemorySetter<'ctx> for AddressLens<AddressesOptic> {
fn set(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pointer: PointerValue<'ctx>,
value: &Self::Value,
) {
ctx.builder.build_store(pointer, value.address).unwrap();
}
}
// To make [`Address`] convenient to use
impl<'ctx, AddresseeOptic: MemoryGetter<'ctx>> Address<'ctx, AddresseeOptic> {
pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> AddresseeOptic::Value {
self.addressee_optic.get(ctx, self.address, name)
}
}
// To make [`Address`] convenient to use
impl<'ctx, AddresseeOptic: MemorySetter<'ctx>> Address<'ctx, AddresseeOptic> {
pub fn set(&self, ctx: &CodeGenContext<'ctx, '_>, value: &AddresseeOptic::Value) {
self.addressee_optic.set(ctx, self.address, value)
}
}
// ((Memory, Pointer) -> ElementOptic::Value*)
#[derive(Debug, Clone)]
pub struct GepGetter<ElementOptic> {
/// The LLVM GEP index
pub gep_index: u32, // TODO: I think I'm not supposed to *just* use i32 for GEP like that
/// Element (or field in the context of `struct`s) name. Used for cosmetics.
pub name: &'static str,
/// The lens to view the actual value after applying this [`FieldLens<T>`]
pub element_optic: ElementOptic,
}
impl<'ctx, ElementOptic: Optic<'ctx>> Optic<'ctx> for GepGetter<ElementOptic> {
type Value = Address<'ctx, ElementOptic>;
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
self.element_optic.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum()
}
}
impl<'ctx, ElementOptic: Optic<'ctx>> MemoryGetter<'ctx> for GepGetter<ElementOptic> {
fn get(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pointer: PointerValue<'ctx>,
name: &str,
) -> Self::Value {
let llvm_i32 = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to *just* use i32 for GEP like that
let element_ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
pointer,
&[llvm_i32.const_zero(), llvm_i32.const_int(self.gep_index as u64, false)],
name,
)
.unwrap()
};
Address { address: element_ptr, addressee_optic: self.element_optic.clone() }
}
}
// Only used by [`FieldBuilder`]
#[derive(Debug)]
struct FieldInfo<'ctx> {
gep_index: u32,
name: &'ctx str,
llvm_type: BasicTypeEnum<'ctx>,
}
#[derive(Debug)]
pub struct FieldBuilder<'ctx> {
pub ctx: &'ctx Context,
gep_index_counter: u32,
struct_name: &'ctx str,
fields: Vec<FieldInfo<'ctx>>,
}
impl<'ctx> FieldBuilder<'ctx> {
pub fn new(ctx: &'ctx Context, struct_name: &'ctx str) -> Self {
FieldBuilder { ctx, gep_index_counter: 0, struct_name, fields: Vec::new() }
}
fn next_gep_index(&mut self) -> u32 {
let index = self.gep_index_counter;
self.gep_index_counter += 1;
index
}
pub fn add_field<ElementOptic: Optic<'ctx>>(
&mut self,
name: &'static str,
element_optic: ElementOptic,
) -> GepGetter<ElementOptic> {
let gep_index = self.next_gep_index();
self.fields.push(FieldInfo {
gep_index,
name,
llvm_type: element_optic.get_llvm_type(self.ctx),
});
GepGetter { gep_index, name, element_optic }
}
}
pub trait StructureOptic<'ctx>: Clone {
// Fields of optics
type Fields;
// TODO: Make it an associated function instead?
fn struct_name(&self) -> &'static str;
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields;
fn get_fields(&self, ctx: &'ctx Context) -> Self::Fields {
let mut builder = FieldBuilder::new(ctx, self.struct_name());
self.build_fields(&mut builder)
}
}
pub struct OpticalStructValue<'ctx, StructOptic> {
optic: StructOptic,
llvm: StructValue<'ctx>,
}
impl<'ctx, StructOptic> OpticValue<'ctx> for OpticalStructValue<'ctx, StructOptic> {
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
self.llvm.as_basic_value_enum()
}
}
// TODO: check StructType
impl<'ctx, T: StructureOptic<'ctx>> Optic<'ctx> for T {
type Value = OpticalStructValue<'ctx, Self>;
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
let mut builder = FieldBuilder::new(ctx, self.struct_name());
self.build_fields(&mut builder); // Self::Fields is discarded
let field_types =
builder.fields.iter().map(|field_info| field_info.llvm_type).collect_vec();
ctx.struct_type(&field_types, false).as_basic_type_enum()
}
}
impl<'ctx, T: StructureOptic<'ctx>> MemoryGetter<'ctx> for T {
fn get(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pointer: PointerValue<'ctx>,
name: &str,
) -> Self::Value {
OpticalStructValue {
optic: self.clone(),
llvm: ctx.builder.build_load(pointer, name).unwrap().into_struct_value(),
}
}
}
impl<'ctx, T: StructureOptic<'ctx>> MemorySetter<'ctx> for T {
fn set(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pointer: PointerValue<'ctx>,
value: &Self::Value,
) {
ctx.builder.build_store(pointer, value.llvm).unwrap();
}
}
impl<'ctx, AddresseeOptic: StructureOptic<'ctx>> Address<'ctx, AddresseeOptic> {
pub fn view<GetFieldGepFn, FieldElementOptic: Optic<'ctx>>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
get_field_gep_fn: GetFieldGepFn,
) -> Address<'ctx, FieldElementOptic>
where
GetFieldGepFn: FnOnce(&AddresseeOptic::Fields) -> &GepGetter<FieldElementOptic>,
{
let fields = self.addressee_optic.get_fields(ctx.ctx);
let field = get_field_gep_fn(&fields);
field.get(ctx, self.address, field.name)
}
}