forked from M-Labs/nac3
core: Do not keep unification result for function arguments
For some reason, when unifying a function call parameter with an argument, subsequent calls to the same function will only accept the type of the substituted argument. This affect snippets like: ``` def make1() -> C[Literal[1]]: return ... def make2() -> C[Literal[2]]: return ... def consume(instance: C[Literal[1, 2]]): pass consume(make1()) consume(make2()) ``` The last statement will result in a compiler error, as the parameter of consume is replaced with C[Literal[1]]. We fix this by getting a snapshot before performing unification, and restoring the snapshot after unification succeeds.
This commit is contained in:
parent
0bbc9ce6f5
commit
f09f3c27a5
|
@ -529,10 +529,12 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
required.pop();
|
required.pop();
|
||||||
let (name, expected) = all_names.pop().unwrap();
|
let (name, expected) = all_names.pop().unwrap();
|
||||||
|
let snapshot = self.unification_table.get_snapshot();
|
||||||
self.unify_impl(expected, *t, false).map_err(|_| {
|
self.unify_impl(expected, *t, false).map_err(|_| {
|
||||||
self.restore_snapshot();
|
self.restore_snapshot();
|
||||||
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
||||||
})?;
|
})?;
|
||||||
|
self.unification_table.restore_snapshot(snapshot);
|
||||||
}
|
}
|
||||||
for (k, t) in kwargs {
|
for (k, t) in kwargs {
|
||||||
if let Some(i) = required.iter().position(|v| v == k) {
|
if let Some(i) = required.iter().position(|v| v == k) {
|
||||||
|
@ -546,10 +548,12 @@ impl Unifier {
|
||||||
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
|
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
|
||||||
})?;
|
})?;
|
||||||
let (name, expected) = all_names.remove(i);
|
let (name, expected) = all_names.remove(i);
|
||||||
|
let snapshot = self.unification_table.get_snapshot();
|
||||||
self.unify_impl(expected, *t, false).map_err(|_| {
|
self.unify_impl(expected, *t, false).map_err(|_| {
|
||||||
self.restore_snapshot();
|
self.restore_snapshot();
|
||||||
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
||||||
})?;
|
})?;
|
||||||
|
self.unification_table.restore_snapshot(snapshot);
|
||||||
}
|
}
|
||||||
if !required.is_empty() {
|
if !required.is_empty() {
|
||||||
self.restore_snapshot();
|
self.restore_snapshot();
|
||||||
|
@ -746,18 +750,21 @@ impl Unifier {
|
||||||
|
|
||||||
(TVar { range: tys, is_const_generic: true, .. }, TLiteral { values, .. }) => {
|
(TVar { range: tys, is_const_generic: true, .. }, TLiteral { values, .. }) => {
|
||||||
assert_eq!(tys.len(), 1);
|
assert_eq!(tys.len(), 1);
|
||||||
assert_eq!(values.len(), 1);
|
|
||||||
|
|
||||||
let primitives = &self.primitive_store
|
let primitives = &self.primitive_store
|
||||||
.expect("Expected PrimitiveStore to be present");
|
.expect("Expected PrimitiveStore to be present");
|
||||||
|
|
||||||
let ty = tys[0];
|
let ty = tys[0];
|
||||||
let value= &values[0];
|
|
||||||
|
for value in values {
|
||||||
let value_ty = value.get_type(primitives, self);
|
let value_ty = value.get_type(primitives, self);
|
||||||
|
|
||||||
// If the types don't match, try to implicitly promote integers
|
if self.unioned(ty, value_ty) {
|
||||||
if !self.unioned(ty, value_ty) {
|
self.set_a_to_b(a, b);
|
||||||
|
return Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// The types don't match, try to implicitly promote integers
|
||||||
let num_val = match *value {
|
let num_val = match *value {
|
||||||
SymbolValue::I32(v) => v as i128,
|
SymbolValue::I32(v) => v as i128,
|
||||||
SymbolValue::I64(v) => v as i128,
|
SymbolValue::I64(v) => v as i128,
|
||||||
|
@ -778,20 +785,19 @@ impl Unifier {
|
||||||
false
|
false
|
||||||
};
|
};
|
||||||
|
|
||||||
if !can_convert {
|
if can_convert {
|
||||||
return self.incompatible_types(a, b)
|
self.set_a_to_b(a, b);
|
||||||
|
return Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.set_a_to_b(a, b);
|
return self.incompatible_types(a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
(TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => {
|
(TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => {
|
||||||
for (v1, v2) in zip(val1, val2) {
|
if val2.iter().any(|val| !val1.contains(val)) {
|
||||||
if v1 != v2 {
|
|
||||||
return self.incompatible_types(a, b)
|
return self.incompatible_types(a, b)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,9 @@ class HybridGenericClass2(Generic[A, T]):
|
||||||
class HybridGenericClass3(Generic[T, A, B]):
|
class HybridGenericClass3(Generic[T, A, B]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def make_generic_1() -> ConstGenericClass[Literal[1]]:
|
||||||
|
return ...
|
||||||
|
|
||||||
def make_generic_2() -> ConstGenericClass[Literal[2]]:
|
def make_generic_2() -> ConstGenericClass[Literal[2]]:
|
||||||
return ...
|
return ...
|
||||||
|
|
||||||
|
@ -28,6 +31,9 @@ def make_hybrid_class_2_int32() -> HybridGenericClass2[Literal[2], int32]:
|
||||||
def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, Literal[0], Literal[1]]:
|
def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, Literal[0], Literal[1]]:
|
||||||
return ...
|
return ...
|
||||||
|
|
||||||
|
def consume_generic_1_or_2(instance: ConstGenericClass[Literal[1, 2]]):
|
||||||
|
pass
|
||||||
|
|
||||||
def consume_generic_2(instance: ConstGenericClass[Literal[2]]):
|
def consume_generic_2(instance: ConstGenericClass[Literal[2]]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -42,6 +48,8 @@ def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, Literal[0]
|
||||||
|
|
||||||
def f():
|
def f():
|
||||||
consume_generic_2(make_generic_2())
|
consume_generic_2(make_generic_2())
|
||||||
|
consume_generic_1_or_2(make_generic_1())
|
||||||
|
consume_generic_1_or_2(make_generic_2())
|
||||||
consume_generic2_1_2(make_generic2_1_2())
|
consume_generic2_1_2(make_generic2_1_2())
|
||||||
consume_hybrid_class_2_i32(make_hybrid_class_2_int32())
|
consume_hybrid_class_2_i32(make_hybrid_class_2_int32())
|
||||||
consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1())
|
consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1())
|
||||||
|
|
Loading…
Reference in New Issue