1
0
forked from M-Labs/nac3

[core] codegen: Implement Tuple{Type,Value}

This commit is contained in:
David Mak 2024-12-16 13:56:08 +08:00
parent 822f9d33f8
commit 7d02f5833d
6 changed files with 303 additions and 54 deletions

View File

@ -1,6 +1,6 @@
use inkwell::{ use inkwell::{
types::BasicTypeEnum, types::BasicTypeEnum,
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValue, BasicValueEnum, IntValue},
FloatPredicate, IntPredicate, OptimizationLevel, FloatPredicate, IntPredicate, OptimizationLevel,
}; };
use itertools::Itertools; use itertools::Itertools;
@ -14,7 +14,7 @@ use super::{
numpy, numpy,
numpy::ndarray_elementwise_unaryop_impl, numpy::ndarray_elementwise_unaryop_impl,
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
types::ndarray::NDArrayType, types::{ndarray::NDArrayType, TupleType},
values::{ values::{
ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
UntypedArrayLikeAccessor, UntypedArrayLikeAccessor,
@ -1868,34 +1868,6 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
}) })
} }
/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it
fn build_output_struct<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
out_matrices: &[BasicValueEnum<'ctx>],
) -> PointerValue<'ctx> {
let field_ty = out_matrices.iter().map(BasicValueEnum::get_type).collect_vec();
let out_ty = ctx.ctx.struct_type(&field_ty, false);
let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap();
for (i, v) in out_matrices.iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
out_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"",
)
.unwrap();
ctx.builder.build_store(ptr, *v).unwrap();
}
}
out_ptr
}
/// Invokes the `np_linalg_cholesky` linalg function /// Invokes the `np_linalg_cholesky` linalg function
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
@ -1973,10 +1945,11 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
None, None,
); );
let q = q.as_base_value().into(); let q = q.as_base_value().as_basic_value_enum();
let r = r.as_base_value().into(); let r = r.as_base_value().as_basic_value_enum();
let out_ptr = build_output_struct(ctx, &[q, r]); let tuple = TupleType::new(generator, ctx.ctx, &[q.get_type(), r.get_type()])
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) .construct_from_objects(ctx, [q, r], None);
Ok(tuple.as_base_value().into())
} }
/// Invokes the `np_linalg_svd` linalg function /// Invokes the `np_linalg_svd` linalg function
@ -2031,12 +2004,12 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
None, None,
); );
let u = u.as_base_value().into(); let u = u.as_base_value().as_basic_value_enum();
let s = s.as_base_value().into(); let s = s.as_base_value().as_basic_value_enum();
let vh = vh.as_base_value().into(); let vh = vh.as_base_value().as_basic_value_enum();
let out_ptr = build_output_struct(ctx, &[u, s, vh]); let tuple = TupleType::new(generator, ctx.ctx, &[u.get_type(), s.get_type(), vh.get_type()])
.construct_from_objects(ctx, [u, s, vh], None);
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) Ok(tuple.as_base_value().into())
} }
/// Invokes the `np_linalg_inv` linalg function /// Invokes the `np_linalg_inv` linalg function
@ -2158,10 +2131,11 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
None, None,
); );
let l = l.as_base_value().into(); let l = l.as_base_value().as_basic_value_enum();
let u = u.as_base_value().into(); let u = u.as_base_value().as_basic_value_enum();
let out_ptr = build_output_struct(ctx, &[l, u]); let tuple = TupleType::new(generator, ctx.ctx, &[l.get_type(), u.get_type()])
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap()) .construct_from_objects(ctx, [l, u], None);
Ok(tuple.as_base_value().into())
} }
/// Invokes the `np_linalg_matrix_power` linalg function /// Invokes the `np_linalg_matrix_power` linalg function
@ -2293,10 +2267,11 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
None, None,
); );
let t = t.as_base_value().into(); let t = t.as_base_value().as_basic_value_enum();
let z = z.as_base_value().into(); let z = z.as_base_value().as_basic_value_enum();
let out_ptr = build_output_struct(ctx, &[t, z]); let tuple = TupleType::new(generator, ctx.ctx, &[t.get_type(), z.get_type()])
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) .construct_from_objects(ctx, [t, z], None);
Ok(tuple.as_base_value().into())
} }
/// Invokes the `sp_linalg_hessenberg` linalg function /// Invokes the `sp_linalg_hessenberg` linalg function
@ -2337,8 +2312,9 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
None, None,
); );
let h = h.as_base_value().into(); let h = h.as_base_value().as_basic_value_enum();
let q = q.as_base_value().into(); let q = q.as_base_value().as_basic_value_enum();
let out_ptr = build_output_struct(ctx, &[h, q]); let tuple = TupleType::new(generator, ctx.ctx, &[h.get_type(), q.get_type()])
Ok(ctx.builder.build_load(out_ptr, "Hessenberg_decomposition_result").map(Into::into).unwrap()) .construct_from_objects(ctx, [h, q], None);
Ok(tuple.as_base_value().into())
} }

