diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index d275d8a76..4a45c6f38 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -529,10 +529,12 @@ impl Unifier { } required.pop(); let (name, expected) = all_names.pop().unwrap(); + let snapshot = self.unification_table.get_snapshot(); self.unify_impl(expected, *t, false).map_err(|_| { self.restore_snapshot(); TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) })?; + self.unification_table.restore_snapshot(snapshot); } for (k, t) in kwargs { if let Some(i) = required.iter().position(|v| v == k) { @@ -546,10 +548,12 @@ impl Unifier { TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc) })?; let (name, expected) = all_names.remove(i); + let snapshot = self.unification_table.get_snapshot(); self.unify_impl(expected, *t, false).map_err(|_| { self.restore_snapshot(); TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) })?; + self.unification_table.restore_snapshot(snapshot); } if !required.is_empty() { self.restore_snapshot(); @@ -746,18 +750,21 @@ impl Unifier { (TVar { range: tys, is_const_generic: true, .. }, TLiteral { values, .. }) => { assert_eq!(tys.len(), 1); - assert_eq!(values.len(), 1); let primitives = &self.primitive_store .expect("Expected PrimitiveStore to be present"); let ty = tys[0]; - let value= &values[0]; - let value_ty = value.get_type(primitives, self); - // If the types don't match, try to implicitly promote integers - if !self.unioned(ty, value_ty) { + for value in values { + let value_ty = value.get_type(primitives, self); + if self.unioned(ty, value_ty) { + self.set_a_to_b(a, b); + return Ok(()) + } + + // The types don't match, try to implicitly promote integers let num_val = match *value { SymbolValue::I32(v) => v as i128, SymbolValue::I64(v) => v as i128, @@ -778,19 +785,18 @@ impl Unifier { false }; - if !can_convert { - return self.incompatible_types(a, b) + if can_convert { + self.set_a_to_b(a, b); + return Ok(()) } } - self.set_a_to_b(a, b); + return self.incompatible_types(a, b) } (TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => { - for (v1, v2) in zip(val1, val2) { - if v1 != v2 { - return self.incompatible_types(a, b) - } + if val2.iter().any(|val| !val1.contains(val)) { + return self.incompatible_types(a, b) } self.set_a_to_b(a, b); diff --git a/nac3standalone/demo/src/const_generic.py b/nac3standalone/demo/src/const_generic.py index 1d44b728e..1d917ff5a 100644 --- a/nac3standalone/demo/src/const_generic.py +++ b/nac3standalone/demo/src/const_generic.py @@ -16,6 +16,9 @@ class HybridGenericClass2(Generic[A, T]): class HybridGenericClass3(Generic[T, A, B]): pass +def make_generic_1() -> ConstGenericClass[Literal[1]]: + return ... + def make_generic_2() -> ConstGenericClass[Literal[2]]: return ... @@ -28,6 +31,9 @@ def make_hybrid_class_2_int32() -> HybridGenericClass2[Literal[2], int32]: def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, Literal[0], Literal[1]]: return ... +def consume_generic_1_or_2(instance: ConstGenericClass[Literal[1, 2]]): + pass + def consume_generic_2(instance: ConstGenericClass[Literal[2]]): pass @@ -42,6 +48,8 @@ def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, Literal[0] def f(): consume_generic_2(make_generic_2()) + consume_generic_1_or_2(make_generic_1()) + consume_generic_1_or_2(make_generic_2()) consume_generic2_1_2(make_generic2_1_2()) consume_hybrid_class_2_i32(make_hybrid_class_2_int32()) consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1())