From 259481e8d0ae5e7eae36d1a41522b716e1f2cb24 Mon Sep 17 00:00:00 2001 From: lyken Date: Sat, 13 Jul 2024 19:08:57 +0800 Subject: [PATCH] core: optics.rs abstract inkwell --- nac3core/src/codegen/irrt/classes.rs | 94 ++++++++ nac3core/src/codegen/mod.rs | 1 + nac3core/src/codegen/optics.rs | 342 +++++++++++++++++++++++++++ 3 files changed, 437 insertions(+) create mode 100644 nac3core/src/codegen/irrt/classes.rs create mode 100644 nac3core/src/codegen/optics.rs diff --git a/nac3core/src/codegen/irrt/classes.rs b/nac3core/src/codegen/irrt/classes.rs new file mode 100644 index 00000000..195f04b7 --- /dev/null +++ b/nac3core/src/codegen/irrt/classes.rs @@ -0,0 +1,94 @@ +use inkwell::types::{BasicTypeEnum, IntType}; + +use crate::codegen::optics::{AddressLens, GepGetter, IntLens, StructureOptic}; + +// use crate::codegen::structure::{ +// FieldLensBuilder, IntLens, LensWithFieldInfo, PointerLens, StructFieldLens, +// }; + +pub struct NpArrayFields<'ctx> { + pub data: GepGetter>>, + pub itemsize: GepGetter>, + pub ndims: GepGetter>, + pub shape: GepGetter>>, + pub strides: GepGetter>>, +} + +#[derive(Debug, Clone, Copy)] +pub struct NpArrayLens<'ctx> { + pub size_type: IntType<'ctx>, + pub elem_type: BasicTypeEnum<'ctx>, +} + +impl<'ctx> StructureOptic<'ctx> for NpArrayLens<'ctx> { + type Fields = NpArrayFields<'ctx>; + + fn struct_name(&self) -> &'static str { + "NDArray" + } + + fn build_fields( + &self, + builder: &mut crate::codegen::optics::FieldBuilder<'ctx>, + ) -> Self::Fields { + NpArrayFields { + data: builder.add_field("data", AddressLens(IntLens(builder.ctx.i8_type()))), + itemsize: builder.add_field("itemsize", IntLens(builder.ctx.i8_type())), + ndims: builder.add_field("ndims", IntLens(builder.ctx.i8_type())), + shape: builder.add_field("shape", AddressLens(IntLens(self.size_type))), + strides: builder.add_field("strides", AddressLens(IntLens(self.size_type))), + } + } +} + +pub struct IrrtStringFields<'ctx> { + pub buffer: GepGetter>>, + pub capacity: GepGetter>, + pub cursor: GepGetter>, +} + +#[derive(Debug, Clone, Copy)] +pub struct IrrtStringLens; + +impl<'ctx> StructureOptic<'ctx> for IrrtStringLens { + type Fields = IrrtStringFields<'ctx>; + + fn struct_name(&self) -> &'static str { + todo!() + } + + fn build_fields( + &self, + builder: &mut crate::codegen::optics::FieldBuilder<'ctx>, + ) -> Self::Fields { + let llvm_i8 = builder.ctx.i8_type(); + let llvm_i32 = builder.ctx.i32_type(); + IrrtStringFields { + buffer: builder.add_field("buffer", AddressLens(IntLens(llvm_i8))), + capacity: builder.add_field("capacity", IntLens(llvm_i32)), + cursor: builder.add_field("cursor", IntLens(llvm_i32)), + } + } +} + +pub struct ErrorContextFields { + pub message: GepGetter, +} + +#[derive(Debug, Clone, Copy)] +pub struct ErrorContextLens; + +impl<'ctx> StructureOptic<'ctx> for ErrorContextLens { + type Fields = ErrorContextFields; + + fn struct_name(&self) -> &'static str { + "ErrorContext" + } + + fn build_fields( + &self, + builder: &mut crate::codegen::optics::FieldBuilder<'ctx>, + ) -> Self::Fields { + ErrorContextFields { message: builder.add_field("message", IrrtStringLens) } + } +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 17952369..11c4d4ef 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -42,6 +42,7 @@ mod generator; pub mod irrt; pub mod llvm_intrinsics; pub mod numpy; +pub mod optics; pub mod stmt; #[cfg(test)] diff --git a/nac3core/src/codegen/optics.rs b/nac3core/src/codegen/optics.rs new file mode 100644 index 00000000..0465ba9d --- /dev/null +++ b/nac3core/src/codegen/optics.rs @@ -0,0 +1,342 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, IntType}, + values::{AnyValue, BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, + AddressSpace, +}; +use itertools::Itertools; + +use super::CodeGenContext; + +// TODO: Write a taxonomy + +pub trait OpticValue<'ctx> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx>; +} + +impl<'ctx, T: BasicValue<'ctx>> OpticValue<'ctx> for T { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.as_basic_value_enum() + } +} + +// TODO: The interface is unintuitive +pub trait Optic<'ctx>: Clone { + type Value: OpticValue<'ctx>; + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>; + + fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Address<'ctx, Self> { + let ptr = ctx.builder.build_alloca(self.get_llvm_type(ctx.ctx), name).unwrap(); + Address { addressee_optic: self.clone(), address: ptr } + } +} + +pub trait Prism<'ctx>: Optic<'ctx> { + // TODO: Return error if `review` fails + fn review>(&self, value: V) -> Self::Value; +} + +pub trait MemoryGetter<'ctx>: Optic<'ctx> { + fn get( + &self, + ctx: &CodeGenContext<'ctx, '_>, + pointer: PointerValue<'ctx>, + name: &str, + ) -> Self::Value; +} + +pub trait MemorySetter<'ctx>: Optic<'ctx> { + fn set(&self, ctx: &CodeGenContext<'ctx, '_>, pointer: PointerValue<'ctx>, value: &Self::Value); +} + +pub trait SizedIntLens<'ctx>: Optic<'ctx, Value = IntValue<'ctx>> {} + +// NOTE: I wanted to make Int8Lens, Int16Lens, Int32Lens, with all +// having the trait IsIntLens, and implement `impl Optic for T`, +// but that clashes with StructureOptic!! +#[derive(Debug, Clone, Copy)] +pub struct IntLens<'ctx>(pub IntType<'ctx>); + +impl<'ctx> Optic<'ctx> for IntLens<'ctx> { + type Value = IntValue<'ctx>; + + fn get_llvm_type(&self, _ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.as_basic_type_enum() + } +} + +impl<'ctx> Prism<'ctx> for IntLens<'ctx> { + fn review>(&self, value: V) -> Self::Value { + let int = value.as_any_value_enum().into_int_value(); + debug_assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width()); + int + } +} + +impl<'ctx> MemoryGetter<'ctx> for IntLens<'ctx> { + fn get( + &self, + ctx: &CodeGenContext<'ctx, '_>, + pointer: PointerValue<'ctx>, + name: &str, + ) -> Self::Value { + self.review(ctx.builder.build_load(pointer, name).unwrap()) + } +} + +impl<'ctx> MemorySetter<'ctx> for IntLens<'ctx> { + fn set(&self, ctx: &CodeGenContext<'ctx, '_>, pointer: PointerValue<'ctx>, int: &Self::Value) { + debug_assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width()); + ctx.builder.build_store(pointer, int.as_basic_value_enum()).unwrap(); + } +} + +#[derive(Debug, Clone)] +pub struct Address<'ctx, AddresseeOptic> { + pub addressee_optic: AddresseeOptic, + pub address: PointerValue<'ctx>, +} + +impl<'ctx, AddresseeOptic> Address<'ctx, AddresseeOptic> { + pub fn cast_to>( + &self, + ctx: &CodeGenContext<'ctx, '_>, + new_optic: S, + ) -> Address<'ctx, S> { + let to_ptr_type = new_optic.get_llvm_type(ctx.ctx).ptr_type(AddressSpace::default()); + let casted_address = + ctx.builder.build_pointer_cast(self.address, to_ptr_type, "ptr_casted").unwrap(); + Address { addressee_optic: new_optic, address: casted_address } + } + + pub fn cast_to_opaque(&self, ctx: &CodeGenContext<'ctx, '_>) -> Address<'ctx, IntLens<'ctx>> { + self.cast_to(ctx, IntLens(ctx.ctx.i8_type())) + } +} + +impl<'ctx, AddresseeOptic> OpticValue<'ctx> for Address<'ctx, AddresseeOptic> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.address.as_basic_value_enum() + } +} + +#[derive(Debug, Clone)] +pub struct AddressLens(pub AddresseeOptic); + +impl<'ctx, AddresseeOptic: Optic<'ctx>> Optic<'ctx> for AddressLens { + type Value = Address<'ctx, AddresseeOptic>; + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum() + } +} + +impl<'ctx, AddresseeOptic: Optic<'ctx>> Prism<'ctx> for AddressLens { + fn review>(&self, value: V) -> Self::Value { + Address { + addressee_optic: self.0.clone(), + address: value.as_any_value_enum().into_pointer_value(), + } + } +} + +impl<'ctx, AddressesOptic: Optic<'ctx>> MemoryGetter<'ctx> for AddressLens { + fn get( + &self, + ctx: &CodeGenContext<'ctx, '_>, + pointer: PointerValue<'ctx>, + name: &str, + ) -> Self::Value { + self.review(ctx.builder.build_load(pointer, name).unwrap()) + } +} + +impl<'ctx, AddressesOptic: Optic<'ctx>> MemorySetter<'ctx> for AddressLens { + fn set( + &self, + ctx: &CodeGenContext<'ctx, '_>, + pointer: PointerValue<'ctx>, + value: &Self::Value, + ) { + ctx.builder.build_store(pointer, value.address).unwrap(); + } +} + +// To make [`Address`] convenient to use +impl<'ctx, AddresseeOptic: MemoryGetter<'ctx>> Address<'ctx, AddresseeOptic> { + pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> AddresseeOptic::Value { + self.addressee_optic.get(ctx, self.address, name) + } +} + +// To make [`Address`] convenient to use +impl<'ctx, AddresseeOptic: MemorySetter<'ctx>> Address<'ctx, AddresseeOptic> { + pub fn set(&self, ctx: &CodeGenContext<'ctx, '_>, value: &AddresseeOptic::Value) { + self.addressee_optic.set(ctx, self.address, value) + } +} + +// ((Memory, Pointer) -> ElementOptic::Value*) +#[derive(Debug, Clone)] +pub struct GepGetter { + /// The LLVM GEP index + pub gep_index: u32, // TODO: I think I'm not supposed to *just* use i32 for GEP like that + /// Element (or field in the context of `struct`s) name. Used for cosmetics. + pub name: &'static str, + /// The lens to view the actual value after applying this [`FieldLens`] + pub element_optic: ElementOptic, +} + +impl<'ctx, ElementOptic: Optic<'ctx>> Optic<'ctx> for GepGetter { + type Value = Address<'ctx, ElementOptic>; + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.element_optic.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum() + } +} + +impl<'ctx, ElementOptic: Optic<'ctx>> MemoryGetter<'ctx> for GepGetter { + fn get( + &self, + ctx: &CodeGenContext<'ctx, '_>, + pointer: PointerValue<'ctx>, + name: &str, + ) -> Self::Value { + let llvm_i32 = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to *just* use i32 for GEP like that + let element_ptr = unsafe { + ctx.builder + .build_in_bounds_gep( + pointer, + &[llvm_i32.const_zero(), llvm_i32.const_int(self.gep_index as u64, false)], + name, + ) + .unwrap() + }; + Address { address: element_ptr, addressee_optic: self.element_optic.clone() } + } +} + +// Only used by [`FieldBuilder`] +#[derive(Debug)] +struct FieldInfo<'ctx> { + gep_index: u32, + name: &'ctx str, + llvm_type: BasicTypeEnum<'ctx>, +} + +#[derive(Debug)] +pub struct FieldBuilder<'ctx> { + pub ctx: &'ctx Context, + gep_index_counter: u32, + struct_name: &'ctx str, + fields: Vec>, +} + +impl<'ctx> FieldBuilder<'ctx> { + pub fn new(ctx: &'ctx Context, struct_name: &'ctx str) -> Self { + FieldBuilder { ctx, gep_index_counter: 0, struct_name, fields: Vec::new() } + } + + fn next_gep_index(&mut self) -> u32 { + let index = self.gep_index_counter; + self.gep_index_counter += 1; + index + } + + pub fn add_field>( + &mut self, + name: &'static str, + element_optic: ElementOptic, + ) -> GepGetter { + let gep_index = self.next_gep_index(); + + self.fields.push(FieldInfo { + gep_index, + name, + llvm_type: element_optic.get_llvm_type(self.ctx), + }); + + GepGetter { gep_index, name, element_optic } + } +} + +pub trait StructureOptic<'ctx>: Clone { + // Fields of optics + type Fields; + + // TODO: Make it an associated function instead? + fn struct_name(&self) -> &'static str; + + fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields; + + fn get_fields(&self, ctx: &'ctx Context) -> Self::Fields { + let mut builder = FieldBuilder::new(ctx, self.struct_name()); + self.build_fields(&mut builder) + } +} + +pub struct OpticalStructValue<'ctx, StructOptic> { + optic: StructOptic, + llvm: StructValue<'ctx>, +} + +impl<'ctx, StructOptic> OpticValue<'ctx> for OpticalStructValue<'ctx, StructOptic> { + fn get_llvm_value(&self) -> BasicValueEnum<'ctx> { + self.llvm.as_basic_value_enum() + } +} + +// TODO: check StructType +impl<'ctx, T: StructureOptic<'ctx>> Optic<'ctx> for T { + type Value = OpticalStructValue<'ctx, Self>; + + fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + let mut builder = FieldBuilder::new(ctx, self.struct_name()); + self.build_fields(&mut builder); // Self::Fields is discarded + + let field_types = + builder.fields.iter().map(|field_info| field_info.llvm_type).collect_vec(); + ctx.struct_type(&field_types, false).as_basic_type_enum() + } +} + +impl<'ctx, T: StructureOptic<'ctx>> MemoryGetter<'ctx> for T { + fn get( + &self, + ctx: &CodeGenContext<'ctx, '_>, + pointer: PointerValue<'ctx>, + name: &str, + ) -> Self::Value { + OpticalStructValue { + optic: self.clone(), + llvm: ctx.builder.build_load(pointer, name).unwrap().into_struct_value(), + } + } +} + +impl<'ctx, T: StructureOptic<'ctx>> MemorySetter<'ctx> for T { + fn set( + &self, + ctx: &CodeGenContext<'ctx, '_>, + pointer: PointerValue<'ctx>, + value: &Self::Value, + ) { + ctx.builder.build_store(pointer, value.llvm).unwrap(); + } +} + +impl<'ctx, AddresseeOptic: StructureOptic<'ctx>> Address<'ctx, AddresseeOptic> { + pub fn view>( + &self, + ctx: &CodeGenContext<'ctx, '_>, + get_field_gep_fn: GetFieldGepFn, + ) -> Address<'ctx, FieldElementOptic> + where + GetFieldGepFn: FnOnce(&AddresseeOptic::Fields) -> &GepGetter, + { + let fields = self.addressee_optic.get_fields(ctx.ctx); + let field = get_field_gep_fn(&fields); + field.get(ctx, self.address, field.name) + } +}