View File

@ -42,7 +42,7 @@ use crate::{
}; };
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator}; pub use generator::{CodeGenerator, DefaultCodeGenerator};
use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType}; use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, TupleType};
pub mod builtin_fns; pub mod builtin_fns;
pub mod concrete_type; pub mod concrete_type;
@ -574,7 +574,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
}) })
.collect_vec(); .collect_vec();
ctx.struct_type(&fields, false).into() TupleType::new(generator, ctx, &fields).as_base_type().into()
} }
TVirtual { .. } => unimplemented!(), TVirtual { .. } => unimplemented!(),
_ => unreachable!("{}", ty_enum.get_type_name()), _ => unreachable!("{}", ty_enum.get_type_name()),

View File

@ -28,11 +28,13 @@ use super::{
}; };
pub use list::*; pub use list::*;
pub use range::*; pub use range::*;
pub use tuple::*;
mod list; mod list;
pub mod ndarray; pub mod ndarray;
mod range; mod range;
pub mod structure; pub mod structure;
mod tuple;
pub mod utils; pub mod utils;
/// A LLVM type that is used to represent a corresponding type in NAC3. /// A LLVM type that is used to represent a corresponding type in NAC3.

View File

@ -0,0 +1,184 @@
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, IntType, StructType},
values::BasicValueEnum,
};
use itertools::Itertools;
use super::ProxyType;
use crate::{
codegen::{
values::{ProxyValue, TupleValue},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::{Type, TypeEnum},
};
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct TupleType<'ctx> {
ty: StructType<'ctx>,
llvm_usize: IntType<'ctx>,
}
impl<'ctx> TupleType<'ctx> {
/// Checks whether `llvm_ty` represents any tuple type, returning [Err] if it does not.
pub fn is_representable(_value: StructType<'ctx>) -> Result<(), String> {
Ok(())
}
/// Creates an LLVM type corresponding to the expected structure of a tuple.
#[must_use]
fn llvm_type(ctx: &'ctx Context, tys: &[BasicTypeEnum<'ctx>]) -> StructType<'ctx> {
ctx.struct_type(tys, false)
}
/// Creates an instance of [`TupleType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
tys: &[BasicTypeEnum<'ctx>],
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_tuple = Self::llvm_type(ctx, tys);
Self { ty: llvm_tuple, llvm_usize }
}
/// Creates an [`TupleType`] from a [unifier type][Type].
#[must_use]
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type,
) -> Self {
let llvm_usize = generator.get_size_type(ctx.ctx);
// Sanity check on object type.
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else {
panic!("Expected type to be a TypeEnum::TTuple, got {}", ctx.unifier.stringify(ty));
};
let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec();
Self { ty: Self::llvm_type(ctx.ctx, &llvm_tys), llvm_usize }
}
/// Creates an [`TupleType`] from a [`StructType`].
#[must_use]
pub fn from_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_representable(struct_ty).is_ok());
TupleType { ty: struct_ty, llvm_usize }
}
/// Returns the number of elements present in this [`TupleType`].
#[must_use]
pub fn num_elements(&self) -> u32 {
self.ty.count_fields()
}
/// Returns the type of the tuple element at the given `index`, or [`None`] if `index` is out of
/// range.
#[must_use]
pub fn type_at_index(&self, index: u32) -> Option<BasicTypeEnum<'ctx>> {
if index < self.num_elements() {
Some(unsafe { self.type_at_index_unchecked(index) })
} else {
None
}
}
/// Returns the type of the tuple element at the given `index`.
///
/// # Safety
///
/// The caller must ensure that the index is valid.
#[must_use]
pub unsafe fn type_at_index_unchecked(&self, index: u32) -> BasicTypeEnum<'ctx> {
self.ty.get_field_type_at_index_unchecked(index)
}
/// Constructs a [`TupleValue`] from this type by zero-initializing the tuple value.
#[must_use]
pub fn construct(
&self,
ctx: &CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
self.map_value(Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(), name)
}
/// Constructs a [`TupleValue`] from `objects`. The resulting tuple preserves the order of
/// objects.
#[must_use]
pub fn construct_from_objects<I: IntoIterator<Item = BasicValueEnum<'ctx>>>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
objects: I,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
let values = objects.into_iter().collect_vec();
assert_eq!(values.len(), self.num_elements() as usize);
assert!(values
.iter()
.enumerate()
.all(|(i, v)| { v.get_type() == unsafe { self.type_at_index_unchecked(i as u32) } }));
let mut value = self.construct(ctx, name);
for (i, val) in values.into_iter().enumerate() {
value.store_element(ctx, i as u32, val);
}
value
}
/// Converts an existing value into a [`ListValue`].
#[must_use]
pub fn map_value(
&self,
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_struct_value(value, self.llvm_usize, name)
}
}
impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> {
type Base = StructType<'ctx>;
type Value = TupleValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::StructType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected struct type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
_generator: &G,
_ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty)
}
fn alloca_type(&self) -> impl BasicType<'ctx> {
self.as_base_type()
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<TupleType<'ctx>> for StructType<'ctx> {
fn from(value: TupleType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -5,11 +5,13 @@ use crate::codegen::CodeGenerator;
pub use array::*; pub use array::*;
pub use list::*; pub use list::*;
pub use range::*; pub use range::*;
pub use tuple::*;
mod array; mod array;
mod list; mod list;
pub mod ndarray; pub mod ndarray;
mod range; mod range;
mod tuple;
pub mod utils; pub mod utils;
/// A LLVM type that is used to represent a non-primitive value in NAC3. /// A LLVM type that is used to represent a non-primitive value in NAC3.

View File

@ -0,0 +1,85 @@
use inkwell::{
types::IntType,
values::{BasicValue, BasicValueEnum, StructValue},
};
use super::ProxyValue;
use crate::codegen::{types::TupleType, CodeGenContext};
#[derive(Copy, Clone)]
pub struct TupleValue<'ctx> {
value: StructValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> TupleValue<'ctx> {
/// Checks whether `value` is an instance of `tuple`, returning [Err] if `value` is not an
/// instance.
pub fn is_representable(
value: StructValue<'ctx>,
_llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
TupleType::is_representable(value.get_type())
}
/// Creates an [`TupleValue`] from a [`StructValue`].
#[must_use]
pub fn from_struct_value(
value: StructValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(value, llvm_usize).is_ok());
Self { value, llvm_usize, name }
}
/// Stores a value into the tuple element at the given `index`.
pub fn store_element(
&mut self,
ctx: &CodeGenContext<'ctx, '_>,
index: u32,
element: impl BasicValue<'ctx>,
) {
assert_eq!(element.as_basic_value_enum().get_type(), unsafe {
self.get_type().type_at_index_unchecked(index)
});
let new_value = ctx
.builder
.build_insert_value(self.value, element, index, self.name.unwrap_or_default())
.unwrap();
self.value = new_value.into_struct_value();
}
/// Loads a value from the tuple element at the given `index`.
pub fn load_element(&self, ctx: &CodeGenContext<'ctx, '_>, index: u32) -> BasicValueEnum<'ctx> {
ctx.builder
.build_extract_value(
self.value,
index,
&format!("{}[{{i}}]", self.name.unwrap_or("tuple")),
)
.unwrap()
}
}
impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> {
type Base = StructValue<'ctx>;
type Type = TupleType<'ctx>;
fn get_type(&self) -> Self::Type {
TupleType::from_type(self.as_base_value().get_type(), self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<TupleValue<'ctx>> for StructValue<'ctx> {
fn from(value: TupleValue<'ctx>) -> Self {
value.as_base_value()
}
}