core: Refactor generic constants to Literal

Better matches the syntax of `typing.Literal`.
This commit is contained in:
David Mak 2023-12-13 18:26:42 +08:00 committed by sb10q
parent 5f692debd8
commit 457d3b6cd7
5 changed files with 162 additions and 89 deletions
nac3core/src
nac3standalone/demo/src

View File

@ -60,9 +60,8 @@ pub enum ConcreteTypeEnum {
ret: ConcreteType, ret: ConcreteType,
vars: HashMap<u32, ConcreteType>, vars: HashMap<u32, ConcreteType>,
}, },
TConstant { TLiteral {
value: SymbolValue, values: Vec<SymbolValue>,
ty: ConcreteType,
}, },
} }
@ -202,9 +201,8 @@ impl ConcreteTypeStore {
TypeEnum::TFunc(signature) => { TypeEnum::TFunc(signature) => {
self.from_signature(unifier, primitives, signature, cache) self.from_signature(unifier, primitives, signature, cache)
} }
TypeEnum::TConstant { value, ty, .. } => ConcreteTypeEnum::TConstant { TypeEnum::TLiteral { values, .. } => ConcreteTypeEnum::TLiteral {
value: value.clone(), values: values.clone(),
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
}, },
_ => unreachable!("{:?}", ty_enum.get_type_name()), _ => unreachable!("{:?}", ty_enum.get_type_name()),
}; };
@ -293,9 +291,8 @@ impl ConcreteTypeStore {
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) .map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache)))
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
}), }),
ConcreteTypeEnum::TConstant { value, ty } => TypeEnum::TConstant { ConcreteTypeEnum::TLiteral { values, .. } => TypeEnum::TLiteral {
value: value.clone(), values: values.clone(),
ty: self.to_unifier_type(unifier, primitives, *ty, cache),
loc: None, loc: None,
} }
}; };

View File

