1
0
forked from M-Labs/nac3

nac3core: reset unification table state before printing errors

Fixes nondeterministic error messages due to nondeterministic
unification order. As all unification operations will be restored, the
error messages should not be affected by the unification order before
the failure operation.
This commit is contained in:
pca006132 2022-02-25 14:47:19 +08:00
parent 5cd4fe6507
commit cc769a7006
2 changed files with 103 additions and 6 deletions

View File

@ -161,6 +161,7 @@ pub struct Unifier {
pub(super) calls: Vec<Rc<Call>>, pub(super) calls: Vec<Rc<Call>>,
var_id: u32, var_id: u32,
unify_cache: HashSet<(Type, Type)>, unify_cache: HashSet<(Type, Type)>,
snapshot: Option<(usize, u32)>
} }
impl Default for Unifier { impl Default for Unifier {
@ -178,6 +179,7 @@ impl Unifier {
calls: Vec::new(), calls: Vec::new(),
unify_cache: HashSet::new(), unify_cache: HashSet::new(),
top_level: None, top_level: None,
snapshot: None,
} }
} }
@ -198,6 +200,7 @@ impl Unifier {
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(), calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
top_level: None, top_level: None,
unify_cache: HashSet::new(), unify_cache: HashSet::new(),
snapshot: None,
} }
} }
@ -383,6 +386,19 @@ impl Unifier {
} }
} }
fn restore_snapshot(&mut self) {
if let Some(snapshot) = self.snapshot.take() {
self.unification_table.restore_snapshot(snapshot);
}
}
fn discard_snapshot(&mut self, snapshot: (usize, u32)) {
if self.snapshot == Some(snapshot) {
self.unification_table.discard_snapshot(snapshot);
self.snapshot = None;
}
}
pub fn unify_call( pub fn unify_call(
&mut self, &mut self,
call: &Call, call: &Call,
@ -390,6 +406,11 @@ impl Unifier {
signature: &FunSignature, signature: &FunSignature,
required: &[StrRef], required: &[StrRef],
) -> Result<(), TypeError> { ) -> Result<(), TypeError> {
let snapshot = self.unification_table.get_snapshot();
if self.snapshot.is_none() {
self.snapshot = Some(snapshot);
}
let Call { posargs, kwargs, ret, fun, loc } = call; let Call { posargs, kwargs, ret, fun, loc } = call;
let instantiated = self.instantiate_fun(b, &*signature); let instantiated = self.instantiate_fun(b, &*signature);
let r = self.get_ty(instantiated); let r = self.get_ty(instantiated);
@ -414,6 +435,7 @@ impl Unifier {
required.pop(); required.pop();
let (name, expected) = all_names.pop().unwrap(); let (name, expected) = all_names.pop().unwrap();
self.unify_impl(expected, *t, false).map_err(|_| { self.unify_impl(expected, *t, false).map_err(|_| {
self.restore_snapshot();
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
})?; })?;
} }
@ -424,34 +446,51 @@ impl Unifier {
let i = all_names let i = all_names
.iter() .iter()
.position(|v| &v.0 == k) .position(|v| &v.0 == k)
.ok_or_else(|| TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc))?; .ok_or_else(|| {
self.restore_snapshot();
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
})?;
let (name, expected) = all_names.remove(i); let (name, expected) = all_names.remove(i);
self.unify_impl(expected, *t, false).map_err(|_| { self.unify_impl(expected, *t, false).map_err(|_| {
self.restore_snapshot();
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
})?; })?;
} }
if !required.is_empty() { if !required.is_empty() {
self.restore_snapshot();
return Err(TypeError::new( return Err(TypeError::new(
TypeErrorKind::MissingArgs(required.iter().join(", ")), TypeErrorKind::MissingArgs(required.iter().join(", ")),
*loc, *loc,
)); ));
} }
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| { self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
self.restore_snapshot();
if err.loc.is_none() { if err.loc.is_none() {
err.loc = *loc; err.loc = *loc;
} }
err err
})?; })?;
*fun.borrow_mut() = Some(instantiated); *fun.borrow_mut() = Some(instantiated);
self.discard_snapshot(snapshot);
Ok(()) Ok(())
} }
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), TypeError> { pub fn unify(&mut self, a: Type, b: Type) -> Result<(), TypeError> {
let snapshot = self.unification_table.get_snapshot();
if self.snapshot.is_none() {
self.snapshot = Some(snapshot);
}
self.unify_cache.clear(); self.unify_cache.clear();
if self.unification_table.unioned(a, b) { if self.unification_table.unioned(a, b) {
Ok(()) Ok(())
} else { } else {
self.unify_impl(a, b, false) let result = self.unify_impl(a, b, false);
if result.is_err() {
self.restore_snapshot();
}
self.discard_snapshot(snapshot);
result
} }
} }

