diff --git a/nac3core/src/typecheck/unification_table.rs b/nac3core/src/typecheck/unification_table.rs index 19f836d8..619c23a7 100644 --- a/nac3core/src/typecheck/unification_table.rs +++ b/nac3core/src/typecheck/unification_table.rs @@ -1,5 +1,7 @@ use std::rc::Rc; +use itertools::izip; + #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] pub struct UnificationKey(usize); @@ -7,7 +9,7 @@ pub struct UnificationKey(usize); pub struct UnificationTable { parents: Vec, ranks: Vec, - values: Vec, + values: Vec>, } impl UnificationTable { @@ -19,7 +21,7 @@ impl UnificationTable { let index = self.parents.len(); self.parents.push(index); self.ranks.push(0); - self.values.push(v); + self.values.push(Some(v)); UnificationKey(index) } @@ -40,12 +42,12 @@ impl UnificationTable { pub fn probe_value(&mut self, a: UnificationKey) -> &V { let index = self.find(a); - &self.values[index] + self.values[index].as_ref().unwrap() } pub fn set_value(&mut self, a: UnificationKey, v: V) { let index = self.find(a); - self.values[index] = v; + self.values[index] = Some(v); } pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool { @@ -77,12 +79,15 @@ where V: Clone, { pub fn get_send(&self) -> UnificationTable { - let values = self.values.iter().map(|v| v.as_ref().clone()).collect(); + let values = izip!(self.values.iter(), self.parents.iter()) + .enumerate() + .map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None }) + .collect(); UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values } } pub fn from_send(table: &UnificationTable) -> UnificationTable> { - let values = table.values.iter().cloned().map(Rc::new).collect(); + let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect(); UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values } } }