diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index 7db69c2..21576e0 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -3,7 +3,7 @@ use crate::typecheck::typedef::TypeEnum; use super::type_inferencer::Inferencer; use super::typedef::Type; use nac3parser::ast::{self, Constant, Expr, ExprKind, Operator::{LShift, RShift}, Stmt, StmtKind, StrRef}; -use std::{collections::HashSet, iter::once, ops::Not}; +use std::{collections::HashSet, iter::once}; impl<'a> Inferencer<'a> { fn should_have_value(&mut self, expr: &Expr>) -> Result<(), HashSet> { @@ -208,6 +208,27 @@ impl<'a> Inferencer<'a> { Ok(()) } + /// Check that the return value is a non-`alloca` type, effectively only allowing primitive types. + /// + /// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which + /// is freed when the function returns. + fn check_return_value_ty(&mut self, ret_ty: Type) -> bool { + match &*self.unifier.get_ty_immutable(ret_ty) { + TypeEnum::TObj { .. } => { + [ + self.primitives.int32, + self.primitives.int64, + self.primitives.uint32, + self.primitives.uint64, + self.primitives.float, + self.primitives.bool, + ].iter().any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)) + } + TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)), + _ => false, + } + } + // check statements for proper identifier def-use and return on all paths fn check_stmt( &mut self, @@ -307,24 +328,19 @@ impl<'a> Inferencer<'a> { // This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which // is freed when the function returns. if let Some(ret_ty) = value.custom { - if [ - self.primitives.int32, - self.primitives.int64, - self.primitives.uint32, - self.primitives.uint64, - self.primitives.float, - self.primitives.bool, - ].iter().any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)).not() { - // Explicitly allow ellipsis as a return value, as the type of the ellipsis is contextually - // inferred and just generates an unconditional assertion - if matches!(value.node, ExprKind::Constant { value: Constant::Ellipsis, .. }).not() { - return Err(HashSet::from([ - format!( - "return value of type {} must be a primitive", - self.unifier.stringify(ret_ty), - ), - ])) - } + // Explicitly allow ellipsis as a return value, as the type of the ellipsis is contextually + // inferred and just generates an unconditional assertion + if matches!(value.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) { + return Ok(true) + } + + if !self.check_return_value_ty(ret_ty) { + return Err(HashSet::from([ + format!( + "return value of type {} must be a primitive of a tuple of primitives", + self.unifier.stringify(ret_ty), + ), + ])) } } }