forked from M-Labs/nac3
1
0
Fork 0

Compare commits

...

1 Commits

Author SHA1 Message Date
abdul124 d30daf9ded core: extended base type enum with opaque pointer 2024-06-28 17:29:01 +08:00
3 changed files with 143 additions and 18 deletions

View File

@ -159,9 +159,11 @@
# development tools # development tools
cargo-insta cargo-insta
clippy clippy
pre-commit
rustfmt rustfmt
rust-analyzer
]; ];
# https://nixos.wiki/wiki/Rust#Shell.nix_example
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";
@ -185,4 +187,4 @@
extra-trusted-public-keys = "nixbld.m-labs.hk-1:5aSRVA5b320xbNvu30tqxVPXpld73bhtOeH6uAjRyHc="; extra-trusted-public-keys = "nixbld.m-labs.hk-1:5aSRVA5b320xbNvu30tqxVPXpld73bhtOeH6uAjRyHc=";
extra-substituters = "https://nixbld.m-labs.hk"; extra-substituters = "https://nixbld.m-labs.hk";
}; };
} }

View File

@ -0,0 +1,92 @@
use inkwell::llvm_sys::prelude::LLVMTypeRef;
use inkwell::types::{AsTypeRef, BasicTypeEnum, PointerType};
#[derive(Debug)]
pub struct OpaquePointerType<'ctx> {
pub ptr_ty: PointerType<'ctx>,
pub inner_ty: Box<Option<ExtendedTypeEnum<'ctx>>>,
}
#[derive(Debug)]
pub enum ExtendedTypeEnum<'ctx> {
BasicEnum(BasicTypeEnum<'ctx>),
OpaquePointer(OpaquePointerType<'ctx>),
}
unsafe impl AsTypeRef for ExtendedTypeEnum<'_> {
fn as_type_ref(&self) -> LLVMTypeRef {
match *self {
ExtendedTypeEnum::OpaquePointer(_) => panic!("Opaque Pointer Reference is not allowed"),
ExtendedTypeEnum::BasicEnum(t) => t.as_type_ref(),
}
}
}
impl ExtendedTypeEnum<'_> {
pub fn get_type(&self) -> BasicTypeEnum<'_> {
match self {
ExtendedTypeEnum::BasicEnum(t) => t.clone(),
ExtendedTypeEnum::OpaquePointer(t) => t.ptr_ty.clone().into(),
}
}
}
impl<'ctx> From<OpaquePointerType<'ctx>> for ExtendedTypeEnum<'ctx> {
fn from(value: OpaquePointerType) -> ExtendedTypeEnum {
ExtendedTypeEnum::OpaquePointer(value)
}
}
impl<'ctx> From<BasicTypeEnum<'ctx>> for ExtendedTypeEnum<'ctx> {
fn from(value: BasicTypeEnum) -> ExtendedTypeEnum {
ExtendedTypeEnum::BasicEnum(value)
}
}
impl<'ctx> TryFrom<ExtendedTypeEnum<'ctx>> for OpaquePointerType<'ctx> {
type Error = ();
fn try_from(value: ExtendedTypeEnum<'ctx>) -> Result<Self, Self::Error> {
match value {
ExtendedTypeEnum::OpaquePointer(ty) => Ok(ty),
_ => Err(()),
}
}
}
impl<'ctx> TryFrom<ExtendedTypeEnum<'ctx>> for BasicTypeEnum<'ctx> {
type Error = ();
fn try_from(value: ExtendedTypeEnum<'ctx>) -> Result<Self, Self::Error> {
match value {
ExtendedTypeEnum::BasicEnum(ty) => Ok(ty),
_ => Err(()),
}
}
}
impl<'ctx> ExtendedTypeEnum<'ctx> {
pub fn into_basic_type(self) -> BasicTypeEnum<'ctx> {
if let ExtendedTypeEnum::BasicEnum(t) = self {
t
} else {
panic!("Found {:?} but expected the ArrayType variant", self);
}
}
pub fn into_opaque_pointer(self) -> OpaquePointerType<'ctx> {
if let ExtendedTypeEnum::OpaquePointer(t) = self {
t
} else {
panic!("Found {:?} but expected the ArrayType variant", self);
}
}
pub fn is_basic_enum(self) -> bool {
matches!(self, ExtendedTypeEnum::BasicEnum(_))
}
pub fn is_opaque_pointer(self) -> bool {
matches!(self, ExtendedTypeEnum::OpaquePointer(_))
}
}

View File

