From a6331ef88b490fc97b0accc5eac1cde1591a387e Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Dec 2023 13:26:14 +0800 Subject: [PATCH 1/7] standalone: Output id of undefined identifier --- nac3standalone/src/basic_symbol_resolver.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index cdc3575..e483c6d 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -64,7 +64,9 @@ impl SymbolResolver for Resolver { fn get_identifier_def(&self, id: StrRef) -> Result> { self.0.id_to_def.lock().get(&id).copied() - .ok_or_else(|| HashSet::from(["Undefined identifier".to_string()])) + .ok_or_else(|| HashSet::from([ + format!("Undefined identifier `{id}`"), + ])) } fn get_string_id(&self, s: &str) -> i32 { -- 2.44.1 From 8d47d6d4dd6511b431eb4d74f577fc96f21d3350 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Dec 2023 18:15:22 +0800 Subject: [PATCH 2/7] core: Fix indentation --- nac3core/src/toplevel/type_annotation.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index a44de76..4703bca 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -390,17 +390,17 @@ pub fn get_type_from_type_annotation_kinds( ])) } - let param_ty = params - .iter() - .map(|x| { - get_type_from_type_annotation_kinds( - top_level_defs, - unifier, - x, - subst_list - ) - }) - .collect::, _>>()?; + let param_ty = params + .iter() + .map(|x| { + get_type_from_type_annotation_kinds( + top_level_defs, + unifier, + x, + subst_list + ) + }) + .collect::, _>>()?; let subst = { // check for compatible range -- 2.44.1 From 875d08e9bbf386c41dc2b72405840299c22818f1 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Dec 2023 18:23:32 +0800 Subject: [PATCH 3/7] core: Add PrimitiveStore into Unifier This will be used during unification between a const generic variable and a `Literal`. --- nac3core/src/codegen/mod.rs | 1 + nac3core/src/toplevel/helper.rs | 1 + nac3core/src/typecheck/type_inferencer/test.rs | 3 +++ nac3core/src/typecheck/typedef/mod.rs | 15 ++++++++++++++- 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 0917d8c..51496e6 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -580,6 +580,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte let (unifier, primitives) = &top_level_ctx.unifiers.read()[task.unifier_index]; (Unifier::from_shared_unifier(unifier), *primitives) }; + unifier.put_primitive_store(&primitives); unifier.top_level = Some(top_level_ctx.clone()); let mut cache = HashMap::new(); diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index e11de99..0f4d375 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -145,6 +145,7 @@ impl TopLevelComposer { exception, option, }; + unifier.put_primitive_store(&primitives); crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); (primitives, unifier) } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 4589c83..f3ac2d0 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -148,6 +148,7 @@ impl TestEnvironment { uint64, option, }; + unifier.put_primitive_store(&primitives); set_primitives_magic_methods(&primitives, &mut unifier); let id_to_name = [ @@ -296,6 +297,8 @@ impl TestEnvironment { option, }; + unifier.put_primitive_store(&primitives); + let (v0, id) = unifier.get_dummy_var(); let foo_ty = unifier.add_ty(TypeEnum::TObj { diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 85dfedc..03802e1 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -13,6 +13,7 @@ use super::type_error::{TypeError, TypeErrorKind}; use super::unification_table::{UnificationKey, UnificationTable}; use crate::symbol_resolver::SymbolValue; use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef}; +use crate::typecheck::type_inferencer::PrimitiveStore; #[cfg(test)] mod test; @@ -211,7 +212,8 @@ pub struct Unifier { pub(crate) calls: Vec>, var_id: u32, unify_cache: HashSet<(Type, Type)>, - snapshot: Option<(usize, u32)> + snapshot: Option<(usize, u32)>, + primitive_store: Option, } impl Default for Unifier { @@ -231,9 +233,19 @@ impl Unifier { unify_cache: HashSet::new(), top_level: None, snapshot: None, + primitive_store: None, } } + /// Sets the [PrimitiveStore] instance within this `Unifier`. + /// + /// This function can only be invoked once. Any subsequent invocations will result in an + /// assertion error.. + pub fn put_primitive_store(&mut self, primitives: &PrimitiveStore) { + assert!(self.primitive_store.is_none()); + self.primitive_store.replace(primitives.clone()); + } + pub unsafe fn get_unification_table(&mut self) -> &mut UnificationTable> { &mut self.unification_table } @@ -252,6 +264,7 @@ impl Unifier { top_level: None, unify_cache: HashSet::new(), snapshot: None, + primitive_store: None, } } -- 2.44.1 From f742d0c32dae59195032703fda40b4d25ff491da Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Dec 2023 18:26:42 +0800 Subject: [PATCH 4/7] core: Refactor generic constants to `Literal` Better matches the syntax of `typing.Literal`. --- nac3core/src/codegen/concrete_type.rs | 15 ++-- nac3core/src/symbol_resolver.rs | 35 ++++++++ nac3core/src/toplevel/type_annotation.rs | 100 ++++++++++++----------- nac3core/src/typecheck/typedef/mod.rs | 85 +++++++++++++------ nac3standalone/demo/src/const_generic.py | 16 ++-- 5 files changed, 162 insertions(+), 89 deletions(-) diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index 472f78e..7745160 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -60,9 +60,8 @@ pub enum ConcreteTypeEnum { ret: ConcreteType, vars: HashMap, }, - TConstant { - value: SymbolValue, - ty: ConcreteType, + TLiteral { + values: Vec, }, } @@ -202,9 +201,8 @@ impl ConcreteTypeStore { TypeEnum::TFunc(signature) => { self.from_signature(unifier, primitives, signature, cache) } - TypeEnum::TConstant { value, ty, .. } => ConcreteTypeEnum::TConstant { - value: value.clone(), - ty: self.from_unifier_type(unifier, primitives, *ty, cache), + TypeEnum::TLiteral { values, .. } => ConcreteTypeEnum::TLiteral { + values: values.clone(), }, _ => unreachable!("{:?}", ty_enum.get_type_name()), }; @@ -293,9 +291,8 @@ impl ConcreteTypeStore { .map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) .collect::>(), }), - ConcreteTypeEnum::TConstant { value, ty } => TypeEnum::TConstant { - value: value.clone(), - ty: self.to_unifier_type(unifier, primitives, *ty, cache), + ConcreteTypeEnum::TLiteral { values, .. } => TypeEnum::TLiteral { + values: values.clone(), loc: None, } }; diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 2f08ecc..0932bea 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -114,6 +114,41 @@ impl SymbolValue { } } + /// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value. + /// + /// * `constant` - The constant to create the value from. + pub fn from_constant_inferred( + constant: &Constant, + unifier: &mut Unifier + ) -> Result { + match constant { + Constant::None => Ok(SymbolValue::OptionNone), + Constant::Bool(b) => Ok(SymbolValue::Bool(*b)), + Constant::Str(s) => Ok(SymbolValue::Str(s.to_string())), + Constant::Int(i) => { + let i = *i; + if i >= 0 { + i32::try_from(i).map(SymbolValue::I32) + .or_else(|_| i64::try_from(i).map(SymbolValue::I64)) + .map_err(|_| format!("Literal cannot be expressed as any integral type: {i}")) + } else { + u32::try_from(i).map(SymbolValue::U32) + .or_else(|_| u64::try_from(i).map(SymbolValue::U64)) + .map_err(|_| format!("Literal cannot be expressed as any integral type: {i}")) + } + } + Constant::Tuple(t) => { + let elems = t + .iter() + .map(|constant| Self::from_constant_inferred(constant, unifier)) + .collect::, _>>()?; + Ok(SymbolValue::Tuple(elems)) + } + Constant::Float(f) => Ok(SymbolValue::Double(*f)), + _ => Err(format!("Unsupported value type {constant:?}")), + } + } + /// Returns the [`Type`] representing the data type of this value. pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type { match self { diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 4703bca..d44597c 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,5 +1,6 @@ use crate::symbol_resolver::SymbolValue; use super::*; +use nac3parser::ast::Constant; #[derive(Clone, Debug)] pub enum TypeAnnotation { @@ -13,16 +14,8 @@ pub enum TypeAnnotation { // can only be CustomClassKind Virtual(Box), TypeVar(Type), - /// A constant used in the context of a const-generic variable. - Constant { - /// The non-type variable associated with this constant. - /// - /// Invoking [Unifier::get_ty] on this type will return a [TypeEnum::TVar] representing the - /// const generic variable of which this constant is associated with. - ty: Type, - /// The constant value of this constant. - value: SymbolValue - }, + /// A `Literal` allowing a subset of literals. + Literal(Vec), List(Box), Tuple(Vec), } @@ -57,7 +50,7 @@ impl TypeAnnotation { } ) } - Constant { value, .. } => format!("Const({value})"), + Literal(values) => format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", ")), Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)), List(ty) => format!("list[{}]", ty.stringify(unifier)), Tuple(types) => { @@ -191,8 +184,7 @@ pub fn parse_ast_to_type_annotation_kinds( } let result = params_ast .iter() - .enumerate() - .map(|(idx, x)| { + .map(|x| { parse_ast_to_type_annotation_kinds( resolver, top_level_defs, @@ -203,7 +195,7 @@ pub fn parse_ast_to_type_annotation_kinds( locked.insert(obj_id, type_vars.clone()); locked.clone() }, - Some(type_vars[idx]), + None, ) }) .collect::, _>>()?; @@ -319,6 +311,46 @@ pub fn parse_ast_to_type_annotation_kinds( Ok(TypeAnnotation::Tuple(type_annotations)) } + // Literal + ast::ExprKind::Subscript { value, slice, .. } + if { + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into()) + } => { + let tup_elts = { + if let ast::ExprKind::Tuple { elts, .. } = &slice.node { + elts.as_slice() + } else { + std::slice::from_ref(slice.as_ref()) + } + }; + let type_annotations = tup_elts + .iter() + .map(|e| { + match &e.node { + ast::ExprKind::Constant { value, .. } => Ok( + TypeAnnotation::Literal(vec![value.clone()]), + ), + _ => parse_ast_to_type_annotation_kinds( + resolver, + top_level_defs, + unifier, + primitives, + e, + locked.clone(), + None, + ), + } + }) + .collect::, _>>()? + .into_iter() + .flat_map(|type_ann| match type_ann { + TypeAnnotation::Literal(values) => values, + _ => unreachable!(), + }) + .collect_vec(); + Ok(TypeAnnotation::Literal(type_annotations)) + } + // custom class ast::ExprKind::Subscript { value, slice, .. } => { if let ast::ExprKind::Name { id, .. } = &value.node { @@ -331,30 +363,7 @@ pub fn parse_ast_to_type_annotation_kinds( } ast::ExprKind::Constant { value, .. } => { - let type_var = type_var.expect("Expect type variable to be present"); - - let ntv_ty_enum = unifier.get_ty_immutable(type_var); - let TypeEnum::TVar { range: underlying_ty, .. } = ntv_ty_enum.as_ref() else { - unreachable!() - }; - let underlying_ty = underlying_ty[0]; - - let value = SymbolValue::from_constant(value, underlying_ty, primitives, unifier) - .map_err(|err| HashSet::from([err]))?; - - if matches!(value, SymbolValue::Str(_) | SymbolValue::Tuple(_) | SymbolValue::OptionSome(_)) { - return Err(HashSet::from([ - format!( - "expression {value} is not allowed for constant type annotation (at {})", - expr.location - ), - ])) - } - - Ok(TypeAnnotation::Constant { - ty: type_var, - value, - }) + Ok(TypeAnnotation::Literal(vec![value.clone()])) } _ => Err(HashSet::from([ @@ -495,14 +504,13 @@ pub fn get_type_from_type_annotation_kinds( Ok(ty) } TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), - TypeAnnotation::Constant { ty, value, .. } => { - let ty_enum = unifier.get_ty(*ty); - let TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } = &*ty_enum else { - unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name()); - }; + TypeAnnotation::Literal(values) => { + let values = values.iter() + .map(|v| SymbolValue::from_constant_inferred(v, unifier)) + .collect::, _>>() + .map_err(|err| HashSet::from([err]))?; - let ty = ntv_underlying_ty[0]; - let var = unifier.get_fresh_constant(value.clone(), ty, *loc); + let var = unifier.get_fresh_literal(values, None); Ok(var) } TypeAnnotation::Virtual(ty) => { @@ -576,7 +584,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec {} + TypeAnnotation::Primitive(..) | TypeAnnotation::Literal { .. } => {} } result } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 03802e1..41a5e23 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -140,12 +140,10 @@ pub enum TypeEnum { is_const_generic: bool, }, - /// A constant for substitution into a const generic variable. - TConstant { + /// A literal generic type matching `typing.Literal`. + TLiteral { /// The value of the constant. - value: SymbolValue, - /// The underlying type of the value. - ty: Type, + values: Vec, loc: Option, }, @@ -192,7 +190,7 @@ impl TypeEnum { match self { TypeEnum::TRigidVar { .. } => "TRigidVar", TypeEnum::TVar { .. } => "TVar", - TypeEnum::TConstant { .. } => "TConstant", + TypeEnum::TLiteral { .. } => "TConstant", TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TList { .. } => "TList", TypeEnum::TObj { .. } => "TObj", @@ -371,8 +369,7 @@ impl Unifier { (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false }), id) } - /// Returns a fresh type representing a constant generic variable with the given underlying type - /// `ty`. + /// Returns a fresh type representing a constant generic variable with the given underlying type `ty`. pub fn get_fresh_const_generic_var( &mut self, ty: Type, @@ -384,17 +381,17 @@ impl Unifier { (self.add_ty(TypeEnum::TVar { id, range: vec![ty], fields: None, name, loc, is_const_generic: true }), id) } - /// Returns a fresh type representing a [fresh constant][TypeEnum::TConstant] with the given - /// `value` and type `ty`. - pub fn get_fresh_constant( + /// Returns a fresh type representing a [literal][TypeEnum::TConstant] with the given `values`. + pub fn get_fresh_literal( &mut self, - value: SymbolValue, - ty: Type, + values: Vec, loc: Option, ) -> Type { - assert!(matches!(self.get_ty(ty).as_ref(), TypeEnum::TObj { .. })); - - self.add_ty(TypeEnum::TConstant { ty, value, loc }) + let ty_enum = TypeEnum::TLiteral { + values: values.clone(), + loc + }; + self.add_ty(ty_enum) } /// Unification would not unify rigid variables with other types, but we want to do this for @@ -469,7 +466,7 @@ impl Unifier { pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { use TypeEnum::*; match &*self.get_ty(a) { - TRigidVar { .. } | TConstant { .. } => true, + TRigidVar { .. } | TLiteral { .. } => true, TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), @@ -747,18 +744,54 @@ impl Unifier { self.set_a_to_b(a, b); } - (TVar { range: ty1, is_const_generic: true, .. }, TConstant { ty: ty2, .. }) => { - let ty1 = ty1[0]; + (TVar { range: tys, is_const_generic: true, .. }, TLiteral { values, .. }) => { + assert_eq!(tys.len(), 1); + assert_eq!(values.len(), 1); + + let primitives = &self.primitive_store + .expect("Expected PrimitiveStore to be present"); + + let ty = tys[0]; + let value= &values[0]; + let value_ty = value.get_type(primitives, self); + + // If the types don't match, try to implicitly promote integers + if !self.unioned(ty, value_ty) { + + let num_val = match *value { + SymbolValue::I32(v) => v as i128, + SymbolValue::I64(v) => v as i128, + SymbolValue::U32(v) => v as i128, + SymbolValue::U64(v) => v as i128, + _ => return self.incompatible_types(a, b), + }; + + let can_convert = if self.unioned(ty, primitives.int32) { + i32::try_from(num_val).is_ok() + } else if self.unioned(ty, primitives.int64) { + i64::try_from(num_val).is_ok() + } else if self.unioned(ty, primitives.uint32) { + u32::try_from(num_val).is_ok() + } else if self.unioned(ty, primitives.uint64) { + u64::try_from(num_val).is_ok() + } else { + false + }; + + if !can_convert { + return self.incompatible_types(a, b) + } + } - self.unify_impl(ty1, *ty2, false)?; self.set_a_to_b(a, b); } - (TConstant { value: val1, ty: ty1, .. }, TConstant { value: val2, ty: ty2, .. }) => { - if val1 != val2 { - return self.incompatible_types(a, b) + (TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => { + for (v1, v2) in zip(val1, val2) { + if v1 != v2 { + return self.incompatible_types(a, b) + } } - self.unify_impl(*ty1, *ty2, false)?; self.set_a_to_b(a, b); } @@ -1016,8 +1049,8 @@ impl Unifier { }; n } - TypeEnum::TConstant { value, .. } => { - format!("const({value})") + TypeEnum::TLiteral { values, .. } => { + format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", ")) } TypeEnum::TTuple { ty } => { let mut fields = diff --git a/nac3standalone/demo/src/const_generic.py b/nac3standalone/demo/src/const_generic.py index 2ddca99..1d44b72 100644 --- a/nac3standalone/demo/src/const_generic.py +++ b/nac3standalone/demo/src/const_generic.py @@ -16,28 +16,28 @@ class HybridGenericClass2(Generic[A, T]): class HybridGenericClass3(Generic[T, A, B]): pass -def make_generic_2() -> ConstGenericClass[2]: +def make_generic_2() -> ConstGenericClass[Literal[2]]: return ... -def make_generic2_1_2() -> ConstGeneric2Class[1, 2]: +def make_generic2_1_2() -> ConstGeneric2Class[Literal[1], Literal[2]]: return ... -def make_hybrid_class_2_int32() -> HybridGenericClass2[2, int32]: +def make_hybrid_class_2_int32() -> HybridGenericClass2[Literal[2], int32]: return ... -def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, 0, 1]: +def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, Literal[0], Literal[1]]: return ... -def consume_generic_2(instance: ConstGenericClass[2]): +def consume_generic_2(instance: ConstGenericClass[Literal[2]]): pass -def consume_generic2_1_2(instance: ConstGeneric2Class[1, 2]): +def consume_generic2_1_2(instance: ConstGeneric2Class[Literal[1], Literal[2]]): pass -def consume_hybrid_class_2_i32(instance: HybridGenericClass2[2, int32]): +def consume_hybrid_class_2_i32(instance: HybridGenericClass2[Literal[2], int32]): pass -def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, 0, 1]): +def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, Literal[0], Literal[1]]): pass def f(): -- 2.44.1 From bf830f216e08cf5cebe07388b87395491770e1e1 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 15 Dec 2023 12:26:43 +0800 Subject: [PATCH 5/7] core: Deduplicate values in `Literal` Matches the behavior with `typing.Literal`. --- nac3core/src/typecheck/typedef/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 41a5e23..d275d8a 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -388,7 +388,7 @@ impl Unifier { loc: Option, ) -> Type { let ty_enum = TypeEnum::TLiteral { - values: values.clone(), + values: values.into_iter().dedup().collect(), loc }; self.add_ty(ty_enum) -- 2.44.1 From 2722a7a17f8d0db45bdeb7561bd33e4628af8d40 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Dec 2023 21:55:30 +0800 Subject: [PATCH 6/7] core: Do not keep unification result for function arguments For some reason, when unifying a function call parameter with an argument, subsequent calls to the same function will only accept the type of the substituted argument. This affect snippets like: ``` def make1() -> C[Literal[1]]: return ... def make2() -> C[Literal[2]]: return ... def consume(instance: C[Literal[1, 2]]): pass consume(make1()) consume(make2()) ``` The last statement will result in a compiler error, as the parameter of consume is replaced with C[Literal[1]]. We fix this by getting a snapshot before performing unification, and restoring the snapshot after unification succeeds. --- nac3core/src/typecheck/typedef/mod.rs | 30 ++++++++++++++---------- nac3standalone/demo/src/const_generic.py | 8 +++++++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index d275d8a..4a45c6f 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -529,10 +529,12 @@ impl Unifier { } required.pop(); let (name, expected) = all_names.pop().unwrap(); + let snapshot = self.unification_table.get_snapshot(); self.unify_impl(expected, *t, false).map_err(|_| { self.restore_snapshot(); TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) })?; + self.unification_table.restore_snapshot(snapshot); } for (k, t) in kwargs { if let Some(i) = required.iter().position(|v| v == k) { @@ -546,10 +548,12 @@ impl Unifier { TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc) })?; let (name, expected) = all_names.remove(i); + let snapshot = self.unification_table.get_snapshot(); self.unify_impl(expected, *t, false).map_err(|_| { self.restore_snapshot(); TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) })?; + self.unification_table.restore_snapshot(snapshot); } if !required.is_empty() { self.restore_snapshot(); @@ -746,18 +750,21 @@ impl Unifier { (TVar { range: tys, is_const_generic: true, .. }, TLiteral { values, .. }) => { assert_eq!(tys.len(), 1); - assert_eq!(values.len(), 1); let primitives = &self.primitive_store .expect("Expected PrimitiveStore to be present"); let ty = tys[0]; - let value= &values[0]; - let value_ty = value.get_type(primitives, self); - // If the types don't match, try to implicitly promote integers - if !self.unioned(ty, value_ty) { + for value in values { + let value_ty = value.get_type(primitives, self); + if self.unioned(ty, value_ty) { + self.set_a_to_b(a, b); + return Ok(()) + } + + // The types don't match, try to implicitly promote integers let num_val = match *value { SymbolValue::I32(v) => v as i128, SymbolValue::I64(v) => v as i128, @@ -778,19 +785,18 @@ impl Unifier { false }; - if !can_convert { - return self.incompatible_types(a, b) + if can_convert { + self.set_a_to_b(a, b); + return Ok(()) } } - self.set_a_to_b(a, b); + return self.incompatible_types(a, b) } (TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => { - for (v1, v2) in zip(val1, val2) { - if v1 != v2 { - return self.incompatible_types(a, b) - } + if val2.iter().any(|val| !val1.contains(val)) { + return self.incompatible_types(a, b) } self.set_a_to_b(a, b); diff --git a/nac3standalone/demo/src/const_generic.py b/nac3standalone/demo/src/const_generic.py index 1d44b72..1d917ff 100644 --- a/nac3standalone/demo/src/const_generic.py +++ b/nac3standalone/demo/src/const_generic.py @@ -16,6 +16,9 @@ class HybridGenericClass2(Generic[A, T]): class HybridGenericClass3(Generic[T, A, B]): pass +def make_generic_1() -> ConstGenericClass[Literal[1]]: + return ... + def make_generic_2() -> ConstGenericClass[Literal[2]]: return ... @@ -28,6 +31,9 @@ def make_hybrid_class_2_int32() -> HybridGenericClass2[Literal[2], int32]: def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, Literal[0], Literal[1]]: return ... +def consume_generic_1_or_2(instance: ConstGenericClass[Literal[1, 2]]): + pass + def consume_generic_2(instance: ConstGenericClass[Literal[2]]): pass @@ -42,6 +48,8 @@ def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, Literal[0] def f(): consume_generic_2(make_generic_2()) + consume_generic_1_or_2(make_generic_1()) + consume_generic_1_or_2(make_generic_2()) consume_generic2_1_2(make_generic2_1_2()) consume_hybrid_class_2_i32(make_hybrid_class_2_int32()) consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1()) -- 2.44.1 From a4ee2ae22ef5e1169538ab335c278cc63214d133 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Dec 2023 18:27:06 +0800 Subject: [PATCH 7/7] core: Remove redundant argument in type annotation parsing --- nac3core/src/toplevel/composer.rs | 6 ------ nac3core/src/toplevel/type_annotation.rs | 7 ------- nac3standalone/src/main.rs | 2 -- 3 files changed, 15 deletions(-) diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 68cfb72..8e612a6 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -563,7 +563,6 @@ impl TopLevelComposer { &primitive_types, b, vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(), - None, )?; if let TypeAnnotation::CustomClass { .. } = &base_ty { @@ -904,7 +903,6 @@ impl TopLevelComposer { // NOTE: since only class need this, for function // it should be fine to be empty map HashMap::new(), - None, )?; let type_vars_within = @@ -971,7 +969,6 @@ impl TopLevelComposer { // NOTE: since only class need this, for function // it should be fine to be empty map HashMap::new(), - None, )? }; @@ -1158,7 +1155,6 @@ impl TopLevelComposer { vec![(class_id, class_type_vars_def.clone())] .into_iter() .collect(), - None, )? }; // find type vars within this method parameter type annotation @@ -1224,7 +1220,6 @@ impl TopLevelComposer { primitives, result, vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), - None, )?; // find type vars within this return type annotation let type_vars_within = @@ -1319,7 +1314,6 @@ impl TopLevelComposer { primitives, annotation.as_ref(), vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), - None, )?; // find type vars within this return type annotation let type_vars_within = diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index d44597c..3cd2f25 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -74,7 +74,6 @@ pub fn parse_ast_to_type_annotation_kinds( expr: &ast::Expr, // the key stores the type_var of this topleveldef::class, we only need this field here locked: HashMap>, - type_var: Option, ) -> Result> { let name_handle = |id: &StrRef, unifier: &mut Unifier, @@ -195,7 +194,6 @@ pub fn parse_ast_to_type_annotation_kinds( locked.insert(obj_id, type_vars.clone()); locked.clone() }, - None, ) }) .collect::, _>>()?; @@ -231,7 +229,6 @@ pub fn parse_ast_to_type_annotation_kinds( primitives, slice.as_ref(), locked, - None, )?; if !matches!(def, TypeAnnotation::CustomClass { .. }) { unreachable!("must be concretized custom class kind in the virtual") @@ -252,7 +249,6 @@ pub fn parse_ast_to_type_annotation_kinds( primitives, slice.as_ref(), locked, - None, )?; Ok(TypeAnnotation::List(def_ann.into())) } @@ -270,7 +266,6 @@ pub fn parse_ast_to_type_annotation_kinds( primitives, slice.as_ref(), locked, - None, )?; let id = if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() { @@ -304,7 +299,6 @@ pub fn parse_ast_to_type_annotation_kinds( primitives, e, locked.clone(), - None, ) }) .collect::, _>>()?; @@ -337,7 +331,6 @@ pub fn parse_ast_to_type_annotation_kinds( primitives, e, locked.clone(), - None, ), } }) diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index f34f55a..eb746a8 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -104,7 +104,6 @@ fn handle_typevar_definition( primitives, x, HashMap::default(), - None, )?; get_type_from_type_annotation_kinds( def_list, unifier, &ty, &mut None @@ -146,7 +145,6 @@ fn handle_typevar_definition( primitives, &args[1], HashMap::default(), - None, )?; let constraint = get_type_from_type_annotation_kinds( def_list, unifier, &ty, &mut None -- 2.44.1