forked from M-Labs/nac3
core: Initial implementation for const generics
This commit is contained in:
parent
b6dfcfcc38
commit
031e660f18
|
@ -85,7 +85,7 @@ impl SymbolValue {
|
||||||
.map(|val| SymbolValue::U64(val))
|
.map(|val| SymbolValue::U64(val))
|
||||||
.map_err(|e| e.to_string())
|
.map_err(|e| e.to_string())
|
||||||
} else {
|
} else {
|
||||||
Err(format!("Expected {:?}, but got int", expected_ty))
|
Err(format!("Expected {}, but got int", unifier.stringify(expected_ty)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Constant::Tuple(t) => {
|
Constant::Tuple(t) => {
|
||||||
|
|
|
@ -560,6 +560,7 @@ impl TopLevelComposer {
|
||||||
&primitive_types,
|
&primitive_types,
|
||||||
b,
|
b,
|
||||||
vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(),
|
vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(),
|
||||||
|
None,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
if let TypeAnnotation::CustomClass { .. } = &base_ty {
|
if let TypeAnnotation::CustomClass { .. } = &base_ty {
|
||||||
|
@ -894,6 +895,7 @@ impl TopLevelComposer {
|
||||||
// NOTE: since only class need this, for function
|
// NOTE: since only class need this, for function
|
||||||
// it should be fine to be empty map
|
// it should be fine to be empty map
|
||||||
HashMap::new(),
|
HashMap::new(),
|
||||||
|
None,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let type_vars_within =
|
let type_vars_within =
|
||||||
|
@ -961,6 +963,7 @@ impl TopLevelComposer {
|
||||||
// NOTE: since only class need this, for function
|
// NOTE: since only class need this, for function
|
||||||
// it should be fine to be empty map
|
// it should be fine to be empty map
|
||||||
HashMap::new(),
|
HashMap::new(),
|
||||||
|
None,
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1158,6 +1161,7 @@ impl TopLevelComposer {
|
||||||
vec![(class_id, class_type_vars_def.clone())]
|
vec![(class_id, class_type_vars_def.clone())]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.collect(),
|
.collect(),
|
||||||
|
None,
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
// find type vars within this method parameter type annotation
|
// find type vars within this method parameter type annotation
|
||||||
|
@ -1221,6 +1225,7 @@ impl TopLevelComposer {
|
||||||
primitives,
|
primitives,
|
||||||
result,
|
result,
|
||||||
vec![(class_id, class_type_vars_def.clone())].into_iter().collect(),
|
vec![(class_id, class_type_vars_def.clone())].into_iter().collect(),
|
||||||
|
None,
|
||||||
)?;
|
)?;
|
||||||
// find type vars within this return type annotation
|
// find type vars within this return type annotation
|
||||||
let type_vars_within =
|
let type_vars_within =
|
||||||
|
@ -1317,6 +1322,7 @@ impl TopLevelComposer {
|
||||||
primitives,
|
primitives,
|
||||||
annotation.as_ref(),
|
annotation.as_ref(),
|
||||||
vec![(class_id, class_type_vars_def.clone())].into_iter().collect(),
|
vec![(class_id, class_type_vars_def.clone())].into_iter().collect(),
|
||||||
|
None,
|
||||||
)?;
|
)?;
|
||||||
// find type vars within this return type annotation
|
// find type vars within this return type annotation
|
||||||
let type_vars_within =
|
let type_vars_within =
|
||||||
|
@ -1735,7 +1741,7 @@ impl TopLevelComposer {
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(_, ty)| {
|
.map(|(_, ty)| {
|
||||||
unifier.get_instantiations(*ty).unwrap_or_else(|| {
|
unifier.get_instantiations(*ty).unwrap_or_else(|| {
|
||||||
if let TypeEnum::TVar { name, loc, .. } = &*unifier.get_ty(*ty)
|
if let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty)
|
||||||
{
|
{
|
||||||
let rigid = unifier.get_fresh_rigid_var(*name, *loc).0;
|
let rigid = unifier.get_fresh_rigid_var(*name, *loc).0;
|
||||||
no_ranges.push(rigid);
|
no_ranges.push(rigid);
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
@ -12,6 +13,16 @@ pub enum TypeAnnotation {
|
||||||
// can only be CustomClassKind
|
// can only be CustomClassKind
|
||||||
Virtual(Box<TypeAnnotation>),
|
Virtual(Box<TypeAnnotation>),
|
||||||
TypeVar(Type),
|
TypeVar(Type),
|
||||||
|
/// A constant used in the context of a const-generic variable.
|
||||||
|
Constant {
|
||||||
|
/// The non-type variable associated with this constant.
|
||||||
|
///
|
||||||
|
/// Invoking [Unifier::get_ty] on this type will return a [TypeEnum::TVar] representing the
|
||||||
|
/// const generic variable of which this constant is associated with.
|
||||||
|
ty: Type,
|
||||||
|
/// The constant value of this constant.
|
||||||
|
value: SymbolValue
|
||||||
|
},
|
||||||
List(Box<TypeAnnotation>),
|
List(Box<TypeAnnotation>),
|
||||||
Tuple(Vec<TypeAnnotation>),
|
Tuple(Vec<TypeAnnotation>),
|
||||||
}
|
}
|
||||||
|
@ -47,6 +58,7 @@ impl TypeAnnotation {
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
Constant { value, .. } => format!("Const({value})"),
|
||||||
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
|
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
|
||||||
List(ty) => format!("list[{}]", ty.stringify(unifier)),
|
List(ty) => format!("list[{}]", ty.stringify(unifier)),
|
||||||
Tuple(types) => {
|
Tuple(types) => {
|
||||||
|
@ -56,6 +68,12 @@ impl TypeAnnotation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parses an AST expression `expr` into a [TypeAnnotation].
|
||||||
|
///
|
||||||
|
/// * `locked` - A [HashMap] containing the IDs of known definitions, mapped to a [Vec] of all
|
||||||
|
/// generic variables associated with the definition.
|
||||||
|
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
|
||||||
|
/// [None] when this function is invoked externally.
|
||||||
pub fn parse_ast_to_type_annotation_kinds<T>(
|
pub fn parse_ast_to_type_annotation_kinds<T>(
|
||||||
resolver: &(dyn SymbolResolver + Send + Sync),
|
resolver: &(dyn SymbolResolver + Send + Sync),
|
||||||
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
||||||
|
@ -64,6 +82,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
||||||
expr: &ast::Expr<T>,
|
expr: &ast::Expr<T>,
|
||||||
// the key stores the type_var of this topleveldef::class, we only need this field here
|
// the key stores the type_var of this topleveldef::class, we only need this field here
|
||||||
locked: HashMap<DefinitionId, Vec<Type>>,
|
locked: HashMap<DefinitionId, Vec<Type>>,
|
||||||
|
type_var: Option<Type>,
|
||||||
) -> Result<TypeAnnotation, String> {
|
) -> Result<TypeAnnotation, String> {
|
||||||
let name_handle = |id: &StrRef,
|
let name_handle = |id: &StrRef,
|
||||||
unifier: &mut Unifier,
|
unifier: &mut Unifier,
|
||||||
|
@ -161,7 +180,8 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
||||||
}
|
}
|
||||||
let result = params_ast
|
let result = params_ast
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| {
|
.enumerate()
|
||||||
|
.map(|(idx, x)| {
|
||||||
parse_ast_to_type_annotation_kinds(
|
parse_ast_to_type_annotation_kinds(
|
||||||
resolver,
|
resolver,
|
||||||
top_level_defs,
|
top_level_defs,
|
||||||
|
@ -172,6 +192,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
||||||
locked.insert(obj_id, type_vars.clone());
|
locked.insert(obj_id, type_vars.clone());
|
||||||
locked.clone()
|
locked.clone()
|
||||||
},
|
},
|
||||||
|
Some(type_vars[idx]),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
@ -190,6 +211,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
||||||
};
|
};
|
||||||
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
|
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
|
||||||
};
|
};
|
||||||
|
|
||||||
match &expr.node {
|
match &expr.node {
|
||||||
ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked),
|
ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked),
|
||||||
// virtual
|
// virtual
|
||||||
|
@ -205,6 +227,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
||||||
primitives,
|
primitives,
|
||||||
slice.as_ref(),
|
slice.as_ref(),
|
||||||
locked,
|
locked,
|
||||||
|
None,
|
||||||
)?;
|
)?;
|
||||||
if !matches!(def, TypeAnnotation::CustomClass { .. }) {
|
if !matches!(def, TypeAnnotation::CustomClass { .. }) {
|
||||||
unreachable!("must be concretized custom class kind in the virtual")
|
unreachable!("must be concretized custom class kind in the virtual")
|
||||||
|
@ -225,6 +248,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
||||||
primitives,
|
primitives,
|
||||||
slice.as_ref(),
|
slice.as_ref(),
|
||||||
locked,
|
locked,
|
||||||
|
None,
|
||||||
)?;
|
)?;
|
||||||
Ok(TypeAnnotation::List(def_ann.into()))
|
Ok(TypeAnnotation::List(def_ann.into()))
|
||||||
}
|
}
|
||||||
|
@ -242,6 +266,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
||||||
primitives,
|
primitives,
|
||||||
slice.as_ref(),
|
slice.as_ref(),
|
||||||
locked,
|
locked,
|
||||||
|
None,
|
||||||
)?;
|
)?;
|
||||||
let id =
|
let id =
|
||||||
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
|
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
|
||||||
|
@ -275,6 +300,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
||||||
primitives,
|
primitives,
|
||||||
e,
|
e,
|
||||||
locked.clone(),
|
locked.clone(),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
@ -290,6 +316,31 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ast::ExprKind::Constant { value, .. } => {
|
||||||
|
let type_var = type_var.expect("Expect type variable to be present");
|
||||||
|
|
||||||
|
let ntv_ty_enum = unifier.get_ty_immutable(type_var);
|
||||||
|
let TypeEnum::TVar { range: underlying_ty, .. } = ntv_ty_enum.as_ref() else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
let underlying_ty = underlying_ty[0];
|
||||||
|
|
||||||
|
let value = SymbolValue::from_constant(value, underlying_ty, primitives, unifier)?;
|
||||||
|
|
||||||
|
if matches!(value, SymbolValue::Str(_) | SymbolValue::Tuple(_) | SymbolValue::OptionSome(_)) {
|
||||||
|
return Err(format!(
|
||||||
|
"expression {} is not allowed for constant type annotation (at {})",
|
||||||
|
value.to_string(),
|
||||||
|
expr.location
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(TypeAnnotation::Constant {
|
||||||
|
ty: type_var,
|
||||||
|
value,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
_ => Err(format!("unsupported expression for type annotation (at {})", expr.location)),
|
_ => Err(format!("unsupported expression for type annotation (at {})", expr.location)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -308,14 +359,18 @@ pub fn get_type_from_type_annotation_kinds(
|
||||||
TypeAnnotation::CustomClass { id: obj_id, params } => {
|
TypeAnnotation::CustomClass { id: obj_id, params } => {
|
||||||
let def_read = top_level_defs[obj_id.0].read();
|
let def_read = top_level_defs[obj_id.0].read();
|
||||||
let class_def: &TopLevelDef = def_read.deref();
|
let class_def: &TopLevelDef = def_read.deref();
|
||||||
if let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def {
|
let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def else {
|
||||||
|
unreachable!("should be class def here")
|
||||||
|
};
|
||||||
|
|
||||||
if type_vars.len() != params.len() {
|
if type_vars.len() != params.len() {
|
||||||
Err(format!(
|
return Err(format!(
|
||||||
"unexpected number of type parameters: expected {} but got {}",
|
"unexpected number of type parameters: expected {} but got {}",
|
||||||
type_vars.len(),
|
type_vars.len(),
|
||||||
params.len()
|
params.len()
|
||||||
))
|
))
|
||||||
} else {
|
}
|
||||||
|
|
||||||
let param_ty = params
|
let param_ty = params
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| {
|
.map(|x| {
|
||||||
|
@ -334,9 +389,8 @@ pub fn get_type_from_type_annotation_kinds(
|
||||||
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
|
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
|
||||||
let mut result: HashMap<u32, Type> = HashMap::new();
|
let mut result: HashMap<u32, Type> = HashMap::new();
|
||||||
for (tvar, p) in type_vars.iter().zip(param_ty) {
|
for (tvar, p) in type_vars.iter().zip(param_ty) {
|
||||||
if let TypeEnum::TVar { id, range, fields: None, name, loc } =
|
match unifier.get_ty(*tvar).as_ref() {
|
||||||
unifier.get_ty(*tvar).as_ref()
|
TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false } => {
|
||||||
{
|
|
||||||
let ok: bool = {
|
let ok: bool = {
|
||||||
// create a temp type var and unify to check compatibility
|
// create a temp type var and unify to check compatibility
|
||||||
p == *tvar || {
|
p == *tvar || {
|
||||||
|
@ -362,8 +416,33 @@ pub fn get_type_from_type_annotation_kinds(
|
||||||
*id
|
*id
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => {
|
||||||
|
let ty = range[0];
|
||||||
|
let ok: bool = {
|
||||||
|
// create a temp type var and unify to check compatibility
|
||||||
|
p == *tvar || {
|
||||||
|
let temp = unifier.get_fresh_const_generic_var(
|
||||||
|
ty,
|
||||||
|
*name,
|
||||||
|
*loc,
|
||||||
|
);
|
||||||
|
unifier.unify(temp.0, p).is_ok()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if ok {
|
||||||
|
result.insert(*id, p);
|
||||||
} else {
|
} else {
|
||||||
unreachable!("must be generic type var")
|
return Err(format!(
|
||||||
|
"cannot apply type {} to type variable {}",
|
||||||
|
unifier.stringify(p),
|
||||||
|
name.unwrap_or_else(|| format!("typevar{id}").into()),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => unreachable!("must be generic type var"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result
|
result
|
||||||
|
@ -391,11 +470,19 @@ pub fn get_type_from_type_annotation_kinds(
|
||||||
}
|
}
|
||||||
Ok(ty)
|
Ok(ty)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
unreachable!("should be class def here")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
|
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
|
||||||
|
TypeAnnotation::Constant { ty, value, .. } => {
|
||||||
|
let ty_enum = unifier.get_ty(*ty);
|
||||||
|
let (ty, loc) = match &*ty_enum {
|
||||||
|
TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } => {
|
||||||
|
(ntv_underlying_ty[0], loc)
|
||||||
|
}
|
||||||
|
_ => unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let var = unifier.get_fresh_constant(value.clone(), ty, *loc);
|
||||||
|
Ok(var)
|
||||||
|
}
|
||||||
TypeAnnotation::Virtual(ty) => {
|
TypeAnnotation::Virtual(ty) => {
|
||||||
let ty = get_type_from_type_annotation_kinds(
|
let ty = get_type_from_type_annotation_kinds(
|
||||||
top_level_defs,
|
top_level_defs,
|
||||||
|
@ -470,7 +557,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<Ty
|
||||||
result.extend(get_type_var_contained_in_type_annotation(a));
|
result.extend(get_type_var_contained_in_type_annotation(a));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeAnnotation::Primitive(..) => {}
|
TypeAnnotation::Primitive(..) | TypeAnnotation::Constant { .. } => {}
|
||||||
}
|
}
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
|
@ -134,6 +134,17 @@ pub enum TypeEnum {
|
||||||
range: Vec<Type>,
|
range: Vec<Type>,
|
||||||
name: Option<StrRef>,
|
name: Option<StrRef>,
|
||||||
loc: Option<Location>,
|
loc: Option<Location>,
|
||||||
|
/// Whether this type variable refers to a const-generic variable.
|
||||||
|
is_const_generic: bool,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// A constant for substitution into a const generic variable.
|
||||||
|
TConstant {
|
||||||
|
/// The value of the constant.
|
||||||
|
value: SymbolValue,
|
||||||
|
/// The underlying type of the value.
|
||||||
|
ty: Type,
|
||||||
|
loc: Option<Location>,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// A tuple type.
|
/// A tuple type.
|
||||||
|
@ -178,6 +189,7 @@ impl TypeEnum {
|
||||||
match self {
|
match self {
|
||||||
TypeEnum::TRigidVar { .. } => "TRigidVar",
|
TypeEnum::TRigidVar { .. } => "TRigidVar",
|
||||||
TypeEnum::TVar { .. } => "TVar",
|
TypeEnum::TVar { .. } => "TVar",
|
||||||
|
TypeEnum::TConstant { .. } => "TConstant",
|
||||||
TypeEnum::TTuple { .. } => "TTuple",
|
TypeEnum::TTuple { .. } => "TTuple",
|
||||||
TypeEnum::TList { .. } => "TList",
|
TypeEnum::TList { .. } => "TList",
|
||||||
TypeEnum::TObj { .. } => "TObj",
|
TypeEnum::TObj { .. } => "TObj",
|
||||||
|
@ -263,6 +275,7 @@ impl Unifier {
|
||||||
fields: Some(fields),
|
fields: Some(fields),
|
||||||
name: None,
|
name: None,
|
||||||
loc: None,
|
loc: None,
|
||||||
|
is_const_generic: false,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -336,7 +349,33 @@ impl Unifier {
|
||||||
let id = self.var_id + 1;
|
let id = self.var_id + 1;
|
||||||
self.var_id += 1;
|
self.var_id += 1;
|
||||||
let range = range.to_vec();
|
let range = range.to_vec();
|
||||||
(self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc }), id)
|
(self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false }), id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a fresh type representing a constant generic variable with the given underlying type
|
||||||
|
/// `ty`.
|
||||||
|
pub fn get_fresh_const_generic_var(
|
||||||
|
&mut self,
|
||||||
|
ty: Type,
|
||||||
|
name: Option<StrRef>,
|
||||||
|
loc: Option<Location>,
|
||||||
|
) -> (Type, u32) {
|
||||||
|
let id = self.var_id + 1;
|
||||||
|
self.var_id += 1;
|
||||||
|
(self.add_ty(TypeEnum::TVar { id, range: vec![ty], fields: None, name, loc, is_const_generic: true }), id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a fresh type representing a [fresh constant][TypeEnum::TConstant] with the given
|
||||||
|
/// `value` and type `ty`.
|
||||||
|
pub fn get_fresh_constant(
|
||||||
|
&mut self,
|
||||||
|
value: SymbolValue,
|
||||||
|
ty: Type,
|
||||||
|
loc: Option<Location>,
|
||||||
|
) -> Type {
|
||||||
|
assert!(matches!(self.get_ty(ty).as_ref(), TypeEnum::TObj { .. }));
|
||||||
|
|
||||||
|
self.add_ty(TypeEnum::TConstant { ty, value, loc })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unification would not unify rigid variables with other types, but we want to do this for
|
/// Unification would not unify rigid variables with other types, but we want to do this for
|
||||||
|
@ -412,7 +451,7 @@ impl Unifier {
|
||||||
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
|
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
|
||||||
use TypeEnum::*;
|
use TypeEnum::*;
|
||||||
match &*self.get_ty(a) {
|
match &*self.get_ty(a) {
|
||||||
TRigidVar { .. } => true,
|
TRigidVar { .. } | TConstant { .. } => true,
|
||||||
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||||
TCall { .. } => false,
|
TCall { .. } => false,
|
||||||
TList { ty } => self.is_concrete(*ty, allowed_typevars),
|
TList { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||||
|
@ -560,8 +599,8 @@ impl Unifier {
|
||||||
};
|
};
|
||||||
match (&*ty_a, &*ty_b) {
|
match (&*ty_a, &*ty_b) {
|
||||||
(
|
(
|
||||||
TVar { fields: fields1, id, name: name1, loc: loc1, .. },
|
TVar { fields: fields1, id, name: name1, loc: loc1, is_const_generic: false, .. },
|
||||||
TVar { fields: fields2, id: id2, name: name2, loc: loc2, .. },
|
TVar { fields: fields2, id: id2, name: name2, loc: loc2, is_const_generic: false, .. },
|
||||||
) => {
|
) => {
|
||||||
let new_fields = match (fields1, fields2) {
|
let new_fields = match (fields1, fields2) {
|
||||||
(None, None) => None,
|
(None, None) => None,
|
||||||
|
@ -616,10 +655,11 @@ impl Unifier {
|
||||||
range,
|
range,
|
||||||
name: name1.or(*name2),
|
name: name1.or(*name2),
|
||||||
loc: loc1.or(*loc2),
|
loc: loc1.or(*loc2),
|
||||||
|
is_const_generic: false,
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
(TVar { fields: None, range, .. }, _) => {
|
(TVar { fields: None, range, is_const_generic: false, .. }, _) => {
|
||||||
// We check for the range of the type variable to see if unification is allowed.
|
// We check for the range of the type variable to see if unification is allowed.
|
||||||
// Note that although b may be compatible with a, we may have to constrain type
|
// Note that although b may be compatible with a, we may have to constrain type
|
||||||
// variables in b to make sure that instantiations of b would always be compatible
|
// variables in b to make sure that instantiations of b would always be compatible
|
||||||
|
@ -636,7 +676,7 @@ impl Unifier {
|
||||||
self.unify_impl(x, b, false)?;
|
self.unify_impl(x, b, false)?;
|
||||||
self.set_a_to_b(a, x);
|
self.set_a_to_b(a, x);
|
||||||
}
|
}
|
||||||
(TVar { fields: Some(fields), range, .. }, TTuple { ty }) => {
|
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => {
|
||||||
let len = ty.len() as i32;
|
let len = ty.len() as i32;
|
||||||
for (k, v) in fields.iter() {
|
for (k, v) in fields.iter() {
|
||||||
match *k {
|
match *k {
|
||||||
|
@ -666,7 +706,7 @@ impl Unifier {
|
||||||
self.unify_impl(x, b, false)?;
|
self.unify_impl(x, b, false)?;
|
||||||
self.set_a_to_b(a, x);
|
self.set_a_to_b(a, x);
|
||||||
}
|
}
|
||||||
(TVar { fields: Some(fields), range, .. }, TList { ty }) => {
|
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => {
|
||||||
for (k, v) in fields.iter() {
|
for (k, v) in fields.iter() {
|
||||||
match *k {
|
match *k {
|
||||||
RecordKey::Int(_) => {
|
RecordKey::Int(_) => {
|
||||||
|
@ -681,6 +721,35 @@ impl Unifier {
|
||||||
self.unify_impl(x, b, false)?;
|
self.unify_impl(x, b, false)?;
|
||||||
self.set_a_to_b(a, x);
|
self.set_a_to_b(a, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
(TVar { id: id1, range: ty1, is_const_generic: true, .. }, TVar { id: id2, range: ty2, .. }) => {
|
||||||
|
let ty1 = ty1[0];
|
||||||
|
let ty2 = ty2[0];
|
||||||
|
|
||||||
|
if id1 != id2 {
|
||||||
|
self.unify_impl(ty1, ty2, false)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.set_a_to_b(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
(TVar { range: ty1, is_const_generic: true, .. }, TConstant { ty: ty2, .. }) => {
|
||||||
|
let ty1 = ty1[0];
|
||||||
|
|
||||||
|
self.unify_impl(ty1, *ty2, false)?;
|
||||||
|
self.set_a_to_b(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
(TConstant { value: val1, ty: ty1, .. }, TConstant { value: val2, ty: ty2, .. }) => {
|
||||||
|
if val1 != val2 {
|
||||||
|
eprintln!("VALUE MISMATCH: lhs={val1:?} rhs={val2:?} eq={}", val1 == val2);
|
||||||
|
return self.incompatible_types(a, b)
|
||||||
|
}
|
||||||
|
self.unify_impl(*ty1, *ty2, false)?;
|
||||||
|
|
||||||
|
self.set_a_to_b(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
|
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
|
||||||
if ty1.len() != ty2.len() {
|
if ty1.len() != ty2.len() {
|
||||||
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
||||||
|
@ -775,7 +844,14 @@ impl Unifier {
|
||||||
if id1 != id2 {
|
if id1 != id2 {
|
||||||
self.incompatible_types(a, b)?;
|
self.incompatible_types(a, b)?;
|
||||||
}
|
}
|
||||||
for (x, y) in zip(params1.values(), params2.values()) {
|
|
||||||
|
// Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits
|
||||||
|
// all K-V pairs "in arbitrary order"
|
||||||
|
let (tv1, tv2) = (
|
||||||
|
params1.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
|
||||||
|
params2.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
|
||||||
|
);
|
||||||
|
for (x, y) in zip(tv1, tv2) {
|
||||||
if self.unify_impl(*x, *y, false).is_err() {
|
if self.unify_impl(*x, *y, false).is_err() {
|
||||||
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
||||||
};
|
};
|
||||||
|
@ -928,6 +1004,9 @@ impl Unifier {
|
||||||
};
|
};
|
||||||
n
|
n
|
||||||
}
|
}
|
||||||
|
TypeEnum::TConstant { value, .. } => {
|
||||||
|
format!("const({value})")
|
||||||
|
}
|
||||||
TypeEnum::TTuple { ty } => {
|
TypeEnum::TTuple { ty } => {
|
||||||
let mut fields =
|
let mut fields =
|
||||||
ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
|
ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
|
||||||
|
@ -983,8 +1062,8 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Unifies `a` and `b` together, and set the value to the value of `b`.
|
||||||
fn set_a_to_b(&mut self, a: Type, b: Type) {
|
fn set_a_to_b(&mut self, a: Type, b: Type) {
|
||||||
// unify a and b together, and set the value to b's value.
|
|
||||||
let table = &mut self.unification_table;
|
let table = &mut self.unification_table;
|
||||||
let ty_b = table.probe_value(b).clone();
|
let ty_b = table.probe_value(b).clone();
|
||||||
table.unify(a, b);
|
table.unify(a, b);
|
||||||
|
@ -1207,6 +1286,7 @@ impl Unifier {
|
||||||
range,
|
range,
|
||||||
name: name2.or(*name),
|
name: name2.or(*name),
|
||||||
loc: loc2.or(*loc),
|
loc: loc2.or(*loc),
|
||||||
|
is_const_generic: false,
|
||||||
};
|
};
|
||||||
Ok(Some(self.unification_table.new_key(ty.into())))
|
Ok(Some(self.unification_table.new_key(ty.into())))
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,12 @@ def Some(v: T) -> Option[T]:
|
||||||
|
|
||||||
none = Option(None)
|
none = Option(None)
|
||||||
|
|
||||||
|
class _ConstGenericMarker:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def ConstGeneric(name, constraint):
|
||||||
|
return TypeVar(name, _ConstGenericMarker, constraint)
|
||||||
|
|
||||||
def round_away_zero(x):
|
def round_away_zero(x):
|
||||||
if x >= 0.0:
|
if x >= 0.0:
|
||||||
return math.floor(x + 0.5)
|
return math.floor(x + 0.5)
|
||||||
|
@ -99,6 +105,7 @@ def patch(module):
|
||||||
module.uint32 = uint32
|
module.uint32 = uint32
|
||||||
module.uint64 = uint64
|
module.uint64 = uint64
|
||||||
module.TypeVar = TypeVar
|
module.TypeVar = TypeVar
|
||||||
|
module.ConstGeneric = ConstGeneric
|
||||||
module.Generic = Generic
|
module.Generic = Generic
|
||||||
module.extern = extern
|
module.extern = extern
|
||||||
module.Option = Option
|
module.Option = Option
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
A = ConstGeneric("A", int32)
|
||||||
|
B = ConstGeneric("B", uint32)
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
class ConstGenericClass(Generic[A]):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ConstGeneric2Class(Generic[A, B]):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class HybridGenericClass2(Generic[A, T]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class HybridGenericClass3(Generic[T, A, B]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def make_generic_2() -> ConstGenericClass[2]:
|
||||||
|
return ...
|
||||||
|
|
||||||
|
def make_generic2_1_2() -> ConstGeneric2Class[1, 2]:
|
||||||
|
return ...
|
||||||
|
|
||||||
|
def make_hybrid_class_2_int32() -> HybridGenericClass2[2, int32]:
|
||||||
|
return ...
|
||||||
|
|
||||||
|
def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, 0, 1]:
|
||||||
|
return ...
|
||||||
|
|
||||||
|
def consume_generic_2(instance: ConstGenericClass[2]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def consume_generic2_1_2(instance: ConstGeneric2Class[1, 2]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def consume_hybrid_class_2_i32(instance: HybridGenericClass2[2, int32]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, 0, 1]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def f():
|
||||||
|
consume_generic_2(make_generic_2())
|
||||||
|
consume_generic2_1_2(make_generic2_1_2())
|
||||||
|
consume_hybrid_class_2_i32(make_hybrid_class_2_int32())
|
||||||
|
consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1())
|
||||||
|
|
||||||
|
def run() -> int32:
|
||||||
|
return 0
|
|
@ -25,7 +25,7 @@ use nac3core::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use nac3parser::{
|
use nac3parser::{
|
||||||
ast::{Expr, ExprKind, StmtKind},
|
ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
|
||||||
parser,
|
parser,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -83,6 +83,11 @@ fn handle_typevar_definition(
|
||||||
|
|
||||||
match &func.node {
|
match &func.node {
|
||||||
ExprKind::Name { id, .. } if id == &"TypeVar".into() => {
|
ExprKind::Name { id, .. } if id == &"TypeVar".into() => {
|
||||||
|
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
|
||||||
|
return Err(format!("Expected string constant for first parameter of `TypeVar`, got {:?}", &args[0].node))
|
||||||
|
};
|
||||||
|
let generic_name: StrRef = ty_name.to_string().into();
|
||||||
|
|
||||||
let constraints = args
|
let constraints = args
|
||||||
.iter()
|
.iter()
|
||||||
.skip(1)
|
.skip(1)
|
||||||
|
@ -94,13 +99,50 @@ fn handle_typevar_definition(
|
||||||
primitives,
|
primitives,
|
||||||
x,
|
x,
|
||||||
Default::default(),
|
Default::default(),
|
||||||
|
None,
|
||||||
)?;
|
)?;
|
||||||
get_type_from_type_annotation_kinds(
|
get_type_from_type_annotation_kinds(
|
||||||
def_list, unifier, primitives, &ty, &mut None
|
def_list, unifier, primitives, &ty, &mut None
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
Ok(unifier.get_fresh_var_with_range(&constraints, None, None).0)
|
let loc = func.location;
|
||||||
|
|
||||||
|
if constraints.len() == 1 {
|
||||||
|
return Err(format!("A single constraint is not allowed (at {})", loc))
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).0)
|
||||||
|
}
|
||||||
|
|
||||||
|
ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => {
|
||||||
|
if args.len() != 2 {
|
||||||
|
return Err(format!("Expected 2 arguments for `ConstGeneric`, got {}", args.len()))
|
||||||
|
}
|
||||||
|
|
||||||
|
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected string constant for first parameter of `ConstGeneric`, got {:?}",
|
||||||
|
&args[0].node
|
||||||
|
))
|
||||||
|
};
|
||||||
|
let generic_name: StrRef = ty_name.to_string().into();
|
||||||
|
|
||||||
|
let ty = parse_ast_to_type_annotation_kinds(
|
||||||
|
resolver,
|
||||||
|
def_list,
|
||||||
|
unifier,
|
||||||
|
primitives,
|
||||||
|
&args[1],
|
||||||
|
Default::default(),
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
let constraint = get_type_from_type_annotation_kinds(
|
||||||
|
def_list, unifier, primitives, &ty, &mut None
|
||||||
|
)?;
|
||||||
|
let loc = func.location;
|
||||||
|
|
||||||
|
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).0)
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => Err(format!(
|
_ => Err(format!(
|
||||||
|
|
Loading…
Reference in New Issue