From 457d3b6cd73a0a6770e8be130a68be57fb793ae1 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Dec 2023 18:26:42 +0800 Subject: [PATCH] core: Refactor generic constants to `Literal` Better matches the syntax of `typing.Literal`. --- nac3core/src/codegen/concrete_type.rs | 15 ++-- nac3core/src/symbol_resolver.rs | 35 ++++++++ nac3core/src/toplevel/type_annotation.rs | 100 ++++++++++++----------- nac3core/src/typecheck/typedef/mod.rs | 85 +++++++++++++------ nac3standalone/demo/src/const_generic.py | 16 ++-- 5 files changed, 162 insertions(+), 89 deletions(-) diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index 472f78e1e..774516006 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -60,9 +60,8 @@ pub enum ConcreteTypeEnum { ret: ConcreteType, vars: HashMap, }, - TConstant { - value: SymbolValue, - ty: ConcreteType, + TLiteral { + values: Vec, }, } @@ -202,9 +201,8 @@ impl ConcreteTypeStore { TypeEnum::TFunc(signature) => { self.from_signature(unifier, primitives, signature, cache) } - TypeEnum::TConstant { value, ty, .. } => ConcreteTypeEnum::TConstant { - value: value.clone(), - ty: self.from_unifier_type(unifier, primitives, *ty, cache), + TypeEnum::TLiteral { values, .. } => ConcreteTypeEnum::TLiteral { + values: values.clone(), }, _ => unreachable!("{:?}", ty_enum.get_type_name()), }; @@ -293,9 +291,8 @@ impl ConcreteTypeStore { .map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) .collect::>(), }), - ConcreteTypeEnum::TConstant { value, ty } => TypeEnum::TConstant { - value: value.clone(), - ty: self.to_unifier_type(unifier, primitives, *ty, cache), + ConcreteTypeEnum::TLiteral { values, .. } => TypeEnum::TLiteral { + values: values.clone(), loc: None, } }; diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 2f08ecc27..0932bea8a 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -114,6 +114,41 @@ impl SymbolValue { } } + /// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value. + /// + /// * `constant` - The constant to create the value from. + pub fn from_constant_inferred( + constant: &Constant, + unifier: &mut Unifier + ) -> Result { + match constant { + Constant::None => Ok(SymbolValue::OptionNone), + Constant::Bool(b) => Ok(SymbolValue::Bool(*b)), + Constant::Str(s) => Ok(SymbolValue::Str(s.to_string())), + Constant::Int(i) => { + let i = *i; + if i >= 0 { + i32::try_from(i).map(SymbolValue::I32) + .or_else(|_| i64::try_from(i).map(SymbolValue::I64)) + .map_err(|_| format!("Literal cannot be expressed as any integral type: {i}")) + } else { + u32::try_from(i).map(SymbolValue::U32) + .or_else(|_| u64::try_from(i).map(SymbolValue::U64)) + .map_err(|_| format!("Literal cannot be expressed as any integral type: {i}")) + } + } + Constant::Tuple(t) => { + let elems = t + .iter() + .map(|constant| Self::from_constant_inferred(constant, unifier)) + .collect::, _>>()?; + Ok(SymbolValue::Tuple(elems)) + } + Constant::Float(f) => Ok(SymbolValue::Double(*f)), + _ => Err(format!("Unsupported value type {constant:?}")), + } + } + /// Returns the [`Type`] representing the data type of this value. pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type { match self { diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 4703bca22..d44597cde 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,5 +1,6 @@ use crate::symbol_resolver::SymbolValue; use super::*; +use nac3parser::ast::Constant; #[derive(Clone, Debug)] pub enum TypeAnnotation { @@ -13,16 +14,8 @@ pub enum TypeAnnotation { // can only be CustomClassKind Virtual(Box), TypeVar(Type), - /// A constant used in the context of a const-generic variable. - Constant { - /// The non-type variable associated with this constant. - /// - /// Invoking [Unifier::get_ty] on this type will return a [TypeEnum::TVar] representing the - /// const generic variable of which this constant is associated with. - ty: Type, - /// The constant value of this constant. - value: SymbolValue - }, + /// A `Literal` allowing a subset of literals. + Literal(Vec), List(Box), Tuple(Vec), } @@ -57,7 +50,7 @@ impl TypeAnnotation { } ) } - Constant { value, .. } => format!("Const({value})"), + Literal(values) => format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", ")), Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)), List(ty) => format!("list[{}]", ty.stringify(unifier)), Tuple(types) => { @@ -191,8 +184,7 @@ pub fn parse_ast_to_type_annotation_kinds( } let result = params_ast .iter() - .enumerate() - .map(|(idx, x)| { + .map(|x| { parse_ast_to_type_annotation_kinds( resolver, top_level_defs, @@ -203,7 +195,7 @@ pub fn parse_ast_to_type_annotation_kinds( locked.insert(obj_id, type_vars.clone()); locked.clone() }, - Some(type_vars[idx]), + None, ) }) .collect::, _>>()?; @@ -319,6 +311,46 @@ pub fn parse_ast_to_type_annotation_kinds( Ok(TypeAnnotation::Tuple(type_annotations)) } + // Literal + ast::ExprKind::Subscript { value, slice, .. } + if { + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into()) + } => { + let tup_elts = { + if let ast::ExprKind::Tuple { elts, .. } = &slice.node { + elts.as_slice() + } else { + std::slice::from_ref(slice.as_ref()) + } + }; + let type_annotations = tup_elts + .iter() + .map(|e| { + match &e.node { + ast::ExprKind::Constant { value, .. } => Ok( + TypeAnnotation::Literal(vec![value.clone()]), + ), + _ => parse_ast_to_type_annotation_kinds( + resolver, + top_level_defs, + unifier, + primitives, + e, + locked.clone(), + None, + ), + } + }) + .collect::, _>>()? + .into_iter() + .flat_map(|type_ann| match type_ann { + TypeAnnotation::Literal(values) => values, + _ => unreachable!(), + }) + .collect_vec(); + Ok(TypeAnnotation::Literal(type_annotations)) + } + // custom class ast::ExprKind::Subscript { value, slice, .. } => { if let ast::ExprKind::Name { id, .. } = &value.node { @@ -331,30 +363,7 @@ pub fn parse_ast_to_type_annotation_kinds( } ast::ExprKind::Constant { value, .. } => { - let type_var = type_var.expect("Expect type variable to be present"); - - let ntv_ty_enum = unifier.get_ty_immutable(type_var); - let TypeEnum::TVar { range: underlying_ty, .. } = ntv_ty_enum.as_ref() else { - unreachable!() - }; - let underlying_ty = underlying_ty[0]; - - let value = SymbolValue::from_constant(value, underlying_ty, primitives, unifier) - .map_err(|err| HashSet::from([err]))?; - - if matches!(value, SymbolValue::Str(_) | SymbolValue::Tuple(_) | SymbolValue::OptionSome(_)) { - return Err(HashSet::from([ - format!( - "expression {value} is not allowed for constant type annotation (at {})", - expr.location - ), - ])) - } - - Ok(TypeAnnotation::Constant { - ty: type_var, - value, - }) + Ok(TypeAnnotation::Literal(vec![value.clone()])) } _ => Err(HashSet::from([ @@ -495,14 +504,13 @@ pub fn get_type_from_type_annotation_kinds( Ok(ty) } TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), - TypeAnnotation::Constant { ty, value, .. } => { - let ty_enum = unifier.get_ty(*ty); - let TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } = &*ty_enum else { - unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name()); - }; + TypeAnnotation::Literal(values) => { + let values = values.iter() + .map(|v| SymbolValue::from_constant_inferred(v, unifier)) + .collect::, _>>() + .map_err(|err| HashSet::from([err]))?; - let ty = ntv_underlying_ty[0]; - let var = unifier.get_fresh_constant(value.clone(), ty, *loc); + let var = unifier.get_fresh_literal(values, None); Ok(var) } TypeAnnotation::Virtual(ty) => { @@ -576,7 +584,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec {} + TypeAnnotation::Primitive(..) | TypeAnnotation::Literal { .. } => {} } result } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 03802e186..41a5e233d 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -140,12 +140,10 @@ pub enum TypeEnum { is_const_generic: bool, }, - /// A constant for substitution into a const generic variable. - TConstant { + /// A literal generic type matching `typing.Literal`. + TLiteral { /// The value of the constant. - value: SymbolValue, - /// The underlying type of the value. - ty: Type, + values: Vec, loc: Option, }, @@ -192,7 +190,7 @@ impl TypeEnum { match self { TypeEnum::TRigidVar { .. } => "TRigidVar", TypeEnum::TVar { .. } => "TVar", - TypeEnum::TConstant { .. } => "TConstant", + TypeEnum::TLiteral { .. } => "TConstant", TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TList { .. } => "TList", TypeEnum::TObj { .. } => "TObj", @@ -371,8 +369,7 @@ impl Unifier { (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false }), id) } - /// Returns a fresh type representing a constant generic variable with the given underlying type - /// `ty`. + /// Returns a fresh type representing a constant generic variable with the given underlying type `ty`. pub fn get_fresh_const_generic_var( &mut self, ty: Type, @@ -384,17 +381,17 @@ impl Unifier { (self.add_ty(TypeEnum::TVar { id, range: vec![ty], fields: None, name, loc, is_const_generic: true }), id) } - /// Returns a fresh type representing a [fresh constant][TypeEnum::TConstant] with the given - /// `value` and type `ty`. - pub fn get_fresh_constant( + /// Returns a fresh type representing a [literal][TypeEnum::TConstant] with the given `values`. + pub fn get_fresh_literal( &mut self, - value: SymbolValue, - ty: Type, + values: Vec, loc: Option, ) -> Type { - assert!(matches!(self.get_ty(ty).as_ref(), TypeEnum::TObj { .. })); - - self.add_ty(TypeEnum::TConstant { ty, value, loc }) + let ty_enum = TypeEnum::TLiteral { + values: values.clone(), + loc + }; + self.add_ty(ty_enum) } /// Unification would not unify rigid variables with other types, but we want to do this for @@ -469,7 +466,7 @@ impl Unifier { pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { use TypeEnum::*; match &*self.get_ty(a) { - TRigidVar { .. } | TConstant { .. } => true, + TRigidVar { .. } | TLiteral { .. } => true, TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), @@ -747,18 +744,54 @@ impl Unifier { self.set_a_to_b(a, b); } - (TVar { range: ty1, is_const_generic: true, .. }, TConstant { ty: ty2, .. }) => { - let ty1 = ty1[0]; + (TVar { range: tys, is_const_generic: true, .. }, TLiteral { values, .. }) => { + assert_eq!(tys.len(), 1); + assert_eq!(values.len(), 1); + + let primitives = &self.primitive_store + .expect("Expected PrimitiveStore to be present"); + + let ty = tys[0]; + let value= &values[0]; + let value_ty = value.get_type(primitives, self); + + // If the types don't match, try to implicitly promote integers + if !self.unioned(ty, value_ty) { + + let num_val = match *value { + SymbolValue::I32(v) => v as i128, + SymbolValue::I64(v) => v as i128, + SymbolValue::U32(v) => v as i128, + SymbolValue::U64(v) => v as i128, + _ => return self.incompatible_types(a, b), + }; + + let can_convert = if self.unioned(ty, primitives.int32) { + i32::try_from(num_val).is_ok() + } else if self.unioned(ty, primitives.int64) { + i64::try_from(num_val).is_ok() + } else if self.unioned(ty, primitives.uint32) { + u32::try_from(num_val).is_ok() + } else if self.unioned(ty, primitives.uint64) { + u64::try_from(num_val).is_ok() + } else { + false + }; + + if !can_convert { + return self.incompatible_types(a, b) + } + } - self.unify_impl(ty1, *ty2, false)?; self.set_a_to_b(a, b); } - (TConstant { value: val1, ty: ty1, .. }, TConstant { value: val2, ty: ty2, .. }) => { - if val1 != val2 { - return self.incompatible_types(a, b) + (TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => { + for (v1, v2) in zip(val1, val2) { + if v1 != v2 { + return self.incompatible_types(a, b) + } } - self.unify_impl(*ty1, *ty2, false)?; self.set_a_to_b(a, b); } @@ -1016,8 +1049,8 @@ impl Unifier { }; n } - TypeEnum::TConstant { value, .. } => { - format!("const({value})") + TypeEnum::TLiteral { values, .. } => { + format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", ")) } TypeEnum::TTuple { ty } => { let mut fields = diff --git a/nac3standalone/demo/src/const_generic.py b/nac3standalone/demo/src/const_generic.py index 2ddca9978..1d44b728e 100644 --- a/nac3standalone/demo/src/const_generic.py +++ b/nac3standalone/demo/src/const_generic.py @@ -16,28 +16,28 @@ class HybridGenericClass2(Generic[A, T]): class HybridGenericClass3(Generic[T, A, B]): pass -def make_generic_2() -> ConstGenericClass[2]: +def make_generic_2() -> ConstGenericClass[Literal[2]]: return ... -def make_generic2_1_2() -> ConstGeneric2Class[1, 2]: +def make_generic2_1_2() -> ConstGeneric2Class[Literal[1], Literal[2]]: return ... -def make_hybrid_class_2_int32() -> HybridGenericClass2[2, int32]: +def make_hybrid_class_2_int32() -> HybridGenericClass2[Literal[2], int32]: return ... -def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, 0, 1]: +def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, Literal[0], Literal[1]]: return ... -def consume_generic_2(instance: ConstGenericClass[2]): +def consume_generic_2(instance: ConstGenericClass[Literal[2]]): pass -def consume_generic2_1_2(instance: ConstGeneric2Class[1, 2]): +def consume_generic2_1_2(instance: ConstGeneric2Class[Literal[1], Literal[2]]): pass -def consume_hybrid_class_2_i32(instance: HybridGenericClass2[2, int32]): +def consume_hybrid_class_2_i32(instance: HybridGenericClass2[Literal[2], int32]): pass -def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, 0, 1]): +def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, Literal[0], Literal[1]]): pass def f():