diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 579d80af..0af37127 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -161,6 +161,7 @@ pub struct Unifier { pub(super) calls: Vec>, var_id: u32, unify_cache: HashSet<(Type, Type)>, + snapshot: Option<(usize, u32)> } impl Default for Unifier { @@ -178,6 +179,7 @@ impl Unifier { calls: Vec::new(), unify_cache: HashSet::new(), top_level: None, + snapshot: None, } } @@ -198,6 +200,7 @@ impl Unifier { calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(), top_level: None, unify_cache: HashSet::new(), + snapshot: None, } } @@ -383,6 +386,19 @@ impl Unifier { } } + fn restore_snapshot(&mut self) { + if let Some(snapshot) = self.snapshot.take() { + self.unification_table.restore_snapshot(snapshot); + } + } + + fn discard_snapshot(&mut self, snapshot: (usize, u32)) { + if self.snapshot == Some(snapshot) { + self.unification_table.discard_snapshot(snapshot); + self.snapshot = None; + } + } + pub fn unify_call( &mut self, call: &Call, @@ -390,6 +406,11 @@ impl Unifier { signature: &FunSignature, required: &[StrRef], ) -> Result<(), TypeError> { + let snapshot = self.unification_table.get_snapshot(); + if self.snapshot.is_none() { + self.snapshot = Some(snapshot); + } + let Call { posargs, kwargs, ret, fun, loc } = call; let instantiated = self.instantiate_fun(b, &*signature); let r = self.get_ty(instantiated); @@ -414,6 +435,7 @@ impl Unifier { required.pop(); let (name, expected) = all_names.pop().unwrap(); self.unify_impl(expected, *t, false).map_err(|_| { + self.restore_snapshot(); TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) })?; } @@ -424,34 +446,51 @@ impl Unifier { let i = all_names .iter() .position(|v| &v.0 == k) - .ok_or_else(|| TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc))?; + .ok_or_else(|| { + self.restore_snapshot(); + TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc) + })?; let (name, expected) = all_names.remove(i); self.unify_impl(expected, *t, false).map_err(|_| { + self.restore_snapshot(); TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) })?; } if !required.is_empty() { + self.restore_snapshot(); return Err(TypeError::new( TypeErrorKind::MissingArgs(required.iter().join(", ")), *loc, )); } self.unify_impl(*ret, signature.ret, false).map_err(|mut err| { + self.restore_snapshot(); if err.loc.is_none() { err.loc = *loc; } err })?; *fun.borrow_mut() = Some(instantiated); + + self.discard_snapshot(snapshot); Ok(()) } pub fn unify(&mut self, a: Type, b: Type) -> Result<(), TypeError> { + let snapshot = self.unification_table.get_snapshot(); + if self.snapshot.is_none() { + self.snapshot = Some(snapshot); + } self.unify_cache.clear(); if self.unification_table.unioned(a, b) { Ok(()) } else { - self.unify_impl(a, b, false) + let result = self.unify_impl(a, b, false); + if result.is_err() { + self.restore_snapshot(); + } + self.discard_snapshot(snapshot); + result } } diff --git a/nac3core/src/typecheck/unification_table.rs b/nac3core/src/typecheck/unification_table.rs index 27df8dbe..70864154 100644 --- a/nac3core/src/typecheck/unification_table.rs +++ b/nac3core/src/typecheck/unification_table.rs @@ -10,6 +10,27 @@ pub struct UnificationTable { parents: Vec, ranks: Vec, values: Vec>, + log: Vec>, + generation: u32, +} + +#[derive(Clone, Debug)] +enum Action { + Parent { + key: usize, + original_parent: usize, + }, + Value { + key: usize, + original_value: Option, + }, + Rank { + key: usize, + original_rank: u32, + }, + Marker { + generation: u32, + } } impl Default for UnificationTable { @@ -20,7 +41,7 @@ impl Default for UnificationTable { impl UnificationTable { pub fn new() -> UnificationTable { - UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new() } + UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 } } pub fn new_key(&mut self, v: V) -> UnificationKey { @@ -42,6 +63,7 @@ impl UnificationTable { } self.parents[b] = a; if self.ranks[a] == self.ranks[b] { + self.log.push(Action::Rank { key: a, original_rank: self.ranks[a] }); self.ranks[a] += 1; } } @@ -64,7 +86,8 @@ impl UnificationTable { pub fn set_value(&mut self, a: UnificationKey, v: V) { let index = self.find(a); - self.values[index] = Some(v); + let original_value = self.values[index].replace(v); + self.log.push(Action::Value { key: index, original_value }); } pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool { @@ -82,6 +105,7 @@ impl UnificationTable { // a = parent.parent let a = self.parents[parent]; // root.parent = parent.parent + self.log.push(Action::Parent { key: root, original_parent: a }); self.parents[root] = a; root = parent; // parent = root.parent @@ -89,6 +113,40 @@ impl UnificationTable { } parent } + + pub fn get_snapshot(&mut self) -> (usize, u32) { + let generation = self.generation; + self.log.push(Action::Marker { generation }); + self.generation += 1; + (self.log.len(), generation) + } + + pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) { + let (log_len, generation) = snapshot; + assert!(self.log.len() >= log_len, "snapshot restoration error"); + assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot restoration error"); + for action in self.log.drain(log_len - 1..).rev() { + match action { + Action::Parent { key, original_parent } => { + self.parents[key] = original_parent; + } + Action::Value { key, original_value } => { + self.values[key] = original_value; + } + Action::Rank { key, original_rank } => { + self.ranks[key] = original_rank; + } + Action::Marker { .. } => {} + } + } + } + + pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) { + let (log_len, generation) = snapshot; + assert!(self.log.len() >= log_len, "snapshot discard error"); + assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if generation == gen), "snapshot discard error"); + self.log.truncate(log_len - 1); + } } impl UnificationTable> @@ -100,11 +158,11 @@ where .enumerate() .map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None }) .collect(); - UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values } + UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values, log: Vec::new(), generation: 0 } } pub fn from_send(table: &UnificationTable) -> UnificationTable> { let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect(); - UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values } + UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values, log: Vec::new(), generation: 0 } } }