core: Do not keep unification result for function arguments

For some reason, when unifying a function call parameter with an
argument, subsequent calls to the same function will only accept the
type of the substituted argument.

This affect snippets like:

```
def make1() -> C[Literal[1]]:
    return ...

def make2() -> C[Literal[2]]:
    return ...

def consume(instance: C[Literal[1, 2]]):
    pass

consume(make1())
consume(make2())
```

The last statement will result in a compiler error, as the parameter of
consume is replaced with C[Literal[1]].

We fix this by getting a snapshot before performing unification, and
restoring the snapshot after unification succeeds.
This commit is contained in:
David Mak 2023-12-13 21:55:30 +08:00 committed by sb10q
parent 0bbc9ce6f5
commit f09f3c27a5
2 changed files with 26 additions and 12 deletions

View File

@ -529,10 +529,12 @@ impl Unifier {
} }
required.pop(); required.pop();
let (name, expected) = all_names.pop().unwrap(); let (name, expected) = all_names.pop().unwrap();
let snapshot = self.unification_table.get_snapshot();
self.unify_impl(expected, *t, false).map_err(|_| { self.unify_impl(expected, *t, false).map_err(|_| {
self.restore_snapshot(); self.restore_snapshot();
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
})?; })?;
self.unification_table.restore_snapshot(snapshot);
} }
for (k, t) in kwargs { for (k, t) in kwargs {
if let Some(i) = required.iter().position(|v| v == k) { if let Some(i) = required.iter().position(|v| v == k) {
@ -546,10 +548,12 @@ impl Unifier {
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc) TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
})?; })?;
let (name, expected) = all_names.remove(i); let (name, expected) = all_names.remove(i);
let snapshot = self.unification_table.get_snapshot();
self.unify_impl(expected, *t, false).map_err(|_| { self.unify_impl(expected, *t, false).map_err(|_| {
self.restore_snapshot(); self.restore_snapshot();
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
})?; })?;
self.unification_table.restore_snapshot(snapshot);
} }
if !required.is_empty() { if !required.is_empty() {
self.restore_snapshot(); self.restore_snapshot();
@ -746,18 +750,21 @@ impl Unifier {
(TVar { range: tys, is_const_generic: true, .. }, TLiteral { values, .. }) => { (TVar { range: tys, is_const_generic: true, .. }, TLiteral { values, .. }) => {
assert_eq!(tys.len(), 1); assert_eq!(tys.len(), 1);
assert_eq!(values.len(), 1);
let primitives = &self.primitive_store let primitives = &self.primitive_store
.expect("Expected PrimitiveStore to be present"); .expect("Expected PrimitiveStore to be present");
let ty = tys[0]; let ty = tys[0];
let value= &values[0];
for value in values {
let value_ty = value.get_type(primitives, self); let value_ty = value.get_type(primitives, self);
// If the types don't match, try to implicitly promote integers if self.unioned(ty, value_ty) {
if !self.unioned(ty, value_ty) { self.set_a_to_b(a, b);
return Ok(())
}
// The types don't match, try to implicitly promote integers
let num_val = match *value { let num_val = match *value {
SymbolValue::I32(v) => v as i128, SymbolValue::I32(v) => v as i128,
SymbolValue::I64(v) => v as i128, SymbolValue::I64(v) => v as i128,
@ -778,20 +785,19 @@ impl Unifier {
false false
}; };
if !can_convert { if can_convert {
return self.incompatible_types(a, b) self.set_a_to_b(a, b);
return Ok(())
} }
} }
self.set_a_to_b(a, b); return self.incompatible_types(a, b)
} }
(TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => { (TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => {
for (v1, v2) in zip(val1, val2) { if val2.iter().any(|val| !val1.contains(val)) {
if v1 != v2 {
return self.incompatible_types(a, b) return self.incompatible_types(a, b)
} }
}
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }

View File

@ -16,6 +16,9 @@ class HybridGenericClass2(Generic[A, T]):
class HybridGenericClass3(Generic[T, A, B]): class HybridGenericClass3(Generic[T, A, B]):
pass pass
def make_generic_1() -> ConstGenericClass[Literal[1]]:
return ...
def make_generic_2() -> ConstGenericClass[Literal[2]]: def make_generic_2() -> ConstGenericClass[Literal[2]]:
return ... return ...
@ -28,6 +31,9 @@ def make_hybrid_class_2_int32() -> HybridGenericClass2[Literal[2], int32]:
def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, Literal[0], Literal[1]]: def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, Literal[0], Literal[1]]:
return ... return ...
def consume_generic_1_or_2(instance: ConstGenericClass[Literal[1, 2]]):
pass
def consume_generic_2(instance: ConstGenericClass[Literal[2]]): def consume_generic_2(instance: ConstGenericClass[Literal[2]]):
pass pass
@ -42,6 +48,8 @@ def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, Literal[0]
def f(): def f():
consume_generic_2(make_generic_2()) consume_generic_2(make_generic_2())
consume_generic_1_or_2(make_generic_1())
consume_generic_1_or_2(make_generic_2())
consume_generic2_1_2(make_generic2_1_2()) consume_generic2_1_2(make_generic2_1_2())
consume_hybrid_class_2_i32(make_hybrid_class_2_int32()) consume_hybrid_class_2_i32(make_hybrid_class_2_int32())
consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1()) consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1())