@ -8,6 +8,8 @@ use crate::{
}, },
}; };
use crossbeam::channel::{unbounded, Receiver, Sender}; use crossbeam::channel::{unbounded, Receiver, Sender};
use enums::OpaquePointerType;
use crate::codegen::enums::ExtendedTypeEnum;
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock, basic_block::BasicBlock,
@ -26,7 +28,7 @@ use inkwell::{
use itertools::Itertools; use itertools::Itertools;
use nac3parser::ast::{Location, Stmt, StrRef}; use nac3parser::ast::{Location, Stmt, StrRef};
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet}; use std::{borrow::{Borrow, BorrowMut}, collections::{HashMap, HashSet}};
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
@ -43,10 +45,12 @@ pub mod irrt;
pub mod llvm_intrinsics; pub mod llvm_intrinsics;
pub mod numpy; pub mod numpy;
pub mod stmt; pub mod stmt;
pub mod enums;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator}; pub use generator::{CodeGenerator, DefaultCodeGenerator};
@ -418,6 +422,7 @@ pub struct CodeGenTask {
pub id: usize, pub id: usize,
} }
/// Retrieves the [LLVM type][BasicTypeEnum] corresponding to the [Type]. /// Retrieves the [LLVM type][BasicTypeEnum] corresponding to the [Type].
/// ///
/// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable /// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable
@ -431,11 +436,13 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
ty: Type, ty: Type,
) -> BasicTypeEnum<'ctx> { ) -> ExtendedTypeEnum<'ctx> {
use TypeEnum::*; use TypeEnum::*;
// we assume the type cache should already contain primitive types, // we assume the type cache should already contain primitive types,
// and they should be passed by value instead of passing as pointer. // and they should be passed by value instead of passing as pointer.
type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| { if let Some(ty) = type_cache.get(&unifier.get_representative(ty)).copied(){
ExtendedTypeEnum::BasicEnum(ty)
} else {
let ty_enum = unifier.get_ty(ty); let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum { let result = match &*ty_enum {
TObj { obj_id, fields, .. } => { TObj { obj_id, fields, .. } => {
@ -443,7 +450,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
if PrimDef::contains_id(*obj_id) { if PrimDef::contains_id(*obj_id) {
return match &*unifier.get_ty_immutable(ty) { return match &*unifier.get_ty_immutable(ty) {
TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => { TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => {
get_llvm_type( let ty = get_llvm_type(
ctx, ctx,
module, module,
generator, generator,
@ -451,9 +458,18 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
top_level, top_level,
type_cache, type_cache,
*params.iter().next().unwrap().1, *params.iter().next().unwrap().1,
) );
.ptr_type(AddressSpace::default()) let inner_ty = match ty.into() {
.into() ExtendedTypeEnum::BasicEnum(t) => Some(ExtendedTypeEnum::BasicEnum(t)),
ExtendedTypeEnum::OpaquePointer(t) => *t.inner_ty,
};
ExtendedTypeEnum::OpaquePointer(OpaquePointerType{
ptr_ty: ty.get_type().ptr_type(AddressSpace::default()).into(),
inner_ty: Box::new(inner_ty),
})
// ty.ptr_type(AddressSpace::default()).into()
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
@ -461,8 +477,11 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let element_type = get_llvm_type( let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype, ctx, module, generator, unifier, top_level, type_cache, dtype,
); );
// Assuming it is BasicType for now
ExtendedTypeEnum::BasicEnum(NDArrayType::new(generator, ctx, element_type.get_type().clone().to_owned()).as_base_type().into())
NDArrayType::new(generator, ctx, element_type).as_base_type().into() // NDArrayType::new(generator, ctx, element_type).as_base_type().into()
} }
_ => unreachable!( _ => unreachable!(
@ -480,7 +499,11 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let name = unifier.stringify(ty); let name = unifier.stringify(ty);
let ty = if let Some(t) = module.get_struct_type(&name) { let ty = if let Some(t) = module.get_struct_type(&name) {
t.ptr_type(AddressSpace::default()).into() ExtendedTypeEnum::OpaquePointer(OpaquePointerType{
ptr_ty: t.ptr_type(AddressSpace::default()).into(),
inner_ty: Box::new(Some(ExtendedTypeEnum::BasicEnum(t.into()))),
})
// t.ptr_type(AddressSpace::default()).into()
} else { } else {
let struct_type = ctx.opaque_struct_type(&name); let struct_type = ctx.opaque_struct_type(&name);
type_cache.insert( type_cache.insert(
@ -498,11 +521,15 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
top_level, top_level,
type_cache, type_cache,
fields[&f.0].0, fields[&f.0].0,
) ).get_type()
}) })
.collect_vec(); .collect_vec();
struct_type.set_body(&fields, false); struct_type.set_body(&fields, false);
struct_type.ptr_type(AddressSpace::default()).into() ExtendedTypeEnum::OpaquePointer(OpaquePointerType{
ptr_ty: struct_type.ptr_type(AddressSpace::default()).into(),
inner_ty: Box::new(Some(ExtendedTypeEnum::BasicEnum(struct_type.into())))
})
// struct_type.ptr_type(AddressSpace::default()).into()
}; };
return ty; return ty;
} }
@ -511,23 +538,27 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let fields = ty let fields = ty
.iter() .iter()
.map(|ty| { .map(|ty| {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty).get_type()
}) })
.collect_vec(); .collect_vec();
ctx.struct_type(&fields, false).into() ExtendedTypeEnum::BasicEnum(ctx.struct_type(&fields, false).into())
// ctx.struct_type(&fields, false).into()
} }
TList { ty } => { TList { ty } => {
let element_type = let element_type =
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty); get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty);
ListType::new(generator, ctx, element_type).as_base_type().into() // Assuming same as numpy
ExtendedTypeEnum::BasicEnum(ListType::new(generator, ctx, element_type.get_type()).as_base_type().into())
// ListType::new(generator, ctx, element_type).as_base_type().into()
} }
TVirtual { .. } => unimplemented!(), TVirtual { .. } => unimplemented!(),
_ => unreachable!("{}", ty_enum.get_type_name()), _ => unreachable!("{}", ty_enum.get_type_name()),
}; };
type_cache.insert(unifier.get_representative(ty), result); type_cache.insert(unifier.get_representative(ty), result.get_type());
// type_cache.insert(unifier.get_representative(ty), result);
result result
}) }
} }
/// Retrieves the [LLVM type][`BasicTypeEnum`] corresponding to the [`Type`]. /// Retrieves the [LLVM type][`BasicTypeEnum`] corresponding to the [`Type`].