core: Initial implementation for const generics

This commit is contained in:
David Mak 2023-12-05 14:37:08 +08:00
parent b6dfcfcc38
commit 031e660f18
7 changed files with 365 additions and 93 deletions

View File

@ -85,7 +85,7 @@ impl SymbolValue {
.map(|val| SymbolValue::U64(val)) .map(|val| SymbolValue::U64(val))
.map_err(|e| e.to_string()) .map_err(|e| e.to_string())
} else { } else {
Err(format!("Expected {:?}, but got int", expected_ty)) Err(format!("Expected {}, but got int", unifier.stringify(expected_ty)))
} }
} }
Constant::Tuple(t) => { Constant::Tuple(t) => {

View File

@ -560,6 +560,7 @@ impl TopLevelComposer {
&primitive_types, &primitive_types,
b, b,
vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(), vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(),
None,
)?; )?;
if let TypeAnnotation::CustomClass { .. } = &base_ty { if let TypeAnnotation::CustomClass { .. } = &base_ty {
@ -894,6 +895,7 @@ impl TopLevelComposer {
// NOTE: since only class need this, for function // NOTE: since only class need this, for function
// it should be fine to be empty map // it should be fine to be empty map
HashMap::new(), HashMap::new(),
None,
)?; )?;
let type_vars_within = let type_vars_within =
@ -961,6 +963,7 @@ impl TopLevelComposer {
// NOTE: since only class need this, for function // NOTE: since only class need this, for function
// it should be fine to be empty map // it should be fine to be empty map
HashMap::new(), HashMap::new(),
None,
)? )?
}; };
@ -1158,6 +1161,7 @@ impl TopLevelComposer {
vec![(class_id, class_type_vars_def.clone())] vec![(class_id, class_type_vars_def.clone())]
.into_iter() .into_iter()
.collect(), .collect(),
None,
)? )?
}; };
// find type vars within this method parameter type annotation // find type vars within this method parameter type annotation
@ -1221,6 +1225,7 @@ impl TopLevelComposer {
primitives, primitives,
result, result,
vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), vec![(class_id, class_type_vars_def.clone())].into_iter().collect(),
None,
)?; )?;
// find type vars within this return type annotation // find type vars within this return type annotation
let type_vars_within = let type_vars_within =
@ -1317,6 +1322,7 @@ impl TopLevelComposer {
primitives, primitives,
annotation.as_ref(), annotation.as_ref(),
vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), vec![(class_id, class_type_vars_def.clone())].into_iter().collect(),
None,
)?; )?;
// find type vars within this return type annotation // find type vars within this return type annotation
let type_vars_within = let type_vars_within =
@ -1735,7 +1741,7 @@ impl TopLevelComposer {
.iter() .iter()
.map(|(_, ty)| { .map(|(_, ty)| {
unifier.get_instantiations(*ty).unwrap_or_else(|| { unifier.get_instantiations(*ty).unwrap_or_else(|| {
if let TypeEnum::TVar { name, loc, .. } = &*unifier.get_ty(*ty) if let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty)
{ {
let rigid = unifier.get_fresh_rigid_var(*name, *loc).0; let rigid = unifier.get_fresh_rigid_var(*name, *loc).0;
no_ranges.push(rigid); no_ranges.push(rigid);

View File

@ -1,3 +1,4 @@
use crate::symbol_resolver::SymbolValue;
use super::*; use super::*;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -12,6 +13,16 @@ pub enum TypeAnnotation {
// can only be CustomClassKind // can only be CustomClassKind
Virtual(Box<TypeAnnotation>), Virtual(Box<TypeAnnotation>),
TypeVar(Type), TypeVar(Type),
/// A constant used in the context of a const-generic variable.
Constant {
/// The non-type variable associated with this constant.
///
/// Invoking [Unifier::get_ty] on this type will return a [TypeEnum::TVar] representing the
/// const generic variable of which this constant is associated with.
ty: Type,
/// The constant value of this constant.
value: SymbolValue
},
List(Box<TypeAnnotation>), List(Box<TypeAnnotation>),
Tuple(Vec<TypeAnnotation>), Tuple(Vec<TypeAnnotation>),
} }
@ -47,6 +58,7 @@ impl TypeAnnotation {
} }
) )
} }
Constant { value, .. } => format!("Const({value})"),
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)), Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
List(ty) => format!("list[{}]", ty.stringify(unifier)), List(ty) => format!("list[{}]", ty.stringify(unifier)),
Tuple(types) => { Tuple(types) => {
@ -56,6 +68,12 @@ impl TypeAnnotation {
} }
} }
/// Parses an AST expression `expr` into a [TypeAnnotation].
///
/// * `locked` - A [HashMap] containing the IDs of known definitions, mapped to a [Vec] of all
/// generic variables associated with the definition.
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
/// [None] when this function is invoked externally.
pub fn parse_ast_to_type_annotation_kinds<T>( pub fn parse_ast_to_type_annotation_kinds<T>(
resolver: &(dyn SymbolResolver + Send + Sync), resolver: &(dyn SymbolResolver + Send + Sync),
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
@ -64,6 +82,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
expr: &ast::Expr<T>, expr: &ast::Expr<T>,
// the key stores the type_var of this topleveldef::class, we only need this field here // the key stores the type_var of this topleveldef::class, we only need this field here
locked: HashMap<DefinitionId, Vec<Type>>, locked: HashMap<DefinitionId, Vec<Type>>,
type_var: Option<Type>,
) -> Result<TypeAnnotation, String> { ) -> Result<TypeAnnotation, String> {
let name_handle = |id: &StrRef, let name_handle = |id: &StrRef,
unifier: &mut Unifier, unifier: &mut Unifier,
@ -161,7 +180,8 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} }
let result = params_ast let result = params_ast
.iter() .iter()
.map(|x| { .enumerate()
.map(|(idx, x)| {
parse_ast_to_type_annotation_kinds( parse_ast_to_type_annotation_kinds(
resolver, resolver,
top_level_defs, top_level_defs,
@ -172,6 +192,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
locked.insert(obj_id, type_vars.clone()); locked.insert(obj_id, type_vars.clone());
locked.clone() locked.clone()
}, },
Some(type_vars[idx]),
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
@ -190,6 +211,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
}; };
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
}; };
match &expr.node { match &expr.node {
ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked), ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked),
// virtual // virtual
@ -205,6 +227,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
slice.as_ref(), slice.as_ref(),
locked, locked,
None,
)?; )?;
if !matches!(def, TypeAnnotation::CustomClass { .. }) { if !matches!(def, TypeAnnotation::CustomClass { .. }) {
unreachable!("must be concretized custom class kind in the virtual") unreachable!("must be concretized custom class kind in the virtual")
@ -225,6 +248,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
slice.as_ref(), slice.as_ref(),
locked, locked,
None,
)?; )?;
Ok(TypeAnnotation::List(def_ann.into())) Ok(TypeAnnotation::List(def_ann.into()))
} }
@ -242,6 +266,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
slice.as_ref(), slice.as_ref(),
locked, locked,
None,
)?; )?;
let id = let id =
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() { if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
@ -275,6 +300,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
e, e,
locked.clone(), locked.clone(),
None,
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
@ -290,6 +316,31 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} }
} }
ast::ExprKind::Constant { value, .. } => {
let type_var = type_var.expect("Expect type variable to be present");
let ntv_ty_enum = unifier.get_ty_immutable(type_var);
let TypeEnum::TVar { range: underlying_ty, .. } = ntv_ty_enum.as_ref() else {
unreachable!()
};
let underlying_ty = underlying_ty[0];
let value = SymbolValue::from_constant(value, underlying_ty, primitives, unifier)?;
if matches!(value, SymbolValue::Str(_) | SymbolValue::Tuple(_) | SymbolValue::OptionSome(_)) {
return Err(format!(
"expression {} is not allowed for constant type annotation (at {})",
value.to_string(),
expr.location
))
}
Ok(TypeAnnotation::Constant {
ty: type_var,
value,
})
}
_ => Err(format!("unsupported expression for type annotation (at {})", expr.location)), _ => Err(format!("unsupported expression for type annotation (at {})", expr.location)),
} }
} }
@ -308,94 +359,130 @@ pub fn get_type_from_type_annotation_kinds(
TypeAnnotation::CustomClass { id: obj_id, params } => { TypeAnnotation::CustomClass { id: obj_id, params } => {
let def_read = top_level_defs[obj_id.0].read(); let def_read = top_level_defs[obj_id.0].read();
let class_def: &TopLevelDef = def_read.deref(); let class_def: &TopLevelDef = def_read.deref();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def { let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def else {
if type_vars.len() != params.len() { unreachable!("should be class def here")
Err(format!( };
"unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
params.len()
))
} else {
let param_ty = params
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list
)
})
.collect::<Result<Vec<_>, _>>()?;
let subst = { if type_vars.len() != params.len() {
// check for compatible range return Err(format!(
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check "unexpected number of type parameters: expected {} but got {}",
let mut result: HashMap<u32, Type> = HashMap::new(); type_vars.len(),
for (tvar, p) in type_vars.iter().zip(param_ty) { params.len()
if let TypeEnum::TVar { id, range, fields: None, name, loc } = ))
unifier.get_ty(*tvar).as_ref() }
{
let ok: bool = { let param_ty = params
// create a temp type var and unify to check compatibility .iter()
p == *tvar || { .map(|x| {
let temp = unifier.get_fresh_var_with_range( get_type_from_type_annotation_kinds(
range.as_slice(), top_level_defs,
*name, unifier,
*loc, primitives,
); x,
unifier.unify(temp.0, p).is_ok() subst_list
} )
}; })
if ok { .collect::<Result<Vec<_>, _>>()?;
result.insert(*id, p);
} else { let subst = {
return Err(format!( // check for compatible range
"cannot apply type {} to type variable with id {:?}", // TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
unifier.internal_stringify( let mut result: HashMap<u32, Type> = HashMap::new();
p, for (tvar, p) in type_vars.iter().zip(param_ty) {
&mut |id| format!("class{}", id), match unifier.get_ty(*tvar).as_ref() {
&mut |id| format!("typevar{}", id), TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false } => {
&mut None let ok: bool = {
), // create a temp type var and unify to check compatibility
*id p == *tvar || {
)); let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.0, p).is_ok()
} }
};
if ok {
result.insert(*id, p);
} else { } else {
unreachable!("must be generic type var") return Err(format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{}", id),
&mut |id| format!("typevar{}", id),
&mut None
),
*id
));
} }
} }
result
}; TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => {
let mut tobj_fields = methods let ty = range[0];
.iter() let ok: bool = {
.map(|(name, ty, _)| { // create a temp type var and unify to check compatibility
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); p == *tvar || {
// methods are immutable let temp = unifier.get_fresh_const_generic_var(
(*name, (subst_ty, false)) ty,
}) *name,
.collect::<HashMap<_, _>>(); *loc,
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| { );
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); unifier.unify(temp.0, p).is_ok()
(*name, (subst_ty, *mutability)) }
})); };
let need_subst = !subst.is_empty(); if ok {
let ty = unifier.add_ty(TypeEnum::TObj { result.insert(*id, p);
obj_id: *obj_id, } else {
fields: tobj_fields, return Err(format!(
params: subst, "cannot apply type {} to type variable {}",
}); unifier.stringify(p),
if need_subst { name.unwrap_or_else(|| format!("typevar{id}").into()),
subst_list.as_mut().map(|wl| wl.push(ty)); ))
}
}
_ => unreachable!("must be generic type var"),
} }
Ok(ty)
} }
} else { result
unreachable!("should be class def here") };
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
subst_list.as_mut().map(|wl| wl.push(ty));
} }
Ok(ty)
} }
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Constant { ty, value, .. } => {
let ty_enum = unifier.get_ty(*ty);
let (ty, loc) = match &*ty_enum {
TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } => {
(ntv_underlying_ty[0], loc)
}
_ => unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name()),
};
let var = unifier.get_fresh_constant(value.clone(), ty, *loc);
Ok(var)
}
TypeAnnotation::Virtual(ty) => { TypeAnnotation::Virtual(ty) => {
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
top_level_defs, top_level_defs,
@ -470,7 +557,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<Ty
result.extend(get_type_var_contained_in_type_annotation(a)); result.extend(get_type_var_contained_in_type_annotation(a));
} }
} }
TypeAnnotation::Primitive(..) => {} TypeAnnotation::Primitive(..) | TypeAnnotation::Constant { .. } => {}
} }
result result
} }