@ -114,6 +114,41 @@ impl SymbolValue {
} }
} }
/// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value.
///
/// * `constant` - The constant to create the value from.
pub fn from_constant_inferred(
constant: &Constant,
unifier: &mut Unifier
) -> Result<Self, String> {
match constant {
Constant::None => Ok(SymbolValue::OptionNone),
Constant::Bool(b) => Ok(SymbolValue::Bool(*b)),
Constant::Str(s) => Ok(SymbolValue::Str(s.to_string())),
Constant::Int(i) => {
let i = *i;
if i >= 0 {
i32::try_from(i).map(SymbolValue::I32)
.or_else(|_| i64::try_from(i).map(SymbolValue::I64))
.map_err(|_| format!("Literal cannot be expressed as any integral type: {i}"))
} else {
u32::try_from(i).map(SymbolValue::U32)
.or_else(|_| u64::try_from(i).map(SymbolValue::U64))
.map_err(|_| format!("Literal cannot be expressed as any integral type: {i}"))
}
}
Constant::Tuple(t) => {
let elems = t
.iter()
.map(|constant| Self::from_constant_inferred(constant, unifier))
.collect::<Result<Vec<SymbolValue>, _>>()?;
Ok(SymbolValue::Tuple(elems))
}
Constant::Float(f) => Ok(SymbolValue::Double(*f)),
_ => Err(format!("Unsupported value type {constant:?}")),
}
}
/// Returns the [`Type`] representing the data type of this value. /// Returns the [`Type`] representing the data type of this value.
pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type { pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type {
match self { match self {

View File

@ -1,5 +1,6 @@
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use super::*; use super::*;
use nac3parser::ast::Constant;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum TypeAnnotation { pub enum TypeAnnotation {
@ -13,16 +14,8 @@ pub enum TypeAnnotation {
// can only be CustomClassKind // can only be CustomClassKind
Virtual(Box<TypeAnnotation>), Virtual(Box<TypeAnnotation>),
TypeVar(Type), TypeVar(Type),
/// A constant used in the context of a const-generic variable. /// A `Literal` allowing a subset of literals.
Constant { Literal(Vec<Constant>),
/// The non-type variable associated with this constant.
///
/// Invoking [Unifier::get_ty] on this type will return a [TypeEnum::TVar] representing the
/// const generic variable of which this constant is associated with.
ty: Type,
/// The constant value of this constant.
value: SymbolValue
},
List(Box<TypeAnnotation>), List(Box<TypeAnnotation>),
Tuple(Vec<TypeAnnotation>), Tuple(Vec<TypeAnnotation>),
} }
@ -57,7 +50,7 @@ impl TypeAnnotation {
} }
) )
} }
Constant { value, .. } => format!("Const({value})"), Literal(values) => format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", ")),
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)), Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
List(ty) => format!("list[{}]", ty.stringify(unifier)), List(ty) => format!("list[{}]", ty.stringify(unifier)),
Tuple(types) => { Tuple(types) => {
@ -191,8 +184,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} }
let result = params_ast let result = params_ast
.iter() .iter()
.enumerate() .map(|x| {
.map(|(idx, x)| {
parse_ast_to_type_annotation_kinds( parse_ast_to_type_annotation_kinds(
resolver, resolver,
top_level_defs, top_level_defs,
@ -203,7 +195,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
locked.insert(obj_id, type_vars.clone()); locked.insert(obj_id, type_vars.clone());
locked.clone() locked.clone()
}, },
Some(type_vars[idx]), None,
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
@ -319,6 +311,46 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
Ok(TypeAnnotation::Tuple(type_annotations)) Ok(TypeAnnotation::Tuple(type_annotations))
} }
// Literal
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into())
} => {
let tup_elts = {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.as_slice()
} else {
std::slice::from_ref(slice.as_ref())
}
};
let type_annotations = tup_elts
.iter()
.map(|e| {
match &e.node {
ast::ExprKind::Constant { value, .. } => Ok(
TypeAnnotation::Literal(vec![value.clone()]),
),
_ => parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
None,
),
}
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flat_map(|type_ann| match type_ann {
TypeAnnotation::Literal(values) => values,
_ => unreachable!(),
})
.collect_vec();
Ok(TypeAnnotation::Literal(type_annotations))
}
// custom class // custom class
ast::ExprKind::Subscript { value, slice, .. } => { ast::ExprKind::Subscript { value, slice, .. } => {
if let ast::ExprKind::Name { id, .. } = &value.node { if let ast::ExprKind::Name { id, .. } = &value.node {
@ -331,30 +363,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} }
ast::ExprKind::Constant { value, .. } => { ast::ExprKind::Constant { value, .. } => {
let type_var = type_var.expect("Expect type variable to be present"); Ok(TypeAnnotation::Literal(vec![value.clone()]))
let ntv_ty_enum = unifier.get_ty_immutable(type_var);
let TypeEnum::TVar { range: underlying_ty, .. } = ntv_ty_enum.as_ref() else {
unreachable!()
};
let underlying_ty = underlying_ty[0];
let value = SymbolValue::from_constant(value, underlying_ty, primitives, unifier)
.map_err(|err| HashSet::from([err]))?;
if matches!(value, SymbolValue::Str(_) | SymbolValue::Tuple(_) | SymbolValue::OptionSome(_)) {
return Err(HashSet::from([
format!(
"expression {value} is not allowed for constant type annotation (at {})",
expr.location
),
]))
}
Ok(TypeAnnotation::Constant {
ty: type_var,
value,
})
} }
_ => Err(HashSet::from([ _ => Err(HashSet::from([
@ -495,14 +504,13 @@ pub fn get_type_from_type_annotation_kinds(
Ok(ty) Ok(ty)
} }
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Constant { ty, value, .. } => { TypeAnnotation::Literal(values) => {
let ty_enum = unifier.get_ty(*ty); let values = values.iter()
let TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } = &*ty_enum else { .map(|v| SymbolValue::from_constant_inferred(v, unifier))
unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name()); .collect::<Result<Vec<_>, _>>()
}; .map_err(|err| HashSet::from([err]))?;
let ty = ntv_underlying_ty[0]; let var = unifier.get_fresh_literal(values, None);
let var = unifier.get_fresh_constant(value.clone(), ty, *loc);
Ok(var) Ok(var)
} }
TypeAnnotation::Virtual(ty) => { TypeAnnotation::Virtual(ty) => {
@ -576,7 +584,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<Ty
result.extend(get_type_var_contained_in_type_annotation(a)); result.extend(get_type_var_contained_in_type_annotation(a));
} }
} }
TypeAnnotation::Primitive(..) | TypeAnnotation::Constant { .. } => {} TypeAnnotation::Primitive(..) | TypeAnnotation::Literal { .. } => {}
} }
result result
} }

