forked from M-Labs/nac3
core: optics.rs abstract inkwell
This commit is contained in:
parent
5faac4b9d4
commit
259481e8d0
94
nac3core/src/codegen/irrt/classes.rs
Normal file
94
nac3core/src/codegen/irrt/classes.rs
Normal file
@ -0,0 +1,94 @@
|
||||
use inkwell::types::{BasicTypeEnum, IntType};
|
||||
|
||||
use crate::codegen::optics::{AddressLens, GepGetter, IntLens, StructureOptic};
|
||||
|
||||
// use crate::codegen::structure::{
|
||||
// FieldLensBuilder, IntLens, LensWithFieldInfo, PointerLens, StructFieldLens,
|
||||
// };
|
||||
|
||||
pub struct NpArrayFields<'ctx> {
|
||||
pub data: GepGetter<AddressLens<IntLens<'ctx>>>,
|
||||
pub itemsize: GepGetter<IntLens<'ctx>>,
|
||||
pub ndims: GepGetter<IntLens<'ctx>>,
|
||||
pub shape: GepGetter<AddressLens<IntLens<'ctx>>>,
|
||||
pub strides: GepGetter<AddressLens<IntLens<'ctx>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct NpArrayLens<'ctx> {
|
||||
pub size_type: IntType<'ctx>,
|
||||
pub elem_type: BasicTypeEnum<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx> StructureOptic<'ctx> for NpArrayLens<'ctx> {
|
||||
type Fields = NpArrayFields<'ctx>;
|
||||
|
||||
fn struct_name(&self) -> &'static str {
|
||||
"NDArray"
|
||||
}
|
||||
|
||||
fn build_fields(
|
||||
&self,
|
||||
builder: &mut crate::codegen::optics::FieldBuilder<'ctx>,
|
||||
) -> Self::Fields {
|
||||
NpArrayFields {
|
||||
data: builder.add_field("data", AddressLens(IntLens(builder.ctx.i8_type()))),
|
||||
itemsize: builder.add_field("itemsize", IntLens(builder.ctx.i8_type())),
|
||||
ndims: builder.add_field("ndims", IntLens(builder.ctx.i8_type())),
|
||||
shape: builder.add_field("shape", AddressLens(IntLens(self.size_type))),
|
||||
strides: builder.add_field("strides", AddressLens(IntLens(self.size_type))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IrrtStringFields<'ctx> {
|
||||
pub buffer: GepGetter<AddressLens<IntLens<'ctx>>>,
|
||||
pub capacity: GepGetter<IntLens<'ctx>>,
|
||||
pub cursor: GepGetter<IntLens<'ctx>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct IrrtStringLens;
|
||||
|
||||
impl<'ctx> StructureOptic<'ctx> for IrrtStringLens {
|
||||
type Fields = IrrtStringFields<'ctx>;
|
||||
|
||||
fn struct_name(&self) -> &'static str {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn build_fields(
|
||||
&self,
|
||||
builder: &mut crate::codegen::optics::FieldBuilder<'ctx>,
|
||||
) -> Self::Fields {
|
||||
let llvm_i8 = builder.ctx.i8_type();
|
||||
let llvm_i32 = builder.ctx.i32_type();
|
||||
IrrtStringFields {
|
||||
buffer: builder.add_field("buffer", AddressLens(IntLens(llvm_i8))),
|
||||
capacity: builder.add_field("capacity", IntLens(llvm_i32)),
|
||||
cursor: builder.add_field("cursor", IntLens(llvm_i32)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ErrorContextFields {
|
||||
pub message: GepGetter<IrrtStringLens>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ErrorContextLens;
|
||||
|
||||
impl<'ctx> StructureOptic<'ctx> for ErrorContextLens {
|
||||
type Fields = ErrorContextFields;
|
||||
|
||||
fn struct_name(&self) -> &'static str {
|
||||
"ErrorContext"
|
||||
}
|
||||
|
||||
fn build_fields(
|
||||
&self,
|
||||
builder: &mut crate::codegen::optics::FieldBuilder<'ctx>,
|
||||
) -> Self::Fields {
|
||||
ErrorContextFields { message: builder.add_field("message", IrrtStringLens) }
|
||||
}
|
||||
}
|
@ -42,6 +42,7 @@ mod generator;
|
||||
pub mod irrt;
|
||||
pub mod llvm_intrinsics;
|
||||
pub mod numpy;
|
||||
pub mod optics;
|
||||
pub mod stmt;
|
||||
|
||||
#[cfg(test)]
|
||||
|
342
nac3core/src/codegen/optics.rs
Normal file
342
nac3core/src/codegen/optics.rs
Normal file
@ -0,0 +1,342 @@
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
types::{BasicType, BasicTypeEnum, IntType},
|
||||
values::{AnyValue, BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::CodeGenContext;
|
||||
|
||||
// TODO: Write a taxonomy
|
||||
|
||||
pub trait OpticValue<'ctx> {
|
||||
fn get_llvm_value(&self) -> BasicValueEnum<'ctx>;
|
||||
}
|
||||
|
||||
impl<'ctx, T: BasicValue<'ctx>> OpticValue<'ctx> for T {
|
||||
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
|
||||
self.as_basic_value_enum()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: The interface is unintuitive
|
||||
pub trait Optic<'ctx>: Clone {
|
||||
type Value: OpticValue<'ctx>;
|
||||
|
||||
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
|
||||
|
||||
fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Address<'ctx, Self> {
|
||||
let ptr = ctx.builder.build_alloca(self.get_llvm_type(ctx.ctx), name).unwrap();
|
||||
Address { addressee_optic: self.clone(), address: ptr }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Prism<'ctx>: Optic<'ctx> {
|
||||
// TODO: Return error if `review` fails
|
||||
fn review<V: AnyValue<'ctx>>(&self, value: V) -> Self::Value;
|
||||
}
|
||||
|
||||
pub trait MemoryGetter<'ctx>: Optic<'ctx> {
|
||||
fn get(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pointer: PointerValue<'ctx>,
|
||||
name: &str,
|
||||
) -> Self::Value;
|
||||
}
|
||||
|
||||
pub trait MemorySetter<'ctx>: Optic<'ctx> {
|
||||
fn set(&self, ctx: &CodeGenContext<'ctx, '_>, pointer: PointerValue<'ctx>, value: &Self::Value);
|
||||
}
|
||||
|
||||
pub trait SizedIntLens<'ctx>: Optic<'ctx, Value = IntValue<'ctx>> {}
|
||||
|
||||
// NOTE: I wanted to make Int8Lens, Int16Lens, Int32Lens, with all
|
||||
// having the trait IsIntLens, and implement `impl <S: IsIntLens> Optic<S> for T`,
|
||||
// but that clashes with StructureOptic!!
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct IntLens<'ctx>(pub IntType<'ctx>);
|
||||
|
||||
impl<'ctx> Optic<'ctx> for IntLens<'ctx> {
|
||||
type Value = IntValue<'ctx>;
|
||||
|
||||
fn get_llvm_type(&self, _ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||
self.0.as_basic_type_enum()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> Prism<'ctx> for IntLens<'ctx> {
|
||||
fn review<V: AnyValue<'ctx>>(&self, value: V) -> Self::Value {
|
||||
let int = value.as_any_value_enum().into_int_value();
|
||||
debug_assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width());
|
||||
int
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> MemoryGetter<'ctx> for IntLens<'ctx> {
|
||||
fn get(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pointer: PointerValue<'ctx>,
|
||||
name: &str,
|
||||
) -> Self::Value {
|
||||
self.review(ctx.builder.build_load(pointer, name).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> MemorySetter<'ctx> for IntLens<'ctx> {
|
||||
fn set(&self, ctx: &CodeGenContext<'ctx, '_>, pointer: PointerValue<'ctx>, int: &Self::Value) {
|
||||
debug_assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width());
|
||||
ctx.builder.build_store(pointer, int.as_basic_value_enum()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Address<'ctx, AddresseeOptic> {
|
||||
pub addressee_optic: AddresseeOptic,
|
||||
pub address: PointerValue<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx, AddresseeOptic> Address<'ctx, AddresseeOptic> {
|
||||
pub fn cast_to<S: Optic<'ctx>>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
new_optic: S,
|
||||
) -> Address<'ctx, S> {
|
||||
let to_ptr_type = new_optic.get_llvm_type(ctx.ctx).ptr_type(AddressSpace::default());
|
||||
let casted_address =
|
||||
ctx.builder.build_pointer_cast(self.address, to_ptr_type, "ptr_casted").unwrap();
|
||||
Address { addressee_optic: new_optic, address: casted_address }
|
||||
}
|
||||
|
||||
pub fn cast_to_opaque(&self, ctx: &CodeGenContext<'ctx, '_>) -> Address<'ctx, IntLens<'ctx>> {
|
||||
self.cast_to(ctx, IntLens(ctx.ctx.i8_type()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, AddresseeOptic> OpticValue<'ctx> for Address<'ctx, AddresseeOptic> {
|
||||
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
|
||||
self.address.as_basic_value_enum()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AddressLens<AddresseeOptic>(pub AddresseeOptic);
|
||||
|
||||
impl<'ctx, AddresseeOptic: Optic<'ctx>> Optic<'ctx> for AddressLens<AddresseeOptic> {
|
||||
type Value = Address<'ctx, AddresseeOptic>;
|
||||
|
||||
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||
self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, AddresseeOptic: Optic<'ctx>> Prism<'ctx> for AddressLens<AddresseeOptic> {
|
||||
fn review<V: AnyValue<'ctx>>(&self, value: V) -> Self::Value {
|
||||
Address {
|
||||
addressee_optic: self.0.clone(),
|
||||
address: value.as_any_value_enum().into_pointer_value(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, AddressesOptic: Optic<'ctx>> MemoryGetter<'ctx> for AddressLens<AddressesOptic> {
|
||||
fn get(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pointer: PointerValue<'ctx>,
|
||||
name: &str,
|
||||
) -> Self::Value {
|
||||
self.review(ctx.builder.build_load(pointer, name).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, AddressesOptic: Optic<'ctx>> MemorySetter<'ctx> for AddressLens<AddressesOptic> {
|
||||
fn set(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pointer: PointerValue<'ctx>,
|
||||
value: &Self::Value,
|
||||
) {
|
||||
ctx.builder.build_store(pointer, value.address).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// To make [`Address`] convenient to use
|
||||
impl<'ctx, AddresseeOptic: MemoryGetter<'ctx>> Address<'ctx, AddresseeOptic> {
|
||||
pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> AddresseeOptic::Value {
|
||||
self.addressee_optic.get(ctx, self.address, name)
|
||||
}
|
||||
}
|
||||
|
||||
// To make [`Address`] convenient to use
|
||||
impl<'ctx, AddresseeOptic: MemorySetter<'ctx>> Address<'ctx, AddresseeOptic> {
|
||||
pub fn set(&self, ctx: &CodeGenContext<'ctx, '_>, value: &AddresseeOptic::Value) {
|
||||
self.addressee_optic.set(ctx, self.address, value)
|
||||
}
|
||||
}
|
||||
|
||||
// ((Memory, Pointer) -> ElementOptic::Value*)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GepGetter<ElementOptic> {
|
||||
/// The LLVM GEP index
|
||||
pub gep_index: u32, // TODO: I think I'm not supposed to *just* use i32 for GEP like that
|
||||
/// Element (or field in the context of `struct`s) name. Used for cosmetics.
|
||||
pub name: &'static str,
|
||||
/// The lens to view the actual value after applying this [`FieldLens<T>`]
|
||||
pub element_optic: ElementOptic,
|
||||
}
|
||||
|
||||
impl<'ctx, ElementOptic: Optic<'ctx>> Optic<'ctx> for GepGetter<ElementOptic> {
|
||||
type Value = Address<'ctx, ElementOptic>;
|
||||
|
||||
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||
self.element_optic.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, ElementOptic: Optic<'ctx>> MemoryGetter<'ctx> for GepGetter<ElementOptic> {
|
||||
fn get(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pointer: PointerValue<'ctx>,
|
||||
name: &str,
|
||||
) -> Self::Value {
|
||||
let llvm_i32 = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to *just* use i32 for GEP like that
|
||||
let element_ptr = unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(
|
||||
pointer,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(self.gep_index as u64, false)],
|
||||
name,
|
||||
)
|
||||
.unwrap()
|
||||
};
|
||||
Address { address: element_ptr, addressee_optic: self.element_optic.clone() }
|
||||
}
|
||||
}
|
||||
|
||||
// Only used by [`FieldBuilder`]
|
||||
#[derive(Debug)]
|
||||
struct FieldInfo<'ctx> {
|
||||
gep_index: u32,
|
||||
name: &'ctx str,
|
||||
llvm_type: BasicTypeEnum<'ctx>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FieldBuilder<'ctx> {
|
||||
pub ctx: &'ctx Context,
|
||||
gep_index_counter: u32,
|
||||
struct_name: &'ctx str,
|
||||
fields: Vec<FieldInfo<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> FieldBuilder<'ctx> {
|
||||
pub fn new(ctx: &'ctx Context, struct_name: &'ctx str) -> Self {
|
||||
FieldBuilder { ctx, gep_index_counter: 0, struct_name, fields: Vec::new() }
|
||||
}
|
||||
|
||||
fn next_gep_index(&mut self) -> u32 {
|
||||
let index = self.gep_index_counter;
|
||||
self.gep_index_counter += 1;
|
||||
index
|
||||
}
|
||||
|
||||
pub fn add_field<ElementOptic: Optic<'ctx>>(
|
||||
&mut self,
|
||||
name: &'static str,
|
||||
element_optic: ElementOptic,
|
||||
) -> GepGetter<ElementOptic> {
|
||||
let gep_index = self.next_gep_index();
|
||||
|
||||
self.fields.push(FieldInfo {
|
||||
gep_index,
|
||||
name,
|
||||
llvm_type: element_optic.get_llvm_type(self.ctx),
|
||||
});
|
||||
|
||||
GepGetter { gep_index, name, element_optic }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait StructureOptic<'ctx>: Clone {
|
||||
// Fields of optics
|
||||
type Fields;
|
||||
|
||||
// TODO: Make it an associated function instead?
|
||||
fn struct_name(&self) -> &'static str;
|
||||
|
||||
fn build_fields(&self, builder: &mut FieldBuilder<'ctx>) -> Self::Fields;
|
||||
|
||||
fn get_fields(&self, ctx: &'ctx Context) -> Self::Fields {
|
||||
let mut builder = FieldBuilder::new(ctx, self.struct_name());
|
||||
self.build_fields(&mut builder)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpticalStructValue<'ctx, StructOptic> {
|
||||
optic: StructOptic,
|
||||
llvm: StructValue<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx, StructOptic> OpticValue<'ctx> for OpticalStructValue<'ctx, StructOptic> {
|
||||
fn get_llvm_value(&self) -> BasicValueEnum<'ctx> {
|
||||
self.llvm.as_basic_value_enum()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: check StructType
|
||||
impl<'ctx, T: StructureOptic<'ctx>> Optic<'ctx> for T {
|
||||
type Value = OpticalStructValue<'ctx, Self>;
|
||||
|
||||
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||
let mut builder = FieldBuilder::new(ctx, self.struct_name());
|
||||
self.build_fields(&mut builder); // Self::Fields is discarded
|
||||
|
||||
let field_types =
|
||||
builder.fields.iter().map(|field_info| field_info.llvm_type).collect_vec();
|
||||
ctx.struct_type(&field_types, false).as_basic_type_enum()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, T: StructureOptic<'ctx>> MemoryGetter<'ctx> for T {
|
||||
fn get(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pointer: PointerValue<'ctx>,
|
||||
name: &str,
|
||||
) -> Self::Value {
|
||||
OpticalStructValue {
|
||||
optic: self.clone(),
|
||||
llvm: ctx.builder.build_load(pointer, name).unwrap().into_struct_value(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, T: StructureOptic<'ctx>> MemorySetter<'ctx> for T {
|
||||
fn set(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pointer: PointerValue<'ctx>,
|
||||
value: &Self::Value,
|
||||
) {
|
||||
ctx.builder.build_store(pointer, value.llvm).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, AddresseeOptic: StructureOptic<'ctx>> Address<'ctx, AddresseeOptic> {
|
||||
pub fn view<GetFieldGepFn, FieldElementOptic: Optic<'ctx>>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
get_field_gep_fn: GetFieldGepFn,
|
||||
) -> Address<'ctx, FieldElementOptic>
|
||||
where
|
||||
GetFieldGepFn: FnOnce(&AddresseeOptic::Fields) -> &GepGetter<FieldElementOptic>,
|
||||
{
|
||||
let fields = self.addressee_optic.get_fields(ctx.ctx);
|
||||
let field = get_field_gep_fn(&fields);
|
||||
field.get(ctx, self.address, field.name)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user