1
0
forked from M-Labs/nac3

core: extended base type enum with opaque pointer

This commit is contained in:
abdul124 2024-06-28 17:29:01 +08:00
parent 9808923258
commit d30daf9ded
3 changed files with 143 additions and 18 deletions

View File

@ -159,9 +159,11 @@
# development tools
cargo-insta
clippy
pre-commit
rustfmt
rust-analyzer
];
# https://nixos.wiki/wiki/Rust#Shell.nix_example
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
};
devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2";

View File

@ -0,0 +1,92 @@
use inkwell::llvm_sys::prelude::LLVMTypeRef;
use inkwell::types::{AsTypeRef, BasicTypeEnum, PointerType};
#[derive(Debug)]
pub struct OpaquePointerType<'ctx> {
pub ptr_ty: PointerType<'ctx>,
pub inner_ty: Box<Option<ExtendedTypeEnum<'ctx>>>,
}
#[derive(Debug)]
pub enum ExtendedTypeEnum<'ctx> {
BasicEnum(BasicTypeEnum<'ctx>),
OpaquePointer(OpaquePointerType<'ctx>),
}
unsafe impl AsTypeRef for ExtendedTypeEnum<'_> {
fn as_type_ref(&self) -> LLVMTypeRef {
match *self {
ExtendedTypeEnum::OpaquePointer(_) => panic!("Opaque Pointer Reference is not allowed"),
ExtendedTypeEnum::BasicEnum(t) => t.as_type_ref(),
}
}
}
impl ExtendedTypeEnum<'_> {
pub fn get_type(&self) -> BasicTypeEnum<'_> {
match self {
ExtendedTypeEnum::BasicEnum(t) => t.clone(),
ExtendedTypeEnum::OpaquePointer(t) => t.ptr_ty.clone().into(),
}
}
}
impl<'ctx> From<OpaquePointerType<'ctx>> for ExtendedTypeEnum<'ctx> {
fn from(value: OpaquePointerType) -> ExtendedTypeEnum {
ExtendedTypeEnum::OpaquePointer(value)
}
}
impl<'ctx> From<BasicTypeEnum<'ctx>> for ExtendedTypeEnum<'ctx> {
fn from(value: BasicTypeEnum) -> ExtendedTypeEnum {
ExtendedTypeEnum::BasicEnum(value)
}
}
impl<'ctx> TryFrom<ExtendedTypeEnum<'ctx>> for OpaquePointerType<'ctx> {
type Error = ();
fn try_from(value: ExtendedTypeEnum<'ctx>) -> Result<Self, Self::Error> {
match value {
ExtendedTypeEnum::OpaquePointer(ty) => Ok(ty),
_ => Err(()),
}
}
}
impl<'ctx> TryFrom<ExtendedTypeEnum<'ctx>> for BasicTypeEnum<'ctx> {
type Error = ();
fn try_from(value: ExtendedTypeEnum<'ctx>) -> Result<Self, Self::Error> {
match value {
ExtendedTypeEnum::BasicEnum(ty) => Ok(ty),
_ => Err(()),
}
}
}
impl<'ctx> ExtendedTypeEnum<'ctx> {
pub fn into_basic_type(self) -> BasicTypeEnum<'ctx> {
if let ExtendedTypeEnum::BasicEnum(t) = self {
t
} else {
panic!("Found {:?} but expected the ArrayType variant", self);
}
}
pub fn into_opaque_pointer(self) -> OpaquePointerType<'ctx> {
if let ExtendedTypeEnum::OpaquePointer(t) = self {
t
} else {
panic!("Found {:?} but expected the ArrayType variant", self);
}
}
pub fn is_basic_enum(self) -> bool {
matches!(self, ExtendedTypeEnum::BasicEnum(_))
}
pub fn is_opaque_pointer(self) -> bool {
matches!(self, ExtendedTypeEnum::OpaquePointer(_))
}
}

View File