View File

@ -134,6 +134,17 @@ pub enum TypeEnum {
range: Vec<Type>, range: Vec<Type>,
name: Option<StrRef>, name: Option<StrRef>,
loc: Option<Location>, loc: Option<Location>,
/// Whether this type variable refers to a const-generic variable.
is_const_generic: bool,
},
/// A constant for substitution into a const generic variable.
TConstant {
/// The value of the constant.
value: SymbolValue,
/// The underlying type of the value.
ty: Type,
loc: Option<Location>,
}, },
/// A tuple type. /// A tuple type.
@ -178,6 +189,7 @@ impl TypeEnum {
match self { match self {
TypeEnum::TRigidVar { .. } => "TRigidVar", TypeEnum::TRigidVar { .. } => "TRigidVar",
TypeEnum::TVar { .. } => "TVar", TypeEnum::TVar { .. } => "TVar",
TypeEnum::TConstant { .. } => "TConstant",
TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TTuple { .. } => "TTuple",
TypeEnum::TList { .. } => "TList", TypeEnum::TList { .. } => "TList",
TypeEnum::TObj { .. } => "TObj", TypeEnum::TObj { .. } => "TObj",
@ -263,6 +275,7 @@ impl Unifier {
fields: Some(fields), fields: Some(fields),
name: None, name: None,
loc: None, loc: None,
is_const_generic: false,
}) })
} }
@ -336,7 +349,33 @@ impl Unifier {
let id = self.var_id + 1; let id = self.var_id + 1;
self.var_id += 1; self.var_id += 1;
let range = range.to_vec(); let range = range.to_vec();
(self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc }), id) (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false }), id)
}
/// Returns a fresh type representing a constant generic variable with the given underlying type
/// `ty`.
pub fn get_fresh_const_generic_var(
&mut self,
ty: Type,
name: Option<StrRef>,
loc: Option<Location>,
) -> (Type, u32) {
let id = self.var_id + 1;
self.var_id += 1;
(self.add_ty(TypeEnum::TVar { id, range: vec![ty], fields: None, name, loc, is_const_generic: true }), id)
}
/// Returns a fresh type representing a [fresh constant][TypeEnum::TConstant] with the given
/// `value` and type `ty`.
pub fn get_fresh_constant(
&mut self,
value: SymbolValue,
ty: Type,
loc: Option<Location>,
) -> Type {
assert!(matches!(self.get_ty(ty).as_ref(), TypeEnum::TObj { .. }));
self.add_ty(TypeEnum::TConstant { ty, value, loc })
} }
/// Unification would not unify rigid variables with other types, but we want to do this for /// Unification would not unify rigid variables with other types, but we want to do this for
@ -412,7 +451,7 @@ impl Unifier {
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
use TypeEnum::*; use TypeEnum::*;
match &*self.get_ty(a) { match &*self.get_ty(a) {
TRigidVar { .. } => true, TRigidVar { .. } | TConstant { .. } => true,
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false, TCall { .. } => false,
TList { ty } => self.is_concrete(*ty, allowed_typevars), TList { ty } => self.is_concrete(*ty, allowed_typevars),
@ -560,8 +599,8 @@ impl Unifier {
}; };
match (&*ty_a, &*ty_b) { match (&*ty_a, &*ty_b) {
( (
TVar { fields: fields1, id, name: name1, loc: loc1, .. }, TVar { fields: fields1, id, name: name1, loc: loc1, is_const_generic: false, .. },
TVar { fields: fields2, id: id2, name: name2, loc: loc2, .. }, TVar { fields: fields2, id: id2, name: name2, loc: loc2, is_const_generic: false, .. },
) => { ) => {
let new_fields = match (fields1, fields2) { let new_fields = match (fields1, fields2) {
(None, None) => None, (None, None) => None,
@ -616,10 +655,11 @@ impl Unifier {
range, range,
name: name1.or(*name2), name: name1.or(*name2),
loc: loc1.or(*loc2), loc: loc1.or(*loc2),
is_const_generic: false,
}), }),
); );
} }
(TVar { fields: None, range, .. }, _) => { (TVar { fields: None, range, is_const_generic: false, .. }, _) => {
// We check for the range of the type variable to see if unification is allowed. // We check for the range of the type variable to see if unification is allowed.
// Note that although b may be compatible with a, we may have to constrain type // Note that although b may be compatible with a, we may have to constrain type
// variables in b to make sure that instantiations of b would always be compatible // variables in b to make sure that instantiations of b would always be compatible
@ -636,7 +676,7 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { fields: Some(fields), range, .. }, TTuple { ty }) => { (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => {
let len = ty.len() as i32; let len = ty.len() as i32;
for (k, v) in fields.iter() { for (k, v) in fields.iter() {
match *k { match *k {
@ -666,7 +706,7 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { fields: Some(fields), range, .. }, TList { ty }) => { (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => {
for (k, v) in fields.iter() { for (k, v) in fields.iter() {
match *k { match *k {
RecordKey::Int(_) => { RecordKey::Int(_) => {
@ -681,6 +721,35 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { id: id1, range: ty1, is_const_generic: true, .. }, TVar { id: id2, range: ty2, .. }) => {
let ty1 = ty1[0];
let ty2 = ty2[0];
if id1 != id2 {
self.unify_impl(ty1, ty2, false)?;
}
self.set_a_to_b(a, b);
}
(TVar { range: ty1, is_const_generic: true, .. }, TConstant { ty: ty2, .. }) => {
let ty1 = ty1[0];
self.unify_impl(ty1, *ty2, false)?;
self.set_a_to_b(a, b);
}
(TConstant { value: val1, ty: ty1, .. }, TConstant { value: val2, ty: ty2, .. }) => {
if val1 != val2 {
eprintln!("VALUE MISMATCH: lhs={val1:?} rhs={val2:?} eq={}", val1 == val2);
return self.incompatible_types(a, b)
}
self.unify_impl(*ty1, *ty2, false)?;
self.set_a_to_b(a, b);
}
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
if ty1.len() != ty2.len() { if ty1.len() != ty2.len() {
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
@ -775,7 +844,14 @@ impl Unifier {
if id1 != id2 { if id1 != id2 {
self.incompatible_types(a, b)?; self.incompatible_types(a, b)?;
} }
for (x, y) in zip(params1.values(), params2.values()) {
// Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits
// all K-V pairs "in arbitrary order"
let (tv1, tv2) = (
params1.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
params2.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
);
for (x, y) in zip(tv1, tv2) {
if self.unify_impl(*x, *y, false).is_err() { if self.unify_impl(*x, *y, false).is_err() {
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
}; };
@ -928,6 +1004,9 @@ impl Unifier {
}; };
n n
} }
TypeEnum::TConstant { value, .. } => {
format!("const({value})")
}
TypeEnum::TTuple { ty } => { TypeEnum::TTuple { ty } => {
let mut fields = let mut fields =
ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
@ -983,8 +1062,8 @@ impl Unifier {
} }
} }
/// Unifies `a` and `b` together, and set the value to the value of `b`.
fn set_a_to_b(&mut self, a: Type, b: Type) { fn set_a_to_b(&mut self, a: Type, b: Type) {
// unify a and b together, and set the value to b's value.
let table = &mut self.unification_table; let table = &mut self.unification_table;
let ty_b = table.probe_value(b).clone(); let ty_b = table.probe_value(b).clone();
table.unify(a, b); table.unify(a, b);
@ -1207,6 +1286,7 @@ impl Unifier {
range, range,
name: name2.or(*name), name: name2.or(*name),
loc: loc2.or(*loc), loc: loc2.or(*loc),
is_const_generic: false,
}; };
Ok(Some(self.unification_table.new_key(ty.into()))) Ok(Some(self.unification_table.new_key(ty.into())))
} }

View File

@ -44,6 +44,12 @@ def Some(v: T) -> Option[T]:
none = Option(None) none = Option(None)
class _ConstGenericMarker:
pass
def ConstGeneric(name, constraint):
return TypeVar(name, _ConstGenericMarker, constraint)
def round_away_zero(x): def round_away_zero(x):
if x >= 0.0: if x >= 0.0:
return math.floor(x + 0.5) return math.floor(x + 0.5)
@ -99,6 +105,7 @@ def patch(module):
module.uint32 = uint32 module.uint32 = uint32
module.uint64 = uint64 module.uint64 = uint64
module.TypeVar = TypeVar module.TypeVar = TypeVar
module.ConstGeneric = ConstGeneric
module.Generic = Generic module.Generic = Generic
module.extern = extern module.extern = extern
module.Option = Option module.Option = Option

View File

@ -0,0 +1,50 @@
A = ConstGeneric("A", int32)
B = ConstGeneric("B", uint32)
T = TypeVar("T")
class ConstGenericClass(Generic[A]):
def __init__(self):
pass
class ConstGeneric2Class(Generic[A, B]):
def __init__(self):
pass
class HybridGenericClass2(Generic[A, T]):
pass
class HybridGenericClass3(Generic[T, A, B]):
pass
def make_generic_2() -> ConstGenericClass[2]:
return ...
def make_generic2_1_2() -> ConstGeneric2Class[1, 2]:
return ...
def make_hybrid_class_2_int32() -> HybridGenericClass2[2, int32]:
return ...
def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, 0, 1]:
return ...
def consume_generic_2(instance: ConstGenericClass[2]):
pass
def consume_generic2_1_2(instance: ConstGeneric2Class[1, 2]):
pass
def consume_hybrid_class_2_i32(instance: HybridGenericClass2[2, int32]):
pass
def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, 0, 1]):
pass
def f():
consume_generic_2(make_generic_2())
consume_generic2_1_2(make_generic2_1_2())
consume_hybrid_class_2_i32(make_hybrid_class_2_int32())
consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1())
def run() -> int32:
return 0

View File

@ -25,7 +25,7 @@ use nac3core::{
}, },
}; };
use nac3parser::{ use nac3parser::{
ast::{Expr, ExprKind, StmtKind}, ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
parser, parser,
}; };
@ -83,6 +83,11 @@ fn handle_typevar_definition(
match &func.node { match &func.node {
ExprKind::Name { id, .. } if id == &"TypeVar".into() => { ExprKind::Name { id, .. } if id == &"TypeVar".into() => {
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
return Err(format!("Expected string constant for first parameter of `TypeVar`, got {:?}", &args[0].node))
};
let generic_name: StrRef = ty_name.to_string().into();
let constraints = args let constraints = args
.iter() .iter()
.skip(1) .skip(1)
@ -94,13 +99,50 @@ fn handle_typevar_definition(
primitives, primitives,
x, x,
Default::default(), Default::default(),
None,
)?; )?;
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None def_list, unifier, primitives, &ty, &mut None
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.get_fresh_var_with_range(&constraints, None, None).0) let loc = func.location;
if constraints.len() == 1 {
return Err(format!("A single constraint is not allowed (at {})", loc))
}
Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).0)
}
ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => {
if args.len() != 2 {
return Err(format!("Expected 2 arguments for `ConstGeneric`, got {}", args.len()))
}
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
return Err(format!(
"Expected string constant for first parameter of `ConstGeneric`, got {:?}",
&args[0].node
))
};
let generic_name: StrRef = ty_name.to_string().into();
let ty = parse_ast_to_type_annotation_kinds(
resolver,
def_list,
unifier,
primitives,
&args[1],
Default::default(),
None,
)?;
let constraint = get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None
)?;
let loc = func.location;
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).0)
} }
_ => Err(format!( _ => Err(format!(