diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index b21e721..32b95a7 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,6 +1,6 @@ use inkwell::{ types::BasicTypeEnum, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValue, BasicValueEnum, IntValue}, FloatPredicate, IntPredicate, OptimizationLevel, }; use itertools::Itertools; @@ -14,7 +14,7 @@ use super::{ numpy, numpy::ndarray_elementwise_unaryop_impl, stmt::gen_for_callback_incrementing, - types::ndarray::NDArrayType, + types::{ndarray::NDArrayType, TupleType}, values::{ ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, @@ -1868,34 +1868,6 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( }) } -/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it -fn build_output_struct<'ctx>( - ctx: &mut CodeGenContext<'ctx, '_>, - out_matrices: &[BasicValueEnum<'ctx>], -) -> PointerValue<'ctx> { - let field_ty = out_matrices.iter().map(BasicValueEnum::get_type).collect_vec(); - let out_ty = ctx.ctx.struct_type(&field_ty, false); - let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap(); - - for (i, v) in out_matrices.iter().enumerate() { - unsafe { - let ptr = ctx - .builder - .build_in_bounds_gep( - out_ptr, - &[ - ctx.ctx.i32_type().const_zero(), - ctx.ctx.i32_type().const_int(i as u64, false), - ], - "", - ) - .unwrap(); - ctx.builder.build_store(ptr, *v).unwrap(); - } - } - out_ptr -} - /// Invokes the `np_linalg_cholesky` linalg function pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, @@ -1973,10 +1945,11 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( None, ); - let q = q.as_base_value().into(); - let r = r.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[q, r]); - Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) + let q = q.as_base_value().as_basic_value_enum(); + let r = r.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[q.get_type(), r.get_type()]) + .construct_from_objects(ctx, [q, r], None); + Ok(tuple.as_base_value().into()) } /// Invokes the `np_linalg_svd` linalg function @@ -2031,12 +2004,12 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( None, ); - let u = u.as_base_value().into(); - let s = s.as_base_value().into(); - let vh = vh.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[u, s, vh]); - - Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) + let u = u.as_base_value().as_basic_value_enum(); + let s = s.as_base_value().as_basic_value_enum(); + let vh = vh.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[u.get_type(), s.get_type(), vh.get_type()]) + .construct_from_objects(ctx, [u, s, vh], None); + Ok(tuple.as_base_value().into()) } /// Invokes the `np_linalg_inv` linalg function @@ -2158,10 +2131,11 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( None, ); - let l = l.as_base_value().into(); - let u = u.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[l, u]); - Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap()) + let l = l.as_base_value().as_basic_value_enum(); + let u = u.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[l.get_type(), u.get_type()]) + .construct_from_objects(ctx, [l, u], None); + Ok(tuple.as_base_value().into()) } /// Invokes the `np_linalg_matrix_power` linalg function @@ -2293,10 +2267,11 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( None, ); - let t = t.as_base_value().into(); - let z = z.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[t, z]); - Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) + let t = t.as_base_value().as_basic_value_enum(); + let z = z.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[t.get_type(), z.get_type()]) + .construct_from_objects(ctx, [t, z], None); + Ok(tuple.as_base_value().into()) } /// Invokes the `sp_linalg_hessenberg` linalg function @@ -2337,8 +2312,9 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( None, ); - let h = h.as_base_value().into(); - let q = q.as_base_value().into(); - let out_ptr = build_output_struct(ctx, &[h, q]); - Ok(ctx.builder.build_load(out_ptr, "Hessenberg_decomposition_result").map(Into::into).unwrap()) + let h = h.as_base_value().as_basic_value_enum(); + let q = q.as_base_value().as_basic_value_enum(); + let tuple = TupleType::new(generator, ctx.ctx, &[h.get_type(), q.get_type()]) + .construct_from_objects(ctx, [h, q], None); + Ok(tuple.as_base_value().into()) } diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 1e0fb26..2ce3c9a 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -42,7 +42,7 @@ use crate::{ }; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; -use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType}; +use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, TupleType}; pub mod builtin_fns; pub mod concrete_type; @@ -574,7 +574,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) }) .collect_vec(); - ctx.struct_type(&fields, false).into() + TupleType::new(generator, ctx, &fields).as_base_type().into() } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 98bd43b..0a31d6a 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -28,11 +28,13 @@ use super::{ }; pub use list::*; pub use range::*; +pub use tuple::*; mod list; pub mod ndarray; mod range; pub mod structure; +mod tuple; pub mod utils; /// A LLVM type that is used to represent a corresponding type in NAC3. diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs new file mode 100644 index 0000000..ccb63b4 --- /dev/null +++ b/nac3core/src/codegen/types/tuple.rs @@ -0,0 +1,184 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, IntType, StructType}, + values::BasicValueEnum, +}; +use itertools::Itertools; + +use super::ProxyType; +use crate::{ + codegen::{ + values::{ProxyValue, TupleValue}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::{Type, TypeEnum}, +}; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct TupleType<'ctx> { + ty: StructType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +impl<'ctx> TupleType<'ctx> { + /// Checks whether `llvm_ty` represents any tuple type, returning [Err] if it does not. + pub fn is_representable(_value: StructType<'ctx>) -> Result<(), String> { + Ok(()) + } + + /// Creates an LLVM type corresponding to the expected structure of a tuple. + #[must_use] + fn llvm_type(ctx: &'ctx Context, tys: &[BasicTypeEnum<'ctx>]) -> StructType<'ctx> { + ctx.struct_type(tys, false) + } + + /// Creates an instance of [`TupleType`]. + #[must_use] + pub fn new( + generator: &G, + ctx: &'ctx Context, + tys: &[BasicTypeEnum<'ctx>], + ) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_tuple = Self::llvm_type(ctx, tys); + + Self { ty: llvm_tuple, llvm_usize } + } + + /// Creates an [`TupleType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ) -> Self { + let llvm_usize = generator.get_size_type(ctx.ctx); + + // Sanity check on object type. + let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else { + panic!("Expected type to be a TypeEnum::TTuple, got {}", ctx.unifier.stringify(ty)); + }; + + let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec(); + Self { ty: Self::llvm_type(ctx.ctx, &llvm_tys), llvm_usize } + } + + /// Creates an [`TupleType`] from a [`StructType`]. + #[must_use] + pub fn from_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::is_representable(struct_ty).is_ok()); + + TupleType { ty: struct_ty, llvm_usize } + } + + /// Returns the number of elements present in this [`TupleType`]. + #[must_use] + pub fn num_elements(&self) -> u32 { + self.ty.count_fields() + } + + /// Returns the type of the tuple element at the given `index`, or [`None`] if `index` is out of + /// range. + #[must_use] + pub fn type_at_index(&self, index: u32) -> Option> { + if index < self.num_elements() { + Some(unsafe { self.type_at_index_unchecked(index) }) + } else { + None + } + } + + /// Returns the type of the tuple element at the given `index`. + /// + /// # Safety + /// + /// The caller must ensure that the index is valid. + #[must_use] + pub unsafe fn type_at_index_unchecked(&self, index: u32) -> BasicTypeEnum<'ctx> { + self.ty.get_field_type_at_index_unchecked(index) + } + + /// Constructs a [`TupleValue`] from this type by zero-initializing the tuple value. + #[must_use] + pub fn construct( + &self, + ctx: &CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + self.map_value(Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(), name) + } + + /// Constructs a [`TupleValue`] from `objects`. The resulting tuple preserves the order of + /// objects. + #[must_use] + pub fn construct_from_objects>>( + &self, + ctx: &CodeGenContext<'ctx, '_>, + objects: I, + name: Option<&'ctx str>, + ) -> >::Value { + let values = objects.into_iter().collect_vec(); + + assert_eq!(values.len(), self.num_elements() as usize); + assert!(values + .iter() + .enumerate() + .all(|(i, v)| { v.get_type() == unsafe { self.type_at_index_unchecked(i as u32) } })); + + let mut value = self.construct(ctx, name); + for (i, val) in values.into_iter().enumerate() { + value.store_element(ctx, i as u32, val); + } + + value + } + + /// Converts an existing value into a [`ListValue`]. + #[must_use] + pub fn map_value( + &self, + value: <>::Value as ProxyValue<'ctx>>::Base, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { + type Base = StructType<'ctx>; + type Value = TupleValue<'ctx>; + + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::StructType(ty) = llvm_ty.as_basic_type_enum() { + >::is_representable(generator, ctx, ty) + } else { + Err(format!("Expected struct type, got {llvm_ty:?}")) + } + } + + fn is_representable( + _generator: &G, + _ctx: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String> { + Self::is_representable(llvm_ty) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_base_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } +} + +impl<'ctx> From> for StructType<'ctx> { + fn from(value: TupleType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index 032f041..c789fe0 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -5,11 +5,13 @@ use crate::codegen::CodeGenerator; pub use array::*; pub use list::*; pub use range::*; +pub use tuple::*; mod array; mod list; pub mod ndarray; mod range; +mod tuple; pub mod utils; /// A LLVM type that is used to represent a non-primitive value in NAC3. diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs new file mode 100644 index 0000000..5167e47 --- /dev/null +++ b/nac3core/src/codegen/values/tuple.rs @@ -0,0 +1,85 @@ +use inkwell::{ + types::IntType, + values::{BasicValue, BasicValueEnum, StructValue}, +}; + +use super::ProxyValue; +use crate::codegen::{types::TupleType, CodeGenContext}; + +#[derive(Copy, Clone)] +pub struct TupleValue<'ctx> { + value: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> TupleValue<'ctx> { + /// Checks whether `value` is an instance of `tuple`, returning [Err] if `value` is not an + /// instance. + pub fn is_representable( + value: StructValue<'ctx>, + _llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + TupleType::is_representable(value.get_type()) + } + + /// Creates an [`TupleValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + value: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_representable(value, llvm_usize).is_ok()); + + Self { value, llvm_usize, name } + } + + /// Stores a value into the tuple element at the given `index`. + pub fn store_element( + &mut self, + ctx: &CodeGenContext<'ctx, '_>, + index: u32, + element: impl BasicValue<'ctx>, + ) { + assert_eq!(element.as_basic_value_enum().get_type(), unsafe { + self.get_type().type_at_index_unchecked(index) + }); + + let new_value = ctx + .builder + .build_insert_value(self.value, element, index, self.name.unwrap_or_default()) + .unwrap(); + self.value = new_value.into_struct_value(); + } + + /// Loads a value from the tuple element at the given `index`. + pub fn load_element(&self, ctx: &CodeGenContext<'ctx, '_>, index: u32) -> BasicValueEnum<'ctx> { + ctx.builder + .build_extract_value( + self.value, + index, + &format!("{}[{{i}}]", self.name.unwrap_or("tuple")), + ) + .unwrap() + } +} + +impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> { + type Base = StructValue<'ctx>; + type Type = TupleType<'ctx>; + + fn get_type(&self) -> Self::Type { + TupleType::from_type(self.as_base_value().get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } +} + +impl<'ctx> From> for StructValue<'ctx> { + fn from(value: TupleValue<'ctx>) -> Self { + value.as_base_value() + } +}