forked from M-Labs/nac3
Compare commits
1 Commits
ndarray-st
...
issue-313
Author | SHA1 | Date | |
---|---|---|---|
d30daf9ded |
@ -159,9 +159,11 @@
|
||||
# development tools
|
||||
cargo-insta
|
||||
clippy
|
||||
pre-commit
|
||||
rustfmt
|
||||
rust-analyzer
|
||||
];
|
||||
# https://nixos.wiki/wiki/Rust#Shell.nix_example
|
||||
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
||||
};
|
||||
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
||||
name = "nac3-dev-shell-msys2";
|
||||
@ -185,4 +187,4 @@
|
||||
extra-trusted-public-keys = "nixbld.m-labs.hk-1:5aSRVA5b320xbNvu30tqxVPXpld73bhtOeH6uAjRyHc=";
|
||||
extra-substituters = "https://nixbld.m-labs.hk";
|
||||
};
|
||||
}
|
||||
}
|
92
nac3core/src/codegen/enums.rs
Normal file
92
nac3core/src/codegen/enums.rs
Normal file
@ -0,0 +1,92 @@
|
||||
use inkwell::llvm_sys::prelude::LLVMTypeRef;
|
||||
use inkwell::types::{AsTypeRef, BasicTypeEnum, PointerType};
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OpaquePointerType<'ctx> {
|
||||
pub ptr_ty: PointerType<'ctx>,
|
||||
pub inner_ty: Box<Option<ExtendedTypeEnum<'ctx>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ExtendedTypeEnum<'ctx> {
|
||||
BasicEnum(BasicTypeEnum<'ctx>),
|
||||
OpaquePointer(OpaquePointerType<'ctx>),
|
||||
}
|
||||
|
||||
unsafe impl AsTypeRef for ExtendedTypeEnum<'_> {
|
||||
fn as_type_ref(&self) -> LLVMTypeRef {
|
||||
match *self {
|
||||
ExtendedTypeEnum::OpaquePointer(_) => panic!("Opaque Pointer Reference is not allowed"),
|
||||
ExtendedTypeEnum::BasicEnum(t) => t.as_type_ref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtendedTypeEnum<'_> {
|
||||
pub fn get_type(&self) -> BasicTypeEnum<'_> {
|
||||
match self {
|
||||
ExtendedTypeEnum::BasicEnum(t) => t.clone(),
|
||||
ExtendedTypeEnum::OpaquePointer(t) => t.ptr_ty.clone().into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<OpaquePointerType<'ctx>> for ExtendedTypeEnum<'ctx> {
|
||||
fn from(value: OpaquePointerType) -> ExtendedTypeEnum {
|
||||
ExtendedTypeEnum::OpaquePointer(value)
|
||||
}
|
||||
}
|
||||
impl<'ctx> From<BasicTypeEnum<'ctx>> for ExtendedTypeEnum<'ctx> {
|
||||
fn from(value: BasicTypeEnum) -> ExtendedTypeEnum {
|
||||
ExtendedTypeEnum::BasicEnum(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> TryFrom<ExtendedTypeEnum<'ctx>> for OpaquePointerType<'ctx> {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(value: ExtendedTypeEnum<'ctx>) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
ExtendedTypeEnum::OpaquePointer(ty) => Ok(ty),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'ctx> TryFrom<ExtendedTypeEnum<'ctx>> for BasicTypeEnum<'ctx> {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(value: ExtendedTypeEnum<'ctx>) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
ExtendedTypeEnum::BasicEnum(ty) => Ok(ty),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<'ctx> ExtendedTypeEnum<'ctx> {
|
||||
pub fn into_basic_type(self) -> BasicTypeEnum<'ctx> {
|
||||
if let ExtendedTypeEnum::BasicEnum(t) = self {
|
||||
t
|
||||
} else {
|
||||
panic!("Found {:?} but expected the ArrayType variant", self);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_opaque_pointer(self) -> OpaquePointerType<'ctx> {
|
||||
if let ExtendedTypeEnum::OpaquePointer(t) = self {
|
||||
t
|
||||
} else {
|
||||
panic!("Found {:?} but expected the ArrayType variant", self);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_basic_enum(self) -> bool {
|
||||
matches!(self, ExtendedTypeEnum::BasicEnum(_))
|
||||
}
|
||||
|
||||
pub fn is_opaque_pointer(self) -> bool {
|
||||
matches!(self, ExtendedTypeEnum::OpaquePointer(_))
|
||||
}
|
||||
}
|
@ -8,6 +8,8 @@ use crate::{
|
||||
},
|
||||
};
|
||||
use crossbeam::channel::{unbounded, Receiver, Sender};
|
||||
use enums::OpaquePointerType;
|
||||
use crate::codegen::enums::ExtendedTypeEnum;
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
basic_block::BasicBlock,
|
||||
@ -26,7 +28,7 @@ use inkwell::{
|
||||
use itertools::Itertools;
|
||||
use nac3parser::ast::{Location, Stmt, StrRef};
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::{borrow::{Borrow, BorrowMut}, collections::{HashMap, HashSet}};
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
@ -43,10 +45,12 @@ pub mod irrt;
|
||||
pub mod llvm_intrinsics;
|
||||
pub mod numpy;
|
||||
pub mod stmt;
|
||||
pub mod enums;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
|
||||
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
||||
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
||||
|
||||
@ -418,6 +422,7 @@ pub struct CodeGenTask {
|
||||
pub id: usize,
|
||||
}
|
||||
|
||||
|
||||
/// Retrieves the [LLVM type][BasicTypeEnum] corresponding to the [Type].
|
||||
///
|
||||
/// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable
|
||||
@ -431,11 +436,13 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
top_level: &TopLevelContext,
|
||||
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||
ty: Type,
|
||||
) -> BasicTypeEnum<'ctx> {
|
||||
) -> ExtendedTypeEnum<'ctx> {
|
||||
use TypeEnum::*;
|
||||
// we assume the type cache should already contain primitive types,
|
||||
// and they should be passed by value instead of passing as pointer.
|
||||
type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| {
|
||||
if let Some(ty) = type_cache.get(&unifier.get_representative(ty)).copied(){
|
||||
ExtendedTypeEnum::BasicEnum(ty)
|
||||
} else {
|
||||
let ty_enum = unifier.get_ty(ty);
|
||||
let result = match &*ty_enum {
|
||||
TObj { obj_id, fields, .. } => {
|
||||
@ -443,7 +450,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
if PrimDef::contains_id(*obj_id) {
|
||||
return match &*unifier.get_ty_immutable(ty) {
|
||||
TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => {
|
||||
get_llvm_type(
|
||||
let ty = get_llvm_type(
|
||||
ctx,
|
||||
module,
|
||||
generator,
|
||||
@ -451,9 +458,18 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
top_level,
|
||||
type_cache,
|
||||
*params.iter().next().unwrap().1,
|
||||
)
|
||||
.ptr_type(AddressSpace::default())
|
||||
.into()
|
||||
);
|
||||
let inner_ty = match ty.into() {
|
||||
ExtendedTypeEnum::BasicEnum(t) => Some(ExtendedTypeEnum::BasicEnum(t)),
|
||||
ExtendedTypeEnum::OpaquePointer(t) => *t.inner_ty,
|
||||
};
|
||||
ExtendedTypeEnum::OpaquePointer(OpaquePointerType{
|
||||
ptr_ty: ty.get_type().ptr_type(AddressSpace::default()).into(),
|
||||
inner_ty: Box::new(inner_ty),
|
||||
})
|
||||
|
||||
// ty.ptr_type(AddressSpace::default()).into()
|
||||
|
||||
}
|
||||
|
||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
@ -461,8 +477,11 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let element_type = get_llvm_type(
|
||||
ctx, module, generator, unifier, top_level, type_cache, dtype,
|
||||
);
|
||||
|
||||
// Assuming it is BasicType for now
|
||||
ExtendedTypeEnum::BasicEnum(NDArrayType::new(generator, ctx, element_type.get_type().clone().to_owned()).as_base_type().into())
|
||||
|
||||
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
|
||||
// NDArrayType::new(generator, ctx, element_type).as_base_type().into()
|
||||
}
|
||||
|
||||
_ => unreachable!(
|
||||
@ -480,7 +499,11 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
let name = unifier.stringify(ty);
|
||||
let ty = if let Some(t) = module.get_struct_type(&name) {
|
||||
t.ptr_type(AddressSpace::default()).into()
|
||||
ExtendedTypeEnum::OpaquePointer(OpaquePointerType{
|
||||
ptr_ty: t.ptr_type(AddressSpace::default()).into(),
|
||||
inner_ty: Box::new(Some(ExtendedTypeEnum::BasicEnum(t.into()))),
|
||||
})
|
||||
// t.ptr_type(AddressSpace::default()).into()
|
||||
} else {
|
||||
let struct_type = ctx.opaque_struct_type(&name);
|
||||
type_cache.insert(
|
||||
@ -498,11 +521,15 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
top_level,
|
||||
type_cache,
|
||||
fields[&f.0].0,
|
||||
)
|
||||
).get_type()
|
||||
})
|
||||
.collect_vec();
|
||||
struct_type.set_body(&fields, false);
|
||||
struct_type.ptr_type(AddressSpace::default()).into()
|
||||
ExtendedTypeEnum::OpaquePointer(OpaquePointerType{
|
||||
ptr_ty: struct_type.ptr_type(AddressSpace::default()).into(),
|
||||
inner_ty: Box::new(Some(ExtendedTypeEnum::BasicEnum(struct_type.into())))
|
||||
})
|
||||
// struct_type.ptr_type(AddressSpace::default()).into()
|
||||
};
|
||||
return ty;
|
||||
}
|
||||
@ -511,23 +538,27 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let fields = ty
|
||||
.iter()
|
||||
.map(|ty| {
|
||||
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
|
||||
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty).get_type()
|
||||
})
|
||||
.collect_vec();
|
||||
ctx.struct_type(&fields, false).into()
|
||||
ExtendedTypeEnum::BasicEnum(ctx.struct_type(&fields, false).into())
|
||||
// ctx.struct_type(&fields, false).into()
|
||||
}
|
||||
TList { ty } => {
|
||||
let element_type =
|
||||
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty);
|
||||
|
||||
ListType::new(generator, ctx, element_type).as_base_type().into()
|
||||
// Assuming same as numpy
|
||||
ExtendedTypeEnum::BasicEnum(ListType::new(generator, ctx, element_type.get_type()).as_base_type().into())
|
||||
// ListType::new(generator, ctx, element_type).as_base_type().into()
|
||||
}
|
||||
TVirtual { .. } => unimplemented!(),
|
||||
_ => unreachable!("{}", ty_enum.get_type_name()),
|
||||
};
|
||||
type_cache.insert(unifier.get_representative(ty), result);
|
||||
type_cache.insert(unifier.get_representative(ty), result.get_type());
|
||||
// type_cache.insert(unifier.get_representative(ty), result);
|
||||
result
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves the [LLVM type][`BasicTypeEnum`] corresponding to the [`Type`].
|
||||
|
Loading…
Reference in New Issue
Block a user