From b3736c3e99ced4d572e595185efd4efa8aad18d3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 29 Jan 2024 12:49:24 +0800 Subject: [PATCH] core: Disallow returning of non-primitive values Non-primitive values are represented by an `alloca`-ed value in the function body, and when the pointer is returned from the function, the `alloca`-ed object is deallocated on the stack. Related to #54. --- nac3core/src/typecheck/function_check.rs | 27 +++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index a8461f5..7db69c2 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -3,7 +3,7 @@ use crate::typecheck::typedef::TypeEnum; use super::type_inferencer::Inferencer; use super::typedef::Type; use nac3parser::ast::{self, Constant, Expr, ExprKind, Operator::{LShift, RShift}, Stmt, StmtKind, StrRef}; -use std::{collections::HashSet, iter::once}; +use std::{collections::HashSet, iter::once, ops::Not}; impl<'a> Inferencer<'a> { fn should_have_value(&mut self, expr: &Expr>) -> Result<(), HashSet> { @@ -302,6 +302,31 @@ impl<'a> Inferencer<'a> { if let Some(value) = value { self.check_expr(value, defined_identifiers)?; self.should_have_value(value)?; + + // Check that the return value is a non-`alloca` type, effectively only allowing primitive types. + // This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which + // is freed when the function returns. + if let Some(ret_ty) = value.custom { + if [ + self.primitives.int32, + self.primitives.int64, + self.primitives.uint32, + self.primitives.uint64, + self.primitives.float, + self.primitives.bool, + ].iter().any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)).not() { + // Explicitly allow ellipsis as a return value, as the type of the ellipsis is contextually + // inferred and just generates an unconditional assertion + if matches!(value.node, ExprKind::Constant { value: Constant::Ellipsis, .. }).not() { + return Err(HashSet::from([ + format!( + "return value of type {} must be a primitive", + self.unifier.stringify(ret_ty), + ), + ])) + } + } + } } Ok(true) }