View File

@ -140,12 +140,10 @@ pub enum TypeEnum {
is_const_generic: bool, is_const_generic: bool,
}, },
/// A constant for substitution into a const generic variable. /// A literal generic type matching `typing.Literal`.
TConstant { TLiteral {
/// The value of the constant. /// The value of the constant.
value: SymbolValue, values: Vec<SymbolValue>,
/// The underlying type of the value.
ty: Type,
loc: Option<Location>, loc: Option<Location>,
}, },
@ -192,7 +190,7 @@ impl TypeEnum {
match self { match self {
TypeEnum::TRigidVar { .. } => "TRigidVar", TypeEnum::TRigidVar { .. } => "TRigidVar",
TypeEnum::TVar { .. } => "TVar", TypeEnum::TVar { .. } => "TVar",
TypeEnum::TConstant { .. } => "TConstant", TypeEnum::TLiteral { .. } => "TConstant",
TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TTuple { .. } => "TTuple",
TypeEnum::TList { .. } => "TList", TypeEnum::TList { .. } => "TList",
TypeEnum::TObj { .. } => "TObj", TypeEnum::TObj { .. } => "TObj",
@ -371,8 +369,7 @@ impl Unifier {
(self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false }), id) (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false }), id)
} }
/// Returns a fresh type representing a constant generic variable with the given underlying type /// Returns a fresh type representing a constant generic variable with the given underlying type `ty`.
/// `ty`.
pub fn get_fresh_const_generic_var( pub fn get_fresh_const_generic_var(
&mut self, &mut self,
ty: Type, ty: Type,
@ -384,17 +381,17 @@ impl Unifier {
(self.add_ty(TypeEnum::TVar { id, range: vec![ty], fields: None, name, loc, is_const_generic: true }), id) (self.add_ty(TypeEnum::TVar { id, range: vec![ty], fields: None, name, loc, is_const_generic: true }), id)
} }
/// Returns a fresh type representing a [fresh constant][TypeEnum::TConstant] with the given /// Returns a fresh type representing a [literal][TypeEnum::TConstant] with the given `values`.
/// `value` and type `ty`. pub fn get_fresh_literal(
pub fn get_fresh_constant(
&mut self, &mut self,
value: SymbolValue, values: Vec<SymbolValue>,
ty: Type,
loc: Option<Location>, loc: Option<Location>,
) -> Type { ) -> Type {
assert!(matches!(self.get_ty(ty).as_ref(), TypeEnum::TObj { .. })); let ty_enum = TypeEnum::TLiteral {
values: values.clone(),
self.add_ty(TypeEnum::TConstant { ty, value, loc }) loc
};
self.add_ty(ty_enum)
} }
/// Unification would not unify rigid variables with other types, but we want to do this for /// Unification would not unify rigid variables with other types, but we want to do this for
@ -469,7 +466,7 @@ impl Unifier {
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
use TypeEnum::*; use TypeEnum::*;
match &*self.get_ty(a) { match &*self.get_ty(a) {
TRigidVar { .. } | TConstant { .. } => true, TRigidVar { .. } | TLiteral { .. } => true,
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false, TCall { .. } => false,
TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
@ -747,18 +744,54 @@ impl Unifier {
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
(TVar { range: ty1, is_const_generic: true, .. }, TConstant { ty: ty2, .. }) => { (TVar { range: tys, is_const_generic: true, .. }, TLiteral { values, .. }) => {
let ty1 = ty1[0]; assert_eq!(tys.len(), 1);
assert_eq!(values.len(), 1);
let primitives = &self.primitive_store
.expect("Expected PrimitiveStore to be present");
let ty = tys[0];
let value= &values[0];
let value_ty = value.get_type(primitives, self);
// If the types don't match, try to implicitly promote integers
if !self.unioned(ty, value_ty) {
let num_val = match *value {
SymbolValue::I32(v) => v as i128,
SymbolValue::I64(v) => v as i128,
SymbolValue::U32(v) => v as i128,
SymbolValue::U64(v) => v as i128,
_ => return self.incompatible_types(a, b),
};
let can_convert = if self.unioned(ty, primitives.int32) {
i32::try_from(num_val).is_ok()
} else if self.unioned(ty, primitives.int64) {
i64::try_from(num_val).is_ok()
} else if self.unioned(ty, primitives.uint32) {
u32::try_from(num_val).is_ok()
} else if self.unioned(ty, primitives.uint64) {
u64::try_from(num_val).is_ok()
} else {
false
};
if !can_convert {
return self.incompatible_types(a, b)
}
}
self.unify_impl(ty1, *ty2, false)?;
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
(TConstant { value: val1, ty: ty1, .. }, TConstant { value: val2, ty: ty2, .. }) => { (TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => {
if val1 != val2 { for (v1, v2) in zip(val1, val2) {
return self.incompatible_types(a, b) if v1 != v2 {
return self.incompatible_types(a, b)
}
} }
self.unify_impl(*ty1, *ty2, false)?;
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
@ -1016,8 +1049,8 @@ impl Unifier {
}; };
n n
} }
TypeEnum::TConstant { value, .. } => { TypeEnum::TLiteral { values, .. } => {
format!("const({value})") format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", "))
} }
TypeEnum::TTuple { ty } => { TypeEnum::TTuple { ty } => {
let mut fields = let mut fields =

View File

@ -16,28 +16,28 @@ class HybridGenericClass2(Generic[A, T]):
class HybridGenericClass3(Generic[T, A, B]): class HybridGenericClass3(Generic[T, A, B]):
pass pass
def make_generic_2() -> ConstGenericClass[2]: def make_generic_2() -> ConstGenericClass[Literal[2]]:
return ... return ...
def make_generic2_1_2() -> ConstGeneric2Class[1, 2]: def make_generic2_1_2() -> ConstGeneric2Class[Literal[1], Literal[2]]:
return ... return ...
def make_hybrid_class_2_int32() -> HybridGenericClass2[2, int32]: def make_hybrid_class_2_int32() -> HybridGenericClass2[Literal[2], int32]:
return ... return ...
def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, 0, 1]: def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, Literal[0], Literal[1]]:
return ... return ...
def consume_generic_2(instance: ConstGenericClass[2]): def consume_generic_2(instance: ConstGenericClass[Literal[2]]):
pass pass
def consume_generic2_1_2(instance: ConstGeneric2Class[1, 2]): def consume_generic2_1_2(instance: ConstGeneric2Class[Literal[1], Literal[2]]):
pass pass
def consume_hybrid_class_2_i32(instance: HybridGenericClass2[2, int32]): def consume_hybrid_class_2_i32(instance: HybridGenericClass2[Literal[2], int32]):
pass pass
def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, 0, 1]): def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, Literal[0], Literal[1]]):
pass pass
def f(): def f():