From cc769a70066963da8204cda7c272c16798ef48ac Mon Sep 17 00:00:00 2001 From: pca006132 Date: Fri, 25 Feb 2022 14:47:19 +0800 Subject: [PATCH] nac3core: reset unification table state before printing errors Fixes nondeterministic error messages due to nondeterministic unification order. As all unification operations will be restored, the error messages should not be affected by the unification order before the failure operation. --- nac3core/src/typecheck/typedef/mod.rs | 43 +++++++++++++- nac3core/src/typecheck/unification_table.rs | 66 +++++++++++++++++++-- 2 files changed, 103 insertions(+), 6 deletions(-) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 579d80af..0af37127 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -161,6 +161,7 @@ pub struct Unifier { pub(super) calls: Vec>, var_id: u32, unify_cache: HashSet<(Type, Type)>, + snapshot: Option<(usize, u32)> } impl Default for Unifier { @@ -178,6 +179,7 @@ impl Unifier { calls: Vec::new(), unify_cache: HashSet::new(), top_level: None, + snapshot: None, } } @@ -198,6 +200,7 @@ impl Unifier { calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(), top_level: None, unify_cache: HashSet::new(), + snapshot: None, } } @@ -383,6 +386,19 @@ impl Unifier { } } + fn restore_snapshot(&mut self) { + if let Some(snapshot) = self.snapshot.take() { + self.unification_table.restore_snapshot(snapshot); + } + } + + fn discard_snapshot(&mut self, snapshot: (usize, u32)) { + if self.snapshot == Some(snapshot) { + self.unification_table.discard_snapshot(snapshot); + self.snapshot = None; + } + } + pub fn unify_call( &mut self, call: &Call, @@ -390,6 +406,11 @@ impl Unifier { signature: &FunSignature, required: &[StrRef], ) -> Result<(), TypeError> { + let snapshot = self.unification_table.get_snapshot(); + if self.snapshot.is_none() { + self.snapshot = Some(snapshot); + } + let Call { posargs, kwargs, ret, fun, loc } = call; let instantiated = self.instantiate_fun(b, &*signature); let r = self.get_ty(instantiated); @@ -414,6 +435,7 @@ impl Unifier { required.pop(); let (name, expected) = all_names.pop().unwrap(); self.unify_impl(expected, *t, false).map_err(|_| { + self.restore_snapshot(); TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) })?; } @@ -424,34 +446,51 @@ impl Unifier { let i = all_names .iter() .position(|v| &v.0 == k) - .ok_or_else(|| TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc))?; + .ok_or_else(|| { + self.restore_snapshot(); + TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc) + })?; let (name, expected) = all_names.remove(i); self.unify_impl(expected, *t, false).map_err(|_| { + self.restore_snapshot(); TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) })?; } if !required.is_empty() { + self.restore_snapshot(); return Err(TypeError::new( TypeErrorKind::MissingArgs(required.iter().join(", ")), *loc, )); } self.unify_impl(*ret, signature.ret, false).map_err(|mut err| { + self.restore_snapshot(); if err.loc.is_none() { err.loc = *loc; } err })?; *fun.borrow_mut() = Some(instantiated); + + self.discard_snapshot(snapshot); Ok(()) } pub fn unify(&mut self, a: Type, b: Type) -> Result<(), TypeError> { + let snapshot = self.unification_table.get_snapshot(); + if self.snapshot.is_none() { + self.snapshot = Some(snapshot); + } self.unify_cache.clear(); if self.unification_table.unioned(a, b) { Ok(()) } else { - self.unify_impl(a, b, false) + let result = self.unify_impl(a, b, false); + if result.is_err() { + self.restore_snapshot(); + } + self.discard_snapshot(snapshot); + result } } diff --git a/nac3core/src/typecheck/unification_table.rs b/nac3core/src/typecheck/unification_table.rs index 27df8dbe..70864154 100644 --- a/nac3core/src/typecheck/unification_table.rs +++ b/nac3core/src/typecheck/unification_table.rs @@ -10,6 +10,27 @@ pub struct UnificationTable { parents: Vec, ranks: Vec, values: Vec>, + log: Vec>, + generation: u32, +} + +#[derive(Clone, Debug)] +enum Action { + Parent { + key: usize, + original_parent: usize, + }, + Value { + key: usize, + original_value: Option, + }, + Rank { + key: usize, + original_rank: u32, + }, + Marker { + generation: u32, + } } impl Default for UnificationTable { @@ -20,7 +41,7 @@ impl Default for UnificationTable { impl UnificationTable { pub fn new() -> UnificationTable { - UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new() } + UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 } } pub fn new_key(&mut self, v: V) -> UnificationKey { @@ -42,6 +63,7 @@ impl UnificationTable { } self.parents[b] = a; if self.ranks[a] == self.ranks[b] { + self.log.push(Action::Rank { key: a, original_rank: self.ranks[a] }); self.ranks[a] += 1; } } @@ -64,7 +86,8 @@ impl UnificationTable { pub fn set_value(&mut self, a: UnificationKey, v: V) { let index = self.find(a); - self.values[index] = Some(v); + let original_value = self.values[index].replace(v); + self.log.push(Action::Value { key: index, original_value }); } pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool { @@ -82,6 +105,7 @@ impl UnificationTable { // a = parent.parent let a = self.parents[parent]; // root.parent = parent.parent + self.log.push(Action::Parent { key: root, original_parent: a }); self.parents[root] = a; root = parent; // parent = root.parent @@ -89,6 +113,40 @@ impl UnificationTable { } parent } + + pub fn get_snapshot(&mut self) -> (usize, u32) { + let generation = self.generation; + self.log.push(Action::Marker { generation }); + self.generation += 1; + (self.log.len(), generation) + } + + pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) { + let (log_len, generation) = snapshot; + assert!(self.log.len() >= log_len, "snapshot restoration error"); + assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot restoration error"); + for action in self.log.drain(log_len - 1..).rev() { + match action { + Action::Parent { key, original_parent } => { + self.parents[key] = original_parent; + } + Action::Value { key, original_value } => { + self.values[key] = original_value; + } + Action::Rank { key, original_rank } => { + self.ranks[key] = original_rank; + } + Action::Marker { .. } => {} + } + } + } + + pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) { + let (log_len, generation) = snapshot; + assert!(self.log.len() >= log_len, "snapshot discard error"); + assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if generation == gen), "snapshot discard error"); + self.log.truncate(log_len - 1); + } } impl UnificationTable> @@ -100,11 +158,11 @@ where .enumerate() .map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None }) .collect(); - UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values } + UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values, log: Vec::new(), generation: 0 } } pub fn from_send(table: &UnificationTable) -> UnificationTable> { let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect(); - UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values } + UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values, log: Vec::new(), generation: 0 } } }