forked from M-Labs/nac3
1
0
Fork 0

nac3core/typecheck: start implementing escape analysis

This commit is contained in:
pca006132 2022-04-04 22:42:22 +08:00
parent 4f66bdeda9
commit 10c4544553
4 changed files with 605 additions and 5 deletions

View File

@ -0,0 +1,276 @@
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::rc::Rc;
use crate::typecheck::unification_table::{UnificationKey, UnificationTable};
use itertools::Itertools;
use nac3parser::ast::StrRef;
// change this to enum, only local needs unification key
pub type Lifetime = UnificationKey;
#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash)]
pub enum LifetimeKind {
// can be assigned to fields of anything
// can be returned
// lifetime of static values
Global,
// can only be assigned to fields of objects with local lifetime
// can be returned
// lifetime of parameters
NonLocal,
// can only be assigned to fields of objects with local lifetime
// cannot be returned
// lifetime of constructor return values
Local,
// can only be assigned to fields of objects with local lifetime
// cannot be returned
// lifetime of function return values
Unknown,
}
impl std::ops::BitAnd for LifetimeKind {
type Output = Self;
fn bitand(self, other: Self) -> Self {
use LifetimeKind::*;
match (self, other) {
(x, y) if x == y => x,
(Global, NonLocal) | (NonLocal, Global) => NonLocal,
_ => Unknown,
}
}
}
impl std::cmp::PartialOrd for LifetimeKind {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
use LifetimeKind::*;
match (*self, *other) {
(x, y) if x == y => Some(std::cmp::Ordering::Equal),
(Local, _) | (_, Global) => Some(std::cmp::Ordering::Less),
(_, Local) | (Global, _) => Some(std::cmp::Ordering::Greater),
_ => None,
}
}
}
pub struct BlockLifetimeContext {
mapping: Vec<(Option<Lifetime>, Lifetime)>,
}
impl BlockLifetimeContext {
pub fn new() -> Self {
BlockLifetimeContext { mapping: Vec::new() }
}
pub fn add_fresh(&mut self, lifetime: Lifetime) {
self.mapping.push((None, lifetime));
}
}
struct LifetimeEntry {
kind: LifetimeKind,
fields: RefCell<HashMap<StrRef, Lifetime>>,
}
pub struct LifetimeTable {
table: UnificationTable<Rc<LifetimeEntry>>,
cache: HashSet<(Lifetime, Lifetime)>,
}
impl LifetimeTable {
pub fn new() -> Self {
let mut zelf = Self { table: UnificationTable::new(), cache: Default::default() };
zelf.table.new_key(Rc::new(LifetimeEntry {
kind: LifetimeKind::Unknown,
fields: Default::default(),
}));
zelf
}
pub fn add_lifetime(&mut self, kind: LifetimeKind) -> Lifetime {
self.table.new_key(Rc::new(LifetimeEntry { kind, fields: Default::default() }))
}
pub fn unify(&mut self, a: Lifetime, b: Lifetime, ctx: &mut BlockLifetimeContext) {
self.cache.clear();
self.unify_impl(a, b, ctx);
}
fn get_scoped<const N: usize>(
&mut self,
mut lifetimes: [Lifetime; N],
ctx: &mut BlockLifetimeContext,
) -> [Lifetime; N] {
for l in lifetimes.iter_mut() {
let mut result = None;
for (k, v) in ctx.mapping.iter() {
if self.table.unioned(*v, *l) || k.map_or(false, |k| self.table.unioned(k, *l)) {
// already fresh
result = Some(*v);
break;
}
}
if let Some(result) = result {
*l = result;
} else {
let lifetime = self.table.probe_value(*l).clone();
*l = if lifetime.kind == LifetimeKind::Unknown {
UnificationKey(0)
} else {
let k = self.table.new_key(lifetime);
ctx.mapping.push((Some(*l), k));
k
}
}
}
lifetimes
}
fn unify_impl(&mut self, a: Lifetime, b: Lifetime, ctx: &mut BlockLifetimeContext) {
use LifetimeKind::*;
let [a, b] = self.get_scoped([a, b], ctx);
let a = self.table.get_representative(a);
let b = self.table.get_representative(b);
if a == b || self.cache.contains(&(a, b)) || self.cache.contains(&(b, a)) {
return;
}
self.cache.insert((a, b));
let v_a = self.table.probe_value(a).clone();
let v_b = self.table.probe_value(b).clone();
let result_kind = v_a.kind & v_b.kind;
let fields = if result_kind == Local {
// we only need to track fields lifetime for objects with local lifetime
let fields = v_a.fields.clone();
{
let mut fields_ref = fields.borrow_mut();
for (k, v) in v_b.fields.borrow().iter() {
if let Some(old) = fields_ref.insert(k.clone(), *v) {
self.unify_impl(old, *v, ctx);
}
}
}
fields
} else {
Default::default()
};
self.table.unify(a, b);
self.table.set_value(a, Rc::new(LifetimeEntry { kind: result_kind, fields }));
}
pub fn get_field_lifetime(
&mut self,
lifetime: Lifetime,
field: StrRef,
ctx: &mut BlockLifetimeContext,
) -> Lifetime {
use LifetimeKind::*;
let [lifetime] = self.get_scoped([lifetime], ctx);
if let LifetimeEntry { kind: Local, fields } = &*self.table.probe_value(lifetime).clone() {
if let Some(lifetime) = fields.borrow().get(&field) {
*lifetime
} else {
// unknown lifetime
// we can reuse this lifetime because it will never be unified to something else
UnificationKey(0)
}
} else {
lifetime
}
}
pub fn set_field_lifetime(
&mut self,
obj: Lifetime,
field: StrRef,
lifetime: Lifetime,
is_strong_update: bool,
ctx: &mut BlockLifetimeContext,
) -> Result<(), String> {
let [obj, lifetime] = self.get_scoped([obj, lifetime], ctx);
let obj_lifetime = self.table.probe_value(obj).clone();
let field_lifetime = self.table.probe_value(lifetime).clone();
if !(obj_lifetime.kind <= field_lifetime.kind) {
return Err("lifetime error".to_string());
}
let mut fields = obj_lifetime.fields.borrow_mut();
if is_strong_update {
fields.insert(field, lifetime);
} else {
if let Some(old) = fields.insert(field, lifetime) {
self.unify(old, lifetime, ctx);
}
}
Ok(())
}
pub fn get_lifetime_kind(
&mut self,
lifetime: Lifetime,
ctx: &mut BlockLifetimeContext,
) -> LifetimeKind {
let [lifetime] = self.get_scoped([lifetime], ctx);
self.table.probe_value(lifetime).kind
}
pub fn set_function_params(&mut self, lifetime: Lifetime, ctx: &mut BlockLifetimeContext) {
use LifetimeKind::*;
// unify each field with global
let [lifetime] = self.get_scoped([lifetime], ctx);
let lifetime = self.table.probe_value(lifetime).clone();
let mut worklist = lifetime.fields.borrow().values().copied().collect_vec();
while let Some(item) = worklist.pop() {
let [item] = self.get_scoped([item], ctx);
let lifetime = self.table.probe_value(item).clone();
if lifetime.kind == Unknown || lifetime.kind == Global {
continue;
}
let fields = lifetime.fields.borrow().clone();
for (_, v) in fields.iter() {
worklist.push(*v);
}
self.table.set_value(
item,
Rc::new(LifetimeEntry {
kind: lifetime.kind & Global,
fields: RefCell::new(fields),
}),
);
}
}
pub fn get_unknown_lifetime(&self) -> Lifetime {
UnificationKey(0)
}
pub fn equiv(&mut self, a: Lifetime, b: Lifetime, ctx: &mut BlockLifetimeContext) -> bool {
use LifetimeKind::Local;
let [a, b] = self.get_scoped([a, b], ctx);
if self.table.unioned(a, b) {
return true;
}
let lifetime_a = self.table.probe_value(a).clone();
let lifetime_b = self.table.probe_value(b).clone();
if lifetime_a.kind == Local && lifetime_b.kind == Local {
let fields_a = lifetime_a.fields.borrow();
let fields_b = lifetime_b.fields.borrow();
for (k, v) in fields_a.iter() {
if fields_b.get(k).map(|v1| self.equiv(*v, *v1, ctx)) != Some(true) {
return false;
}
}
// they are just equivalent
// this can avoid infinite recursion
self.table.unify(a, b);
true
} else {
lifetime_a.kind == lifetime_b.kind
}
}
}

