diff --git a/nac3core/src/typecheck/escape_analysis/mod.rs b/nac3core/src/typecheck/escape_analysis/mod.rs index d33ca4b8..7fdbf7a6 100644 --- a/nac3core/src/typecheck/escape_analysis/mod.rs +++ b/nac3core/src/typecheck/escape_analysis/mod.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, sync::Arc}; -use nac3parser::ast::{Constant, Expr, ExprKind, StrRef}; +use nac3parser::ast::{Constant, Expr, ExprKind, Stmt, StmtKind, StrRef}; use crate::{ symbol_resolver::SymbolResolver, @@ -16,11 +16,14 @@ use super::{ pub mod lifetime; +#[cfg(test)] +mod test; + struct LifetimeContext<'a> { variable_mapping: HashMap, scope_ctx: BlockLifetimeContext, lifetime_table: LifetimeTable, - primitive_store: PrimitiveStore, + primitive_store: &'a PrimitiveStore, unifier: &'a mut Unifier, resolver: Arc, top_level: &'a TopLevelContext, @@ -29,7 +32,7 @@ struct LifetimeContext<'a> { impl<'a> LifetimeContext<'a> { pub fn new( unifier: &'a mut Unifier, - primitive_store: PrimitiveStore, + primitive_store: &'a PrimitiveStore, resolver: Arc, top_level: &'a TopLevelContext, ) -> LifetimeContext<'a> { @@ -293,7 +296,7 @@ impl<'a> LifetimeContext<'a> { Some((self.lifetime_table.get_unknown_lifetime(), false)) } } - ExprKind::Compare { left, comparators, .. } => { + ExprKind::Compare { left, comparators, .. } => { let mut lifetimes = Vec::new(); if let Some(l) = self.get_expr_lifetime(left)? { lifetimes.push(l.0); @@ -312,6 +315,90 @@ impl<'a> LifetimeContext<'a> { // TODO: listcomp, ifexpr _ => unimplemented!(), }) + } + fn handle_assignment( + &mut self, + lhs: &Expr>, + rhs_lifetime: Option<(Lifetime, bool)>, + ) -> Result<(), String> { + match &lhs.node { + ExprKind::Attribute { value, attr, .. } => { + let (lhs_lifetime, is_strong_update) = self.get_expr_lifetime(value)?.unwrap(); + if let Some((lifetime, _)) = rhs_lifetime { + self.lifetime_table + .set_field_lifetime( + lhs_lifetime, + *attr, + lifetime, + is_strong_update, + &mut self.scope_ctx, + ) + .map_err(|_| format!("illegal field assignment in {}", lhs.location))?; + } + } + ExprKind::Subscript { value, slice, .. } => { + let (list_lifetime, _) = self.get_expr_lifetime(value)?.unwrap(); + let elem_lifetime = if let ExprKind::Slice { lower, upper, step } = &slice.node { + // compute side effects + for expr in [lower, upper, step].iter().filter_map(|x| x.as_ref()) { + // account for side effects when computing the slice + self.get_expr_lifetime(expr)?; + } + // slice assignment will copy elements from rhs to lhs + self.lifetime_table.get_field_lifetime( + rhs_lifetime.unwrap().0, + "elem".into(), + &mut self.scope_ctx, + ) + } else { + // must be list element, as assignment to tuple element is prohibited + self.get_expr_lifetime(slice)?; + rhs_lifetime.unwrap().0 + }; + self.lifetime_table + .set_field_lifetime( + list_lifetime, + "elem".into(), + elem_lifetime, + false, + &mut self.scope_ctx, + ) + .map_err(|_| format!("illegal element assignment in {}", lhs.location))?; + } + ExprKind::Name { id, .. } => { + if let Some(lifetime) = rhs_lifetime { + self.variable_mapping.insert(*id, lifetime); + } + } + ExprKind::Tuple { elts, .. } => { + for (i, e) in elts.iter().enumerate() { + let elem_lifetime = self.lifetime_table.get_field_lifetime( + rhs_lifetime.unwrap().0, + format!("elem{}", i).into(), + &mut self.scope_ctx, + ); + self.handle_assignment(e, Some((elem_lifetime, false)))?; + } + } + _ => unreachable!(), + } + Ok(()) + } + + pub fn handle_statement(&mut self, stmt: &Stmt>) -> Result<(), String> { + match &stmt.node { + StmtKind::Expr { value, .. } => { + self.get_expr_lifetime(value)?; + } + StmtKind::Assign { targets, value, .. } => { + let rhs_lifetime = self.get_expr_lifetime(value)?; + for target in targets.iter() { + self.handle_assignment(target, rhs_lifetime)?; + } + } + _ => unimplemented!(), + } + Ok(()) } } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 510a0f09..d865926a 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -14,7 +14,7 @@ use nac3parser::ast::{ }; #[cfg(test)] -mod test; +pub(crate) mod test; #[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)] pub struct CodeLocation { diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 74930814..252dd1f5 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -11,7 +11,7 @@ use nac3parser::parser::parse_program; use parking_lot::RwLock; use test_case::test_case; -struct Resolver { +pub(crate) struct Resolver { id_to_type: HashMap, id_to_def: HashMap, class_names: HashMap, @@ -56,7 +56,7 @@ impl SymbolResolver for Resolver { } } -struct TestEnvironment { +pub(crate) struct TestEnvironment { pub unifier: Unifier, pub function_data: FunctionData, pub primitives: PrimitiveStore, @@ -192,7 +192,7 @@ impl TestEnvironment { } } - fn new() -> TestEnvironment { + pub fn new() -> TestEnvironment { let mut unifier = Unifier::new(); let mut identifier_mapping = HashMap::new(); let mut top_level_defs: Vec>> = Vec::new(); @@ -447,7 +447,7 @@ impl TestEnvironment { } } - fn get_inferencer(&mut self) -> Inferencer { + pub fn get_inferencer(&mut self) -> Inferencer { Inferencer { top_level: &self.top_level, function_data: &mut self.function_data,