forked from M-Labs/nac3
[core] codegen: Implement Tuple{Type,Value}
This commit is contained in:
parent
822f9d33f8
commit
7d02f5833d
@ -1,6 +1,6 @@
|
||||
use inkwell::{
|
||||
types::BasicTypeEnum,
|
||||
values::{BasicValueEnum, IntValue, PointerValue},
|
||||
values::{BasicValue, BasicValueEnum, IntValue},
|
||||
FloatPredicate, IntPredicate, OptimizationLevel,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
@ -14,7 +14,7 @@ use super::{
|
||||
numpy,
|
||||
numpy::ndarray_elementwise_unaryop_impl,
|
||||
stmt::gen_for_callback_incrementing,
|
||||
types::ndarray::NDArrayType,
|
||||
types::{ndarray::NDArrayType, TupleType},
|
||||
values::{
|
||||
ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
||||
UntypedArrayLikeAccessor,
|
||||
@ -1868,34 +1868,6 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
||||
})
|
||||
}
|
||||
|
||||
/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it
|
||||
fn build_output_struct<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
out_matrices: &[BasicValueEnum<'ctx>],
|
||||
) -> PointerValue<'ctx> {
|
||||
let field_ty = out_matrices.iter().map(BasicValueEnum::get_type).collect_vec();
|
||||
let out_ty = ctx.ctx.struct_type(&field_ty, false);
|
||||
let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap();
|
||||
|
||||
for (i, v) in out_matrices.iter().enumerate() {
|
||||
unsafe {
|
||||
let ptr = ctx
|
||||
.builder
|
||||
.build_in_bounds_gep(
|
||||
out_ptr,
|
||||
&[
|
||||
ctx.ctx.i32_type().const_zero(),
|
||||
ctx.ctx.i32_type().const_int(i as u64, false),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder.build_store(ptr, *v).unwrap();
|
||||
}
|
||||
}
|
||||
out_ptr
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_cholesky` linalg function
|
||||
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
@ -1973,10 +1945,11 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||
None,
|
||||
);
|
||||
|
||||
let q = q.as_base_value().into();
|
||||
let r = r.as_base_value().into();
|
||||
let out_ptr = build_output_struct(ctx, &[q, r]);
|
||||
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
|
||||
let q = q.as_base_value().as_basic_value_enum();
|
||||
let r = r.as_base_value().as_basic_value_enum();
|
||||
let tuple = TupleType::new(generator, ctx.ctx, &[q.get_type(), r.get_type()])
|
||||
.construct_from_objects(ctx, [q, r], None);
|
||||
Ok(tuple.as_base_value().into())
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_svd` linalg function
|
||||
@ -2031,12 +2004,12 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||
None,
|
||||
);
|
||||
|
||||
let u = u.as_base_value().into();
|
||||
let s = s.as_base_value().into();
|
||||
let vh = vh.as_base_value().into();
|
||||
let out_ptr = build_output_struct(ctx, &[u, s, vh]);
|
||||
|
||||
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
|
||||
let u = u.as_base_value().as_basic_value_enum();
|
||||
let s = s.as_base_value().as_basic_value_enum();
|
||||
let vh = vh.as_base_value().as_basic_value_enum();
|
||||
let tuple = TupleType::new(generator, ctx.ctx, &[u.get_type(), s.get_type(), vh.get_type()])
|
||||
.construct_from_objects(ctx, [u, s, vh], None);
|
||||
Ok(tuple.as_base_value().into())
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_inv` linalg function
|
||||
@ -2158,10 +2131,11 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||
None,
|
||||
);
|
||||
|
||||
let l = l.as_base_value().into();
|
||||
let u = u.as_base_value().into();
|
||||
let out_ptr = build_output_struct(ctx, &[l, u]);
|
||||
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
|
||||
let l = l.as_base_value().as_basic_value_enum();
|
||||
let u = u.as_base_value().as_basic_value_enum();
|
||||
let tuple = TupleType::new(generator, ctx.ctx, &[l.get_type(), u.get_type()])
|
||||
.construct_from_objects(ctx, [l, u], None);
|
||||
Ok(tuple.as_base_value().into())
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_matrix_power` linalg function
|
||||
@ -2293,10 +2267,11 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||
None,
|
||||
);
|
||||
|
||||
let t = t.as_base_value().into();
|
||||
let z = z.as_base_value().into();
|
||||
let out_ptr = build_output_struct(ctx, &[t, z]);
|
||||
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
|
||||
let t = t.as_base_value().as_basic_value_enum();
|
||||
let z = z.as_base_value().as_basic_value_enum();
|
||||
let tuple = TupleType::new(generator, ctx.ctx, &[t.get_type(), z.get_type()])
|
||||
.construct_from_objects(ctx, [t, z], None);
|
||||
Ok(tuple.as_base_value().into())
|
||||
}
|
||||
|
||||
/// Invokes the `sp_linalg_hessenberg` linalg function
|
||||
@ -2337,8 +2312,9 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
||||
None,
|
||||
);
|
||||
|
||||
let h = h.as_base_value().into();
|
||||
let q = q.as_base_value().into();
|
||||
let out_ptr = build_output_struct(ctx, &[h, q]);
|
||||
Ok(ctx.builder.build_load(out_ptr, "Hessenberg_decomposition_result").map(Into::into).unwrap())
|
||||
let h = h.as_base_value().as_basic_value_enum();
|
||||
let q = q.as_base_value().as_basic_value_enum();
|
||||
let tuple = TupleType::new(generator, ctx.ctx, &[h.get_type(), q.get_type()])
|
||||
.construct_from_objects(ctx, [h, q], None);
|
||||
Ok(tuple.as_base_value().into())
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ use crate::{
|
||||
};
|
||||
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
||||
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
||||
use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType};
|
||||
use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, TupleType};
|
||||
|
||||
pub mod builtin_fns;
|
||||
pub mod concrete_type;
|
||||
@ -574,7 +574,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
|
||||
})
|
||||
.collect_vec();
|
||||
ctx.struct_type(&fields, false).into()
|
||||
TupleType::new(generator, ctx, &fields).as_base_type().into()
|
||||
}
|
||||
TVirtual { .. } => unimplemented!(),
|
||||
_ => unreachable!("{}", ty_enum.get_type_name()),
|
||||
|
@ -28,11 +28,13 @@ use super::{
|
||||
};
|
||||
pub use list::*;
|
||||
pub use range::*;
|
||||
pub use tuple::*;
|
||||
|
||||
mod list;
|
||||
pub mod ndarray;
|
||||
mod range;
|
||||
pub mod structure;
|
||||
mod tuple;
|
||||
pub mod utils;
|
||||
|
||||
/// A LLVM type that is used to represent a corresponding type in NAC3.
|
||||
|
184
nac3core/src/codegen/types/tuple.rs
Normal file
184
nac3core/src/codegen/types/tuple.rs
Normal file
@ -0,0 +1,184 @@
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
types::{BasicType, BasicTypeEnum, IntType, StructType},
|
||||
values::BasicValueEnum,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::ProxyType;
|
||||
use crate::{
|
||||
codegen::{
|
||||
values::{ProxyValue, TupleValue},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct TupleType<'ctx> {
|
||||
ty: StructType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx> TupleType<'ctx> {
|
||||
/// Checks whether `llvm_ty` represents any tuple type, returning [Err] if it does not.
|
||||
pub fn is_representable(_value: StructType<'ctx>) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Creates an LLVM type corresponding to the expected structure of a tuple.
|
||||
#[must_use]
|
||||
fn llvm_type(ctx: &'ctx Context, tys: &[BasicTypeEnum<'ctx>]) -> StructType<'ctx> {
|
||||
ctx.struct_type(tys, false)
|
||||
}
|
||||
|
||||
/// Creates an instance of [`TupleType`].
|
||||
#[must_use]
|
||||
pub fn new<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
tys: &[BasicTypeEnum<'ctx>],
|
||||
) -> Self {
|
||||
let llvm_usize = generator.get_size_type(ctx);
|
||||
let llvm_tuple = Self::llvm_type(ctx, tys);
|
||||
|
||||
Self { ty: llvm_tuple, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an [`TupleType`] from a [unifier type][Type].
|
||||
#[must_use]
|
||||
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: Type,
|
||||
) -> Self {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
// Sanity check on object type.
|
||||
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else {
|
||||
panic!("Expected type to be a TypeEnum::TTuple, got {}", ctx.unifier.stringify(ty));
|
||||
};
|
||||
|
||||
let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec();
|
||||
Self { ty: Self::llvm_type(ctx.ctx, &llvm_tys), llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an [`TupleType`] from a [`StructType`].
|
||||
#[must_use]
|
||||
pub fn from_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
debug_assert!(Self::is_representable(struct_ty).is_ok());
|
||||
|
||||
TupleType { ty: struct_ty, llvm_usize }
|
||||
}
|
||||
|
||||
/// Returns the number of elements present in this [`TupleType`].
|
||||
#[must_use]
|
||||
pub fn num_elements(&self) -> u32 {
|
||||
self.ty.count_fields()
|
||||
}
|
||||
|
||||
/// Returns the type of the tuple element at the given `index`, or [`None`] if `index` is out of
|
||||
/// range.
|
||||
#[must_use]
|
||||
pub fn type_at_index(&self, index: u32) -> Option<BasicTypeEnum<'ctx>> {
|
||||
if index < self.num_elements() {
|
||||
Some(unsafe { self.type_at_index_unchecked(index) })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the type of the tuple element at the given `index`.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The caller must ensure that the index is valid.
|
||||
#[must_use]
|
||||
pub unsafe fn type_at_index_unchecked(&self, index: u32) -> BasicTypeEnum<'ctx> {
|
||||
self.ty.get_field_type_at_index_unchecked(index)
|
||||
}
|
||||
|
||||
/// Constructs a [`TupleValue`] from this type by zero-initializing the tuple value.
|
||||
#[must_use]
|
||||
pub fn construct(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
self.map_value(Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(), name)
|
||||
}
|
||||
|
||||
/// Constructs a [`TupleValue`] from `objects`. The resulting tuple preserves the order of
|
||||
/// objects.
|
||||
#[must_use]
|
||||
pub fn construct_from_objects<I: IntoIterator<Item = BasicValueEnum<'ctx>>>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
objects: I,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let values = objects.into_iter().collect_vec();
|
||||
|
||||
assert_eq!(values.len(), self.num_elements() as usize);
|
||||
assert!(values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.all(|(i, v)| { v.get_type() == unsafe { self.type_at_index_unchecked(i as u32) } }));
|
||||
|
||||
let mut value = self.construct(ctx, name);
|
||||
for (i, val) in values.into_iter().enumerate() {
|
||||
value.store_element(ctx, i as u32, val);
|
||||
}
|
||||
|
||||
value
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`ListValue`].
|
||||
#[must_use]
|
||||
pub fn map_value(
|
||||
&self,
|
||||
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_struct_value(value, self.llvm_usize, name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> {
|
||||
type Base = StructType<'ctx>;
|
||||
type Value = TupleValue<'ctx>;
|
||||
|
||||
fn is_type<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
llvm_ty: impl BasicType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
if let BasicTypeEnum::StructType(ty) = llvm_ty.as_basic_type_enum() {
|
||||
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
|
||||
} else {
|
||||
Err(format!("Expected struct type, got {llvm_ty:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_representable<G: CodeGenerator + ?Sized>(
|
||||
_generator: &G,
|
||||
_ctx: &'ctx Context,
|
||||
llvm_ty: Self::Base,
|
||||
) -> Result<(), String> {
|
||||
Self::is_representable(llvm_ty)
|
||||
}
|
||||
|
||||
fn alloca_type(&self) -> impl BasicType<'ctx> {
|
||||
self.as_base_type()
|
||||
}
|
||||
|
||||
fn as_base_type(&self) -> Self::Base {
|
||||
self.ty
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<TupleType<'ctx>> for StructType<'ctx> {
|
||||
fn from(value: TupleType<'ctx>) -> Self {
|
||||
value.as_base_type()
|
||||
}
|
||||
}
|
@ -5,11 +5,13 @@ use crate::codegen::CodeGenerator;
|
||||
pub use array::*;
|
||||
pub use list::*;
|
||||
pub use range::*;
|
||||
pub use tuple::*;
|
||||
|
||||
mod array;
|
||||
mod list;
|
||||
pub mod ndarray;
|
||||
mod range;
|
||||
mod tuple;
|
||||
pub mod utils;
|
||||
|
||||
/// A LLVM type that is used to represent a non-primitive value in NAC3.
|
||||
|
85
nac3core/src/codegen/values/tuple.rs
Normal file
85
nac3core/src/codegen/values/tuple.rs
Normal file
@ -0,0 +1,85 @@
|
||||
use inkwell::{
|
||||
types::IntType,
|
||||
values::{BasicValue, BasicValueEnum, StructValue},
|
||||
};
|
||||
|
||||
use super::ProxyValue;
|
||||
use crate::codegen::{types::TupleType, CodeGenContext};
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct TupleValue<'ctx> {
|
||||
value: StructValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
}
|
||||
|
||||
impl<'ctx> TupleValue<'ctx> {
|
||||
/// Checks whether `value` is an instance of `tuple`, returning [Err] if `value` is not an
|
||||
/// instance.
|
||||
pub fn is_representable(
|
||||
value: StructValue<'ctx>,
|
||||
_llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
TupleType::is_representable(value.get_type())
|
||||
}
|
||||
|
||||
/// Creates an [`TupleValue`] from a [`StructValue`].
|
||||
#[must_use]
|
||||
pub fn from_struct_value(
|
||||
value: StructValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
debug_assert!(Self::is_representable(value, llvm_usize).is_ok());
|
||||
|
||||
Self { value, llvm_usize, name }
|
||||
}
|
||||
|
||||
/// Stores a value into the tuple element at the given `index`.
|
||||
pub fn store_element(
|
||||
&mut self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
index: u32,
|
||||
element: impl BasicValue<'ctx>,
|
||||
) {
|
||||
assert_eq!(element.as_basic_value_enum().get_type(), unsafe {
|
||||
self.get_type().type_at_index_unchecked(index)
|
||||
});
|
||||
|
||||
let new_value = ctx
|
||||
.builder
|
||||
.build_insert_value(self.value, element, index, self.name.unwrap_or_default())
|
||||
.unwrap();
|
||||
self.value = new_value.into_struct_value();
|
||||
}
|
||||
|
||||
/// Loads a value from the tuple element at the given `index`.
|
||||
pub fn load_element(&self, ctx: &CodeGenContext<'ctx, '_>, index: u32) -> BasicValueEnum<'ctx> {
|
||||
ctx.builder
|
||||
.build_extract_value(
|
||||
self.value,
|
||||
index,
|
||||
&format!("{}[{{i}}]", self.name.unwrap_or("tuple")),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> {
|
||||
type Base = StructValue<'ctx>;
|
||||
type Type = TupleType<'ctx>;
|
||||
|
||||
fn get_type(&self) -> Self::Type {
|
||||
TupleType::from_type(self.as_base_value().get_type(), self.llvm_usize)
|
||||
}
|
||||
|
||||
fn as_base_value(&self) -> Self::Base {
|
||||
self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<TupleValue<'ctx>> for StructValue<'ctx> {
|
||||
fn from(value: TupleValue<'ctx>) -> Self {
|
||||
value.as_base_value()
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user