diff --git a/nac3core/src/typecheck/escape_analysis/lifetime.rs b/nac3core/src/typecheck/escape_analysis/lifetime.rs new file mode 100644 index 000000000..97664387e --- /dev/null +++ b/nac3core/src/typecheck/escape_analysis/lifetime.rs @@ -0,0 +1,276 @@ +use std::cell::RefCell; +use std::collections::{HashMap, HashSet}; +use std::rc::Rc; + +use crate::typecheck::unification_table::{UnificationKey, UnificationTable}; + +use itertools::Itertools; +use nac3parser::ast::StrRef; + +// change this to enum, only local needs unification key +pub type Lifetime = UnificationKey; + +#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash)] +pub enum LifetimeKind { + // can be assigned to fields of anything + // can be returned + // lifetime of static values + Global, + // can only be assigned to fields of objects with local lifetime + // can be returned + // lifetime of parameters + NonLocal, + // can only be assigned to fields of objects with local lifetime + // cannot be returned + // lifetime of constructor return values + Local, + // can only be assigned to fields of objects with local lifetime + // cannot be returned + // lifetime of function return values + Unknown, +} + +impl std::ops::BitAnd for LifetimeKind { + type Output = Self; + + fn bitand(self, other: Self) -> Self { + use LifetimeKind::*; + match (self, other) { + (x, y) if x == y => x, + (Global, NonLocal) | (NonLocal, Global) => NonLocal, + _ => Unknown, + } + } +} + +impl std::cmp::PartialOrd for LifetimeKind { + fn partial_cmp(&self, other: &Self) -> Option { + use LifetimeKind::*; + match (*self, *other) { + (x, y) if x == y => Some(std::cmp::Ordering::Equal), + (Local, _) | (_, Global) => Some(std::cmp::Ordering::Less), + (_, Local) | (Global, _) => Some(std::cmp::Ordering::Greater), + _ => None, + } + } +} + +pub struct BlockLifetimeContext { + mapping: Vec<(Option, Lifetime)>, +} + +impl BlockLifetimeContext { + pub fn new() -> Self { + BlockLifetimeContext { mapping: Vec::new() } + } + + pub fn add_fresh(&mut self, lifetime: Lifetime) { + self.mapping.push((None, lifetime)); + } +} + +struct LifetimeEntry { + kind: LifetimeKind, + fields: RefCell>, +} + +pub struct LifetimeTable { + table: UnificationTable>, + cache: HashSet<(Lifetime, Lifetime)>, +} + +impl LifetimeTable { + pub fn new() -> Self { + let mut zelf = Self { table: UnificationTable::new(), cache: Default::default() }; + zelf.table.new_key(Rc::new(LifetimeEntry { + kind: LifetimeKind::Unknown, + fields: Default::default(), + })); + zelf + } + + pub fn add_lifetime(&mut self, kind: LifetimeKind) -> Lifetime { + self.table.new_key(Rc::new(LifetimeEntry { kind, fields: Default::default() })) + } + + pub fn unify(&mut self, a: Lifetime, b: Lifetime, ctx: &mut BlockLifetimeContext) { + self.cache.clear(); + self.unify_impl(a, b, ctx); + } + + fn get_scoped( + &mut self, + mut lifetimes: [Lifetime; N], + ctx: &mut BlockLifetimeContext, + ) -> [Lifetime; N] { + for l in lifetimes.iter_mut() { + let mut result = None; + for (k, v) in ctx.mapping.iter() { + if self.table.unioned(*v, *l) || k.map_or(false, |k| self.table.unioned(k, *l)) { + // already fresh + result = Some(*v); + break; + } + } + if let Some(result) = result { + *l = result; + } else { + let lifetime = self.table.probe_value(*l).clone(); + *l = if lifetime.kind == LifetimeKind::Unknown { + UnificationKey(0) + } else { + let k = self.table.new_key(lifetime); + ctx.mapping.push((Some(*l), k)); + k + } + } + } + lifetimes + } + + fn unify_impl(&mut self, a: Lifetime, b: Lifetime, ctx: &mut BlockLifetimeContext) { + use LifetimeKind::*; + + let [a, b] = self.get_scoped([a, b], ctx); + let a = self.table.get_representative(a); + let b = self.table.get_representative(b); + if a == b || self.cache.contains(&(a, b)) || self.cache.contains(&(b, a)) { + return; + } + self.cache.insert((a, b)); + + let v_a = self.table.probe_value(a).clone(); + let v_b = self.table.probe_value(b).clone(); + + let result_kind = v_a.kind & v_b.kind; + + let fields = if result_kind == Local { + // we only need to track fields lifetime for objects with local lifetime + let fields = v_a.fields.clone(); + { + let mut fields_ref = fields.borrow_mut(); + for (k, v) in v_b.fields.borrow().iter() { + if let Some(old) = fields_ref.insert(k.clone(), *v) { + self.unify_impl(old, *v, ctx); + } + } + } + fields + } else { + Default::default() + }; + + self.table.unify(a, b); + self.table.set_value(a, Rc::new(LifetimeEntry { kind: result_kind, fields })); + } + + pub fn get_field_lifetime( + &mut self, + lifetime: Lifetime, + field: StrRef, + ctx: &mut BlockLifetimeContext, + ) -> Lifetime { + use LifetimeKind::*; + let [lifetime] = self.get_scoped([lifetime], ctx); + if let LifetimeEntry { kind: Local, fields } = &*self.table.probe_value(lifetime).clone() { + if let Some(lifetime) = fields.borrow().get(&field) { + *lifetime + } else { + // unknown lifetime + // we can reuse this lifetime because it will never be unified to something else + UnificationKey(0) + } + } else { + lifetime + } + } + + pub fn set_field_lifetime( + &mut self, + obj: Lifetime, + field: StrRef, + lifetime: Lifetime, + is_strong_update: bool, + ctx: &mut BlockLifetimeContext, + ) -> Result<(), String> { + let [obj, lifetime] = self.get_scoped([obj, lifetime], ctx); + let obj_lifetime = self.table.probe_value(obj).clone(); + let field_lifetime = self.table.probe_value(lifetime).clone(); + if !(obj_lifetime.kind <= field_lifetime.kind) { + return Err("lifetime error".to_string()); + } + let mut fields = obj_lifetime.fields.borrow_mut(); + if is_strong_update { + fields.insert(field, lifetime); + } else { + if let Some(old) = fields.insert(field, lifetime) { + self.unify(old, lifetime, ctx); + } + } + Ok(()) + } + + pub fn get_lifetime_kind( + &mut self, + lifetime: Lifetime, + ctx: &mut BlockLifetimeContext, + ) -> LifetimeKind { + let [lifetime] = self.get_scoped([lifetime], ctx); + self.table.probe_value(lifetime).kind + } + + pub fn set_function_params(&mut self, lifetime: Lifetime, ctx: &mut BlockLifetimeContext) { + use LifetimeKind::*; + // unify each field with global + let [lifetime] = self.get_scoped([lifetime], ctx); + let lifetime = self.table.probe_value(lifetime).clone(); + let mut worklist = lifetime.fields.borrow().values().copied().collect_vec(); + while let Some(item) = worklist.pop() { + let [item] = self.get_scoped([item], ctx); + let lifetime = self.table.probe_value(item).clone(); + if lifetime.kind == Unknown || lifetime.kind == Global { + continue; + } + let fields = lifetime.fields.borrow().clone(); + for (_, v) in fields.iter() { + worklist.push(*v); + } + self.table.set_value( + item, + Rc::new(LifetimeEntry { + kind: lifetime.kind & Global, + fields: RefCell::new(fields), + }), + ); + } + } + + pub fn get_unknown_lifetime(&self) -> Lifetime { + UnificationKey(0) + } + + pub fn equiv(&mut self, a: Lifetime, b: Lifetime, ctx: &mut BlockLifetimeContext) -> bool { + use LifetimeKind::Local; + let [a, b] = self.get_scoped([a, b], ctx); + if self.table.unioned(a, b) { + return true; + } + let lifetime_a = self.table.probe_value(a).clone(); + let lifetime_b = self.table.probe_value(b).clone(); + if lifetime_a.kind == Local && lifetime_b.kind == Local { + let fields_a = lifetime_a.fields.borrow(); + let fields_b = lifetime_b.fields.borrow(); + for (k, v) in fields_a.iter() { + if fields_b.get(k).map(|v1| self.equiv(*v, *v1, ctx)) != Some(true) { + return false; + } + } + // they are just equivalent + // this can avoid infinite recursion + self.table.unify(a, b); + true + } else { + lifetime_a.kind == lifetime_b.kind + } + } +} diff --git a/nac3core/src/typecheck/escape_analysis/mod.rs b/nac3core/src/typecheck/escape_analysis/mod.rs new file mode 100644 index 000000000..d33ca4b89 --- /dev/null +++ b/nac3core/src/typecheck/escape_analysis/mod.rs @@ -0,0 +1,317 @@ +use std::{collections::HashMap, sync::Arc}; + +use nac3parser::ast::{Constant, Expr, ExprKind, StrRef}; + +use crate::{ + symbol_resolver::SymbolResolver, + toplevel::{TopLevelContext, TopLevelDef}, +}; + +use self::lifetime::{BlockLifetimeContext, Lifetime, LifetimeTable}; + +use super::{ + type_inferencer::PrimitiveStore, + typedef::{Type, TypeEnum, Unifier}, +}; + +pub mod lifetime; + +struct LifetimeContext<'a> { + variable_mapping: HashMap, + scope_ctx: BlockLifetimeContext, + lifetime_table: LifetimeTable, + primitive_store: PrimitiveStore, + unifier: &'a mut Unifier, + resolver: Arc, + top_level: &'a TopLevelContext, +} + +impl<'a> LifetimeContext<'a> { + pub fn new( + unifier: &'a mut Unifier, + primitive_store: PrimitiveStore, + resolver: Arc, + top_level: &'a TopLevelContext, + ) -> LifetimeContext<'a> { + LifetimeContext { + variable_mapping: HashMap::new(), + scope_ctx: BlockLifetimeContext::new(), + lifetime_table: LifetimeTable::new(), + primitive_store, + unifier, + resolver, + top_level, + } + } + + fn get_expr_lifetime( + &mut self, + expr: &Expr>, + ) -> Result, String> { + let ty = expr.custom.unwrap(); + let is_primitive = self.unifier.unioned(ty, self.primitive_store.int32) + || self.unifier.unioned(ty, self.primitive_store.int64) + || self.unifier.unioned(ty, self.primitive_store.uint32) + || self.unifier.unioned(ty, self.primitive_store.uint64) + || self.unifier.unioned(ty, self.primitive_store.float) + || self.unifier.unioned(ty, self.primitive_store.bool) + || self.unifier.unioned(ty, self.primitive_store.none) + || self.unifier.unioned(ty, self.primitive_store.range); + + Ok(match &expr.node { + ExprKind::Name { id, .. } => { + if let Some(lifetime) = self.variable_mapping.get(id) { + Some(*lifetime) + } else { + if is_primitive { + None + } else { + let lifetime = + self.lifetime_table.add_lifetime(lifetime::LifetimeKind::Global); + self.variable_mapping.insert(id.clone(), (lifetime, false)); + Some((lifetime, false)) + } + } + } + ExprKind::Attribute { value, attr, .. } => { + if is_primitive { + self.get_expr_lifetime(value)?; + None + } else { + self.get_expr_lifetime(value)?.map(|lifetime| { + ( + self.lifetime_table.get_field_lifetime( + lifetime.0, + *attr, + &mut self.scope_ctx, + ), + false, // not sure if it is strong update for now... + ) + }) + } + } + ExprKind::Constant { .. } => { + if is_primitive { + None + } else { + Some((self.lifetime_table.add_lifetime(lifetime::LifetimeKind::Global), false)) + } + } + ExprKind::List { elts, .. } => { + let elems = + elts.iter() + .map(|expr| self.get_expr_lifetime(expr)) + .collect::, _>>()?; + let elem = elems.into_iter().reduce(|prev, next| { + if prev.is_some() { + self.lifetime_table.unify( + prev.unwrap().0, + next.unwrap().0, + &mut self.scope_ctx, + ); + } + prev + }); + let list_lifetime = self.lifetime_table.add_lifetime(lifetime::LifetimeKind::Local); + + if let Some(Some(elem)) = elem { + self.lifetime_table + .set_field_lifetime( + list_lifetime, + "elem".into(), + elem.0, + true, + &mut self.scope_ctx, + ) + .unwrap(); + } + Some((list_lifetime, true)) + } + ExprKind::Subscript { value, slice, .. } => { + // value must be a list, so lifetime cannot be None + let (value_lifetime, _) = self.get_expr_lifetime(value)?.unwrap(); + match &slice.node { + ExprKind::Slice { lower, upper, step } => { + for expr in [lower, upper, step].iter().filter_map(|x| x.as_ref()) { + // account for side effects when computing the slice + self.get_expr_lifetime(expr)?; + } + Some(( + self.lifetime_table.add_lifetime(lifetime::LifetimeKind::Local), + true, + )) + } + ExprKind::Constant { value: Constant::Int(v), .. } => { + if is_primitive { + None + } else if let TypeEnum::TList { .. } = + &*self.unifier.get_ty(value.custom.unwrap()) + { + Some(( + self.lifetime_table.get_field_lifetime( + value_lifetime, + "elem".into(), + &mut self.scope_ctx, + ), + false, + )) + } else { + // tuple + Some(( + self.lifetime_table.get_field_lifetime( + value_lifetime, + format!("elem{}", v).into(), + &mut self.scope_ctx, + ), + false, + )) + } + } + _ => { + // account for side effects when computing the index + self.get_expr_lifetime(slice)?; + if is_primitive { + None + } else { + Some(( + self.lifetime_table.get_field_lifetime( + value_lifetime, + "elem".into(), + &mut self.scope_ctx, + ), + false, + )) + } + } + } + } + ExprKind::Tuple { elts, .. } => { + let elems = + elts.iter() + .map(|expr| self.get_expr_lifetime(expr)) + .collect::, _>>()?; + let tuple_lifetime = + self.lifetime_table.add_lifetime(lifetime::LifetimeKind::Local); + for (i, lifetime) in elems.into_iter().enumerate() { + if let Some((lifetime, _)) = lifetime { + self.lifetime_table + .set_field_lifetime( + tuple_lifetime, + format!("elem{}", i).into(), + lifetime, + true, + &mut self.scope_ctx, + ) + .unwrap(); + } + } + Some((tuple_lifetime, true)) + } + ExprKind::Call { func, args, keywords } => { + let mut lifetimes = Vec::new(); + for arg in args.iter() { + if let Some(lifetime) = self.get_expr_lifetime(arg)? { + lifetimes.push(lifetime.0); + } + } + for keyword in keywords.iter() { + if let Some(lifetime) = self.get_expr_lifetime(&keyword.node.value)? { + lifetimes.push(lifetime.0); + } + } + match &func.node { + ExprKind::Name { id, .. } => { + for lifetime in lifetimes.into_iter() { + self.lifetime_table.set_function_params(lifetime, &mut self.scope_ctx); + } + if is_primitive { + None + } else { + let id = self + .resolver + .get_identifier_def(*id) + .map_err(|e| format!("{} (at {})", e, func.location))?; + // constructors + if let TopLevelDef::Class { .. } = + &*self.top_level.definitions.read()[id.0].read() + { + Some(( + self.lifetime_table.add_lifetime(lifetime::LifetimeKind::Local), + true, + )) + } else { + Some((self.lifetime_table.get_unknown_lifetime(), false)) + } + } + } + ExprKind::Attribute { value, .. } => { + if let Some(lifetime) = self.get_expr_lifetime(value)? { + lifetimes.push(lifetime.0); + } + for lifetime in lifetimes.into_iter() { + self.lifetime_table.set_function_params(lifetime, &mut self.scope_ctx); + } + if is_primitive { + None + } else { + Some((self.lifetime_table.get_unknown_lifetime(), false)) + } + } + _ => unimplemented!(), + } + } + ExprKind::BinOp { left, right, .. } => { + let mut lifetimes = Vec::new(); + if let Some(l) = self.get_expr_lifetime(left)? { + lifetimes.push(l.0); + } + if let Some(l) = self.get_expr_lifetime(right)? { + lifetimes.push(l.0); + } + for lifetime in lifetimes.into_iter() { + self.lifetime_table.set_function_params(lifetime, &mut self.scope_ctx); + } + if is_primitive { + None + } else { + Some((self.lifetime_table.get_unknown_lifetime(), false)) + } + } + ExprKind::BoolOp { values, .. } => { + for v in values { + self.get_expr_lifetime(v)?; + } + None + } + ExprKind::UnaryOp { operand, .. } => { + if let Some(l) = self.get_expr_lifetime(operand)? { + self.lifetime_table.set_function_params(l.0, &mut self.scope_ctx); + } + if is_primitive { + None + } else { + Some((self.lifetime_table.get_unknown_lifetime(), false)) + } + } + ExprKind::Compare { left, comparators, .. } => { + let mut lifetimes = Vec::new(); + if let Some(l) = self.get_expr_lifetime(left)? { + lifetimes.push(l.0); + } + for c in comparators { + if let Some(l) = self.get_expr_lifetime(c)? { + lifetimes.push(l.0); + } + } + for lifetime in lifetimes.into_iter() { + self.lifetime_table.set_function_params(lifetime, &mut self.scope_ctx); + } + // compare should give bool output, which does not have lifetime + None + } + // TODO: listcomp, ifexpr + _ => unimplemented!(), + }) + + } +} diff --git a/nac3core/src/typecheck/mod.rs b/nac3core/src/typecheck/mod.rs index 4cac1bf0e..1b349230f 100644 --- a/nac3core/src/typecheck/mod.rs +++ b/nac3core/src/typecheck/mod.rs @@ -4,3 +4,4 @@ pub mod type_error; pub mod type_inferencer; pub mod typedef; mod unification_table; +pub mod escape_analysis; diff --git a/nac3core/src/typecheck/unification_table.rs b/nac3core/src/typecheck/unification_table.rs index 101057c5e..adb0a11d1 100644 --- a/nac3core/src/typecheck/unification_table.rs +++ b/nac3core/src/typecheck/unification_table.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use itertools::izip; #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] -pub struct UnificationKey(usize); +pub struct UnificationKey(pub(crate) usize); #[derive(Clone)] pub struct UnificationTable { @@ -44,6 +44,12 @@ impl UnificationTable { UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 } } + fn log_action(&mut self, action: Action) { + if !self.log.is_empty() { + self.log.push(action); + } + } + pub fn new_key(&mut self, v: V) -> UnificationKey { let index = self.parents.len(); self.parents.push(index); @@ -61,10 +67,10 @@ impl UnificationTable { if self.ranks[a] < self.ranks[b] { std::mem::swap(&mut a, &mut b); } - self.log.push(Action::Parent { key: b, original_parent: self.parents[b] }); + self.log_action(Action::Parent { key: b, original_parent: self.parents[b] }); self.parents[b] = a; if self.ranks[a] == self.ranks[b] { - self.log.push(Action::Rank { key: a, original_rank: self.ranks[a] }); + self.log_action(Action::Rank { key: a, original_rank: self.ranks[a] }); self.ranks[a] += 1; } } @@ -88,7 +94,7 @@ impl UnificationTable { pub fn set_value(&mut self, a: UnificationKey, v: V) { let index = self.find(a); let original_value = self.values[index].replace(v); - self.log.push(Action::Value { key: index, original_value }); + self.log_action(Action::Value { key: index, original_value }); } pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool { @@ -106,7 +112,7 @@ impl UnificationTable { // a = parent.parent let a = self.parents[parent]; // root.parent = parent.parent - self.log.push(Action::Parent { key: root, original_parent: self.parents[root] }); + self.log_action(Action::Parent { key: root, original_parent: self.parents[root] }); self.parents[root] = a; root = parent; // parent = root.parent