@ -8,6 +8,8 @@ use crate::{
},
};
use crossbeam::channel::{unbounded, Receiver, Sender};
use enums::OpaquePointerType;
use crate::codegen::enums::ExtendedTypeEnum;
use inkwell::{
attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock,
@ -26,7 +28,7 @@ use inkwell::{
use itertools::Itertools;
use nac3parser::ast::{Location, Stmt, StrRef};
use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet};
use std::{borrow::{Borrow, BorrowMut}, collections::{HashMap, HashSet}};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
@ -43,10 +45,12 @@ pub mod irrt;
pub mod llvm_intrinsics;
pub mod numpy;
pub mod stmt;
pub mod enums;
#[cfg(test)]
mod test;
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};
@ -418,6 +422,7 @@ pub struct CodeGenTask {
pub id: usize,
}
/// Retrieves the [LLVM type][BasicTypeEnum] corresponding to the [Type].
///
/// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable
@ -431,11 +436,13 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
ty: Type,
) -> BasicTypeEnum<'ctx> {
) -> ExtendedTypeEnum<'ctx> {
use TypeEnum::*;
// we assume the type cache should already contain primitive types,
// and they should be passed by value instead of passing as pointer.
type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| {
if let Some(ty) = type_cache.get(&unifier.get_representative(ty)).copied(){
ExtendedTypeEnum::BasicEnum(ty)
} else {
let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum {
TObj { obj_id, fields, .. } => {
@ -443,7 +450,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
if PrimDef::contains_id(*obj_id) {
return match &*unifier.get_ty_immutable(ty) {
TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => {
get_llvm_type(
let ty = get_llvm_type(
ctx,
module,
generator,
@ -451,9 +458,18 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
top_level,
type_cache,
*params.iter().next().unwrap().1,
)
.ptr_type(AddressSpace::default())
.into()
);
let inner_ty = match ty.into() {
ExtendedTypeEnum::BasicEnum(t) => Some(ExtendedTypeEnum::BasicEnum(t)),
ExtendedTypeEnum::OpaquePointer(t) => *t.inner_ty,
};
ExtendedTypeEnum::OpaquePointer(OpaquePointerType{
ptr_ty: ty.get_type().ptr_type(AddressSpace::default()).into(),
inner_ty: Box::new(inner_ty),
})
// ty.ptr_type(AddressSpace::default()).into()
}
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
@ -462,7 +478,10 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx, module, generator, unifier, top_level, type_cache, dtype,
);
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
// Assuming it is BasicType for now
ExtendedTypeEnum::BasicEnum(NDArrayType::new(generator, ctx, element_type.get_type().clone().to_owned()).as_base_type().into())
// NDArrayType::new(generator, ctx, element_type).as_base_type().into()
}
_ => unreachable!(
@ -480,7 +499,11 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let name = unifier.stringify(ty);
let ty = if let Some(t) = module.get_struct_type(&name) {
t.ptr_type(AddressSpace::default()).into()
ExtendedTypeEnum::OpaquePointer(OpaquePointerType{
ptr_ty: t.ptr_type(AddressSpace::default()).into(),
inner_ty: Box::new(Some(ExtendedTypeEnum::BasicEnum(t.into()))),
})
// t.ptr_type(AddressSpace::default()).into()
} else {
let struct_type = ctx.opaque_struct_type(&name);
type_cache.insert(
@ -498,11 +521,15 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
top_level,
type_cache,
fields[&f.0].0,
)
).get_type()
})
.collect_vec();
struct_type.set_body(&fields, false);
struct_type.ptr_type(AddressSpace::default()).into()
ExtendedTypeEnum::OpaquePointer(OpaquePointerType{
ptr_ty: struct_type.ptr_type(AddressSpace::default()).into(),
inner_ty: Box::new(Some(ExtendedTypeEnum::BasicEnum(struct_type.into())))
})
// struct_type.ptr_type(AddressSpace::default()).into()
};
return ty;
}
@ -511,23 +538,27 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let fields = ty
.iter()
.map(|ty| {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty).get_type()
})
.collect_vec();
ctx.struct_type(&fields, false).into()
ExtendedTypeEnum::BasicEnum(ctx.struct_type(&fields, false).into())
// ctx.struct_type(&fields, false).into()
}
TList { ty } => {
let element_type =
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty);
ListType::new(generator, ctx, element_type).as_base_type().into()
// Assuming same as numpy
ExtendedTypeEnum::BasicEnum(ListType::new(generator, ctx, element_type.get_type()).as_base_type().into())
// ListType::new(generator, ctx, element_type).as_base_type().into()
}
TVirtual { .. } => unimplemented!(),
_ => unreachable!("{}", ty_enum.get_type_name()),
};
type_cache.insert(unifier.get_representative(ty), result);
type_cache.insert(unifier.get_representative(ty), result.get_type());
// type_cache.insert(unifier.get_representative(ty), result);
result
})
}
}
/// Retrieves the [LLVM type][`BasicTypeEnum`] corresponding to the [`Type`].