View File

@ -0,0 +1,317 @@
use std::{collections::HashMap, sync::Arc};
use nac3parser::ast::{Constant, Expr, ExprKind, StrRef};
use crate::{
symbol_resolver::SymbolResolver,
toplevel::{TopLevelContext, TopLevelDef},
};
use self::lifetime::{BlockLifetimeContext, Lifetime, LifetimeTable};
use super::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier},
};
pub mod lifetime;
struct LifetimeContext<'a> {
variable_mapping: HashMap<StrRef, (Lifetime, bool)>,
scope_ctx: BlockLifetimeContext,
lifetime_table: LifetimeTable,
primitive_store: PrimitiveStore,
unifier: &'a mut Unifier,
resolver: Arc<dyn SymbolResolver + Send + Sync>,
top_level: &'a TopLevelContext,
}
impl<'a> LifetimeContext<'a> {
pub fn new(
unifier: &'a mut Unifier,
primitive_store: PrimitiveStore,
resolver: Arc<dyn SymbolResolver + Send + Sync>,
top_level: &'a TopLevelContext,
) -> LifetimeContext<'a> {
LifetimeContext {
variable_mapping: HashMap::new(),
scope_ctx: BlockLifetimeContext::new(),
lifetime_table: LifetimeTable::new(),
primitive_store,
unifier,
resolver,
top_level,
}
}
fn get_expr_lifetime(
&mut self,
expr: &Expr<Option<Type>>,
) -> Result<Option<(Lifetime, bool)>, String> {
let ty = expr.custom.unwrap();
let is_primitive = self.unifier.unioned(ty, self.primitive_store.int32)
|| self.unifier.unioned(ty, self.primitive_store.int64)
|| self.unifier.unioned(ty, self.primitive_store.uint32)
|| self.unifier.unioned(ty, self.primitive_store.uint64)
|| self.unifier.unioned(ty, self.primitive_store.float)
|| self.unifier.unioned(ty, self.primitive_store.bool)
|| self.unifier.unioned(ty, self.primitive_store.none)
|| self.unifier.unioned(ty, self.primitive_store.range);
Ok(match &expr.node {
ExprKind::Name { id, .. } => {
if let Some(lifetime) = self.variable_mapping.get(id) {
Some(*lifetime)
} else {
if is_primitive {
None
} else {
let lifetime =
self.lifetime_table.add_lifetime(lifetime::LifetimeKind::Global);
self.variable_mapping.insert(id.clone(), (lifetime, false));
Some((lifetime, false))
}
}
}
ExprKind::Attribute { value, attr, .. } => {
if is_primitive {
self.get_expr_lifetime(value)?;
None
} else {
self.get_expr_lifetime(value)?.map(|lifetime| {
(
self.lifetime_table.get_field_lifetime(
lifetime.0,
*attr,
&mut self.scope_ctx,
),
false, // not sure if it is strong update for now...
)
})
}
}
ExprKind::Constant { .. } => {
if is_primitive {
None
} else {
Some((self.lifetime_table.add_lifetime(lifetime::LifetimeKind::Global), false))
}
}
ExprKind::List { elts, .. } => {
let elems =
elts.iter()
.map(|expr| self.get_expr_lifetime(expr))
.collect::<Result<Vec<_>, _>>()?;
let elem = elems.into_iter().reduce(|prev, next| {
if prev.is_some() {
self.lifetime_table.unify(
prev.unwrap().0,
next.unwrap().0,
&mut self.scope_ctx,
);
}
prev
});
let list_lifetime = self.lifetime_table.add_lifetime(lifetime::LifetimeKind::Local);
if let Some(Some(elem)) = elem {
self.lifetime_table
.set_field_lifetime(
list_lifetime,
"elem".into(),
elem.0,
true,
&mut self.scope_ctx,
)
.unwrap();
}
Some((list_lifetime, true))
}
ExprKind::Subscript { value, slice, .. } => {
// value must be a list, so lifetime cannot be None
let (value_lifetime, _) = self.get_expr_lifetime(value)?.unwrap();
match &slice.node {
ExprKind::Slice { lower, upper, step } => {
for expr in [lower, upper, step].iter().filter_map(|x| x.as_ref()) {
// account for side effects when computing the slice
self.get_expr_lifetime(expr)?;
}
Some((
self.lifetime_table.add_lifetime(lifetime::LifetimeKind::Local),
true,
))
}
ExprKind::Constant { value: Constant::Int(v), .. } => {
if is_primitive {
None
} else if let TypeEnum::TList { .. } =
&*self.unifier.get_ty(value.custom.unwrap())
{
Some((
self.lifetime_table.get_field_lifetime(
value_lifetime,
"elem".into(),
&mut self.scope_ctx,
),
false,
))
} else {
// tuple
Some((
self.lifetime_table.get_field_lifetime(
value_lifetime,
format!("elem{}", v).into(),
&mut self.scope_ctx,
),
false,
))
}
}
_ => {
// account for side effects when computing the index
self.get_expr_lifetime(slice)?;
if is_primitive {
None
} else {
Some((
self.lifetime_table.get_field_lifetime(
value_lifetime,
"elem".into(),
&mut self.scope_ctx,
),
false,
))
}
}
}
}
ExprKind::Tuple { elts, .. } => {
let elems =
elts.iter()
.map(|expr| self.get_expr_lifetime(expr))
.collect::<Result<Vec<_>, _>>()?;
let tuple_lifetime =
self.lifetime_table.add_lifetime(lifetime::LifetimeKind::Local);
for (i, lifetime) in elems.into_iter().enumerate() {
if let Some((lifetime, _)) = lifetime {
self.lifetime_table
.set_field_lifetime(
tuple_lifetime,
format!("elem{}", i).into(),
lifetime,
true,
&mut self.scope_ctx,
)
.unwrap();
}
}
Some((tuple_lifetime, true))
}
ExprKind::Call { func, args, keywords } => {
let mut lifetimes = Vec::new();
for arg in args.iter() {
if let Some(lifetime) = self.get_expr_lifetime(arg)? {
lifetimes.push(lifetime.0);
}
}
for keyword in keywords.iter() {
if let Some(lifetime) = self.get_expr_lifetime(&keyword.node.value)? {
lifetimes.push(lifetime.0);
}
}
match &func.node {
ExprKind::Name { id, .. } => {
for lifetime in lifetimes.into_iter() {
self.lifetime_table.set_function_params(lifetime, &mut self.scope_ctx);
}
if is_primitive {
None
} else {
let id = self
.resolver
.get_identifier_def(*id)
.map_err(|e| format!("{} (at {})", e, func.location))?;
// constructors
if let TopLevelDef::Class { .. } =
&*self.top_level.definitions.read()[id.0].read()
{
Some((
self.lifetime_table.add_lifetime(lifetime::LifetimeKind::Local),
true,
))
} else {
Some((self.lifetime_table.get_unknown_lifetime(), false))
}
}
}
ExprKind::Attribute { value, .. } => {
if let Some(lifetime) = self.get_expr_lifetime(value)? {
lifetimes.push(lifetime.0);
}
for lifetime in lifetimes.into_iter() {
self.lifetime_table.set_function_params(lifetime, &mut self.scope_ctx);
}
if is_primitive {
None
} else {
Some((self.lifetime_table.get_unknown_lifetime(), false))
}
}
_ => unimplemented!(),
}
}
ExprKind::BinOp { left, right, .. } => {
let mut lifetimes = Vec::new();
if let Some(l) = self.get_expr_lifetime(left)? {
lifetimes.push(l.0);
}
if let Some(l) = self.get_expr_lifetime(right)? {
lifetimes.push(l.0);
}
for lifetime in lifetimes.into_iter() {
self.lifetime_table.set_function_params(lifetime, &mut self.scope_ctx);
}
if is_primitive {
None
} else {
Some((self.lifetime_table.get_unknown_lifetime(), false))
}
}
ExprKind::BoolOp { values, .. } => {
for v in values {
self.get_expr_lifetime(v)?;
}
None
}
ExprKind::UnaryOp { operand, .. } => {
if let Some(l) = self.get_expr_lifetime(operand)? {
self.lifetime_table.set_function_params(l.0, &mut self.scope_ctx);
}
if is_primitive {
None
} else {
Some((self.lifetime_table.get_unknown_lifetime(), false))
}
}
ExprKind::Compare { left, comparators, .. } => {
let mut lifetimes = Vec::new();
if let Some(l) = self.get_expr_lifetime(left)? {
lifetimes.push(l.0);
}
for c in comparators {
if let Some(l) = self.get_expr_lifetime(c)? {
lifetimes.push(l.0);
}
}
for lifetime in lifetimes.into_iter() {
self.lifetime_table.set_function_params(lifetime, &mut self.scope_ctx);
}
// compare should give bool output, which does not have lifetime
None
}
// TODO: listcomp, ifexpr
_ => unimplemented!(),
})
}
}

View File

@ -4,3 +4,4 @@ pub mod type_error;
pub mod type_inferencer; pub mod type_inferencer;
pub mod typedef; pub mod typedef;
mod unification_table; mod unification_table;
pub mod escape_analysis;

View File

@ -3,7 +3,7 @@ use std::rc::Rc;
use itertools::izip; use itertools::izip;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct UnificationKey(usize); pub struct UnificationKey(pub(crate) usize);
#[derive(Clone)] #[derive(Clone)]
pub struct UnificationTable<V> { pub struct UnificationTable<V> {
@ -44,6 +44,12 @@ impl<V> UnificationTable<V> {
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 } UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 }
} }
fn log_action(&mut self, action: Action<V>) {
if !self.log.is_empty() {
self.log.push(action);
}
}
pub fn new_key(&mut self, v: V) -> UnificationKey { pub fn new_key(&mut self, v: V) -> UnificationKey {
let index = self.parents.len(); let index = self.parents.len();
self.parents.push(index); self.parents.push(index);
@ -61,10 +67,10 @@ impl<V> UnificationTable<V> {
if self.ranks[a] < self.ranks[b] { if self.ranks[a] < self.ranks[b] {
std::mem::swap(&mut a, &mut b); std::mem::swap(&mut a, &mut b);
} }
self.log.push(Action::Parent { key: b, original_parent: self.parents[b] }); self.log_action(Action::Parent { key: b, original_parent: self.parents[b] });
self.parents[b] = a; self.parents[b] = a;
if self.ranks[a] == self.ranks[b] { if self.ranks[a] == self.ranks[b] {
self.log.push(Action::Rank { key: a, original_rank: self.ranks[a] }); self.log_action(Action::Rank { key: a, original_rank: self.ranks[a] });
self.ranks[a] += 1; self.ranks[a] += 1;
} }
} }
@ -88,7 +94,7 @@ impl<V> UnificationTable<V> {
pub fn set_value(&mut self, a: UnificationKey, v: V) { pub fn set_value(&mut self, a: UnificationKey, v: V) {
let index = self.find(a); let index = self.find(a);
let original_value = self.values[index].replace(v); let original_value = self.values[index].replace(v);
self.log.push(Action::Value { key: index, original_value }); self.log_action(Action::Value { key: index, original_value });
} }
pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool { pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool {
@ -106,7 +112,7 @@ impl<V> UnificationTable<V> {
// a = parent.parent // a = parent.parent
let a = self.parents[parent]; let a = self.parents[parent];
// root.parent = parent.parent // root.parent = parent.parent
self.log.push(Action::Parent { key: root, original_parent: self.parents[root] }); self.log_action(Action::Parent { key: root, original_parent: self.parents[root] });
self.parents[root] = a; self.parents[root] = a;
root = parent; root = parent;
// parent = root.parent // parent = root.parent