diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 4a45c6f..d275d8a 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -529,12 +529,10 @@ impl Unifier { } required.pop(); let (name, expected) = all_names.pop().unwrap(); - let snapshot = self.unification_table.get_snapshot(); self.unify_impl(expected, *t, false).map_err(|_| { self.restore_snapshot(); TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) })?; - self.unification_table.restore_snapshot(snapshot); } for (k, t) in kwargs { if let Some(i) = required.iter().position(|v| v == k) { @@ -548,12 +546,10 @@ impl Unifier { TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc) })?; let (name, expected) = all_names.remove(i); - let snapshot = self.unification_table.get_snapshot(); self.unify_impl(expected, *t, false).map_err(|_| { self.restore_snapshot(); TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) })?; - self.unification_table.restore_snapshot(snapshot); } if !required.is_empty() { self.restore_snapshot(); @@ -750,21 +746,18 @@ impl Unifier { (TVar { range: tys, is_const_generic: true, .. }, TLiteral { values, .. }) => { assert_eq!(tys.len(), 1); + assert_eq!(values.len(), 1); let primitives = &self.primitive_store .expect("Expected PrimitiveStore to be present"); let ty = tys[0]; + let value= &values[0]; + let value_ty = value.get_type(primitives, self); - for value in values { - let value_ty = value.get_type(primitives, self); + // If the types don't match, try to implicitly promote integers + if !self.unioned(ty, value_ty) { - if self.unioned(ty, value_ty) { - self.set_a_to_b(a, b); - return Ok(()) - } - - // The types don't match, try to implicitly promote integers let num_val = match *value { SymbolValue::I32(v) => v as i128, SymbolValue::I64(v) => v as i128, @@ -785,18 +778,19 @@ impl Unifier { false }; - if can_convert { - self.set_a_to_b(a, b); - return Ok(()) + if !can_convert { + return self.incompatible_types(a, b) } } - return self.incompatible_types(a, b) + self.set_a_to_b(a, b); } (TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => { - if val2.iter().any(|val| !val1.contains(val)) { - return self.incompatible_types(a, b) + for (v1, v2) in zip(val1, val2) { + if v1 != v2 { + return self.incompatible_types(a, b) + } } self.set_a_to_b(a, b); diff --git a/nac3standalone/demo/src/const_generic.py b/nac3standalone/demo/src/const_generic.py index 1d917ff..1d44b72 100644 --- a/nac3standalone/demo/src/const_generic.py +++ b/nac3standalone/demo/src/const_generic.py @@ -16,9 +16,6 @@ class HybridGenericClass2(Generic[A, T]): class HybridGenericClass3(Generic[T, A, B]): pass -def make_generic_1() -> ConstGenericClass[Literal[1]]: - return ... - def make_generic_2() -> ConstGenericClass[Literal[2]]: return ... @@ -31,9 +28,6 @@ def make_hybrid_class_2_int32() -> HybridGenericClass2[Literal[2], int32]: def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, Literal[0], Literal[1]]: return ... -def consume_generic_1_or_2(instance: ConstGenericClass[Literal[1, 2]]): - pass - def consume_generic_2(instance: ConstGenericClass[Literal[2]]): pass @@ -48,8 +42,6 @@ def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, Literal[0] def f(): consume_generic_2(make_generic_2()) - consume_generic_1_or_2(make_generic_1()) - consume_generic_1_or_2(make_generic_2()) consume_generic2_1_2(make_generic2_1_2()) consume_hybrid_class_2_i32(make_hybrid_class_2_int32()) consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1())