View File

@ -10,6 +10,27 @@ pub struct UnificationTable<V> {
parents: Vec<usize>, parents: Vec<usize>,
ranks: Vec<u32>, ranks: Vec<u32>,
values: Vec<Option<V>>, values: Vec<Option<V>>,
log: Vec<Action<V>>,
generation: u32,
}
#[derive(Clone, Debug)]
enum Action<V> {
Parent {
key: usize,
original_parent: usize,
},
Value {
key: usize,
original_value: Option<V>,
},
Rank {
key: usize,
original_rank: u32,
},
Marker {
generation: u32,
}
} }
impl<V> Default for UnificationTable<V> { impl<V> Default for UnificationTable<V> {
@ -20,7 +41,7 @@ impl<V> Default for UnificationTable<V> {
impl<V> UnificationTable<V> { impl<V> UnificationTable<V> {
pub fn new() -> UnificationTable<V> { pub fn new() -> UnificationTable<V> {
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new() } UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 }
} }
pub fn new_key(&mut self, v: V) -> UnificationKey { pub fn new_key(&mut self, v: V) -> UnificationKey {
@ -42,6 +63,7 @@ impl<V> UnificationTable<V> {
} }
self.parents[b] = a; self.parents[b] = a;
if self.ranks[a] == self.ranks[b] { if self.ranks[a] == self.ranks[b] {
self.log.push(Action::Rank { key: a, original_rank: self.ranks[a] });
self.ranks[a] += 1; self.ranks[a] += 1;
} }
} }
@ -64,7 +86,8 @@ impl<V> UnificationTable<V> {
pub fn set_value(&mut self, a: UnificationKey, v: V) { pub fn set_value(&mut self, a: UnificationKey, v: V) {
let index = self.find(a); let index = self.find(a);
self.values[index] = Some(v); let original_value = self.values[index].replace(v);
self.log.push(Action::Value { key: index, original_value });
} }
pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool { pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool {
@ -82,6 +105,7 @@ impl<V> UnificationTable<V> {
// a = parent.parent // a = parent.parent
let a = self.parents[parent]; let a = self.parents[parent];
// root.parent = parent.parent // root.parent = parent.parent
self.log.push(Action::Parent { key: root, original_parent: a });
self.parents[root] = a; self.parents[root] = a;
root = parent; root = parent;
// parent = root.parent // parent = root.parent
@ -89,6 +113,40 @@ impl<V> UnificationTable<V> {
} }
parent parent
} }
pub fn get_snapshot(&mut self) -> (usize, u32) {
let generation = self.generation;
self.log.push(Action::Marker { generation });
self.generation += 1;
(self.log.len(), generation)
}
pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot restoration error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot restoration error");
for action in self.log.drain(log_len - 1..).rev() {
match action {
Action::Parent { key, original_parent } => {
self.parents[key] = original_parent;
}
Action::Value { key, original_value } => {
self.values[key] = original_value;
}
Action::Rank { key, original_rank } => {
self.ranks[key] = original_rank;
}
Action::Marker { .. } => {}
}
}
}
pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot discard error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if generation == gen), "snapshot discard error");
self.log.truncate(log_len - 1);
}
} }
impl<V> UnificationTable<Rc<V>> impl<V> UnificationTable<Rc<V>>
@ -100,11 +158,11 @@ where
.enumerate() .enumerate()
.map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None }) .map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None })
.collect(); .collect();
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values } UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values, log: Vec::new(), generation: 0 }
} }
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> { pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect(); let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect();
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values } UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values, log: Vec::new(), generation: 0 }
} }
} }