forked from M-Labs/nac3
nac3core: reset unification table state before printing errors
Fixes nondeterministic error messages due to nondeterministic unification order. As all unification operations will be restored, the error messages should not be affected by the unification order before the failure operation.
This commit is contained in:
parent
5cd4fe6507
commit
cc769a7006
|
@ -161,6 +161,7 @@ pub struct Unifier {
|
||||||
pub(super) calls: Vec<Rc<Call>>,
|
pub(super) calls: Vec<Rc<Call>>,
|
||||||
var_id: u32,
|
var_id: u32,
|
||||||
unify_cache: HashSet<(Type, Type)>,
|
unify_cache: HashSet<(Type, Type)>,
|
||||||
|
snapshot: Option<(usize, u32)>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Unifier {
|
impl Default for Unifier {
|
||||||
|
@ -178,6 +179,7 @@ impl Unifier {
|
||||||
calls: Vec::new(),
|
calls: Vec::new(),
|
||||||
unify_cache: HashSet::new(),
|
unify_cache: HashSet::new(),
|
||||||
top_level: None,
|
top_level: None,
|
||||||
|
snapshot: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,6 +200,7 @@ impl Unifier {
|
||||||
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
|
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
|
||||||
top_level: None,
|
top_level: None,
|
||||||
unify_cache: HashSet::new(),
|
unify_cache: HashSet::new(),
|
||||||
|
snapshot: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -383,6 +386,19 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn restore_snapshot(&mut self) {
|
||||||
|
if let Some(snapshot) = self.snapshot.take() {
|
||||||
|
self.unification_table.restore_snapshot(snapshot);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn discard_snapshot(&mut self, snapshot: (usize, u32)) {
|
||||||
|
if self.snapshot == Some(snapshot) {
|
||||||
|
self.unification_table.discard_snapshot(snapshot);
|
||||||
|
self.snapshot = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn unify_call(
|
pub fn unify_call(
|
||||||
&mut self,
|
&mut self,
|
||||||
call: &Call,
|
call: &Call,
|
||||||
|
@ -390,6 +406,11 @@ impl Unifier {
|
||||||
signature: &FunSignature,
|
signature: &FunSignature,
|
||||||
required: &[StrRef],
|
required: &[StrRef],
|
||||||
) -> Result<(), TypeError> {
|
) -> Result<(), TypeError> {
|
||||||
|
let snapshot = self.unification_table.get_snapshot();
|
||||||
|
if self.snapshot.is_none() {
|
||||||
|
self.snapshot = Some(snapshot);
|
||||||
|
}
|
||||||
|
|
||||||
let Call { posargs, kwargs, ret, fun, loc } = call;
|
let Call { posargs, kwargs, ret, fun, loc } = call;
|
||||||
let instantiated = self.instantiate_fun(b, &*signature);
|
let instantiated = self.instantiate_fun(b, &*signature);
|
||||||
let r = self.get_ty(instantiated);
|
let r = self.get_ty(instantiated);
|
||||||
|
@ -414,6 +435,7 @@ impl Unifier {
|
||||||
required.pop();
|
required.pop();
|
||||||
let (name, expected) = all_names.pop().unwrap();
|
let (name, expected) = all_names.pop().unwrap();
|
||||||
self.unify_impl(expected, *t, false).map_err(|_| {
|
self.unify_impl(expected, *t, false).map_err(|_| {
|
||||||
|
self.restore_snapshot();
|
||||||
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
@ -424,34 +446,51 @@ impl Unifier {
|
||||||
let i = all_names
|
let i = all_names
|
||||||
.iter()
|
.iter()
|
||||||
.position(|v| &v.0 == k)
|
.position(|v| &v.0 == k)
|
||||||
.ok_or_else(|| TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc))?;
|
.ok_or_else(|| {
|
||||||
|
self.restore_snapshot();
|
||||||
|
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
|
||||||
|
})?;
|
||||||
let (name, expected) = all_names.remove(i);
|
let (name, expected) = all_names.remove(i);
|
||||||
self.unify_impl(expected, *t, false).map_err(|_| {
|
self.unify_impl(expected, *t, false).map_err(|_| {
|
||||||
|
self.restore_snapshot();
|
||||||
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
if !required.is_empty() {
|
if !required.is_empty() {
|
||||||
|
self.restore_snapshot();
|
||||||
return Err(TypeError::new(
|
return Err(TypeError::new(
|
||||||
TypeErrorKind::MissingArgs(required.iter().join(", ")),
|
TypeErrorKind::MissingArgs(required.iter().join(", ")),
|
||||||
*loc,
|
*loc,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
|
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
|
||||||
|
self.restore_snapshot();
|
||||||
if err.loc.is_none() {
|
if err.loc.is_none() {
|
||||||
err.loc = *loc;
|
err.loc = *loc;
|
||||||
}
|
}
|
||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
*fun.borrow_mut() = Some(instantiated);
|
*fun.borrow_mut() = Some(instantiated);
|
||||||
|
|
||||||
|
self.discard_snapshot(snapshot);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), TypeError> {
|
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), TypeError> {
|
||||||
|
let snapshot = self.unification_table.get_snapshot();
|
||||||
|
if self.snapshot.is_none() {
|
||||||
|
self.snapshot = Some(snapshot);
|
||||||
|
}
|
||||||
self.unify_cache.clear();
|
self.unify_cache.clear();
|
||||||
if self.unification_table.unioned(a, b) {
|
if self.unification_table.unioned(a, b) {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
self.unify_impl(a, b, false)
|
let result = self.unify_impl(a, b, false);
|
||||||
|
if result.is_err() {
|
||||||
|
self.restore_snapshot();
|
||||||
|
}
|
||||||
|
self.discard_snapshot(snapshot);
|
||||||
|
result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,27 @@ pub struct UnificationTable<V> {
|
||||||
parents: Vec<usize>,
|
parents: Vec<usize>,
|
||||||
ranks: Vec<u32>,
|
ranks: Vec<u32>,
|
||||||
values: Vec<Option<V>>,
|
values: Vec<Option<V>>,
|
||||||
|
log: Vec<Action<V>>,
|
||||||
|
generation: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
enum Action<V> {
|
||||||
|
Parent {
|
||||||
|
key: usize,
|
||||||
|
original_parent: usize,
|
||||||
|
},
|
||||||
|
Value {
|
||||||
|
key: usize,
|
||||||
|
original_value: Option<V>,
|
||||||
|
},
|
||||||
|
Rank {
|
||||||
|
key: usize,
|
||||||
|
original_rank: u32,
|
||||||
|
},
|
||||||
|
Marker {
|
||||||
|
generation: u32,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V> Default for UnificationTable<V> {
|
impl<V> Default for UnificationTable<V> {
|
||||||
|
@ -20,7 +41,7 @@ impl<V> Default for UnificationTable<V> {
|
||||||
|
|
||||||
impl<V> UnificationTable<V> {
|
impl<V> UnificationTable<V> {
|
||||||
pub fn new() -> UnificationTable<V> {
|
pub fn new() -> UnificationTable<V> {
|
||||||
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new() }
|
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_key(&mut self, v: V) -> UnificationKey {
|
pub fn new_key(&mut self, v: V) -> UnificationKey {
|
||||||
|
@ -42,6 +63,7 @@ impl<V> UnificationTable<V> {
|
||||||
}
|
}
|
||||||
self.parents[b] = a;
|
self.parents[b] = a;
|
||||||
if self.ranks[a] == self.ranks[b] {
|
if self.ranks[a] == self.ranks[b] {
|
||||||
|
self.log.push(Action::Rank { key: a, original_rank: self.ranks[a] });
|
||||||
self.ranks[a] += 1;
|
self.ranks[a] += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -64,7 +86,8 @@ impl<V> UnificationTable<V> {
|
||||||
|
|
||||||
pub fn set_value(&mut self, a: UnificationKey, v: V) {
|
pub fn set_value(&mut self, a: UnificationKey, v: V) {
|
||||||
let index = self.find(a);
|
let index = self.find(a);
|
||||||
self.values[index] = Some(v);
|
let original_value = self.values[index].replace(v);
|
||||||
|
self.log.push(Action::Value { key: index, original_value });
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool {
|
pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool {
|
||||||
|
@ -82,6 +105,7 @@ impl<V> UnificationTable<V> {
|
||||||
// a = parent.parent
|
// a = parent.parent
|
||||||
let a = self.parents[parent];
|
let a = self.parents[parent];
|
||||||
// root.parent = parent.parent
|
// root.parent = parent.parent
|
||||||
|
self.log.push(Action::Parent { key: root, original_parent: a });
|
||||||
self.parents[root] = a;
|
self.parents[root] = a;
|
||||||
root = parent;
|
root = parent;
|
||||||
// parent = root.parent
|
// parent = root.parent
|
||||||
|
@ -89,6 +113,40 @@ impl<V> UnificationTable<V> {
|
||||||
}
|
}
|
||||||
parent
|
parent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_snapshot(&mut self) -> (usize, u32) {
|
||||||
|
let generation = self.generation;
|
||||||
|
self.log.push(Action::Marker { generation });
|
||||||
|
self.generation += 1;
|
||||||
|
(self.log.len(), generation)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) {
|
||||||
|
let (log_len, generation) = snapshot;
|
||||||
|
assert!(self.log.len() >= log_len, "snapshot restoration error");
|
||||||
|
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot restoration error");
|
||||||
|
for action in self.log.drain(log_len - 1..).rev() {
|
||||||
|
match action {
|
||||||
|
Action::Parent { key, original_parent } => {
|
||||||
|
self.parents[key] = original_parent;
|
||||||
|
}
|
||||||
|
Action::Value { key, original_value } => {
|
||||||
|
self.values[key] = original_value;
|
||||||
|
}
|
||||||
|
Action::Rank { key, original_rank } => {
|
||||||
|
self.ranks[key] = original_rank;
|
||||||
|
}
|
||||||
|
Action::Marker { .. } => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) {
|
||||||
|
let (log_len, generation) = snapshot;
|
||||||
|
assert!(self.log.len() >= log_len, "snapshot discard error");
|
||||||
|
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if generation == gen), "snapshot discard error");
|
||||||
|
self.log.truncate(log_len - 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V> UnificationTable<Rc<V>>
|
impl<V> UnificationTable<Rc<V>>
|
||||||
|
@ -100,11 +158,11 @@ where
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None })
|
.map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None })
|
||||||
.collect();
|
.collect();
|
||||||
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values }
|
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values, log: Vec::new(), generation: 0 }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
|
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
|
||||||
let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect();
|
let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect();
|
||||||
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values }
|
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values, log: Vec::new(), generation: 0 }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue