core: Initial implementation for const generics

This commit is contained in:
David Mak 2023-12-05 14:37:08 +08:00
parent b9b725e78e
commit 638d9f8a30
7 changed files with 366 additions and 94 deletions

View File

@ -85,7 +85,7 @@ impl SymbolValue {
.map(|val| SymbolValue::U64(val)) .map(|val| SymbolValue::U64(val))
.map_err(|e| e.to_string()) .map_err(|e| e.to_string())
} else { } else {
Err(format!("Expected {:?}, but got int", expected_ty)) Err(format!("Expected {}, but got int", unifier.stringify(expected_ty)))
} }
} }
Constant::Tuple(t) => { Constant::Tuple(t) => {

View File

@ -560,6 +560,7 @@ impl TopLevelComposer {
&primitive_types, &primitive_types,
b, b,
vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(), vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(),
None,
)?; )?;
if let TypeAnnotation::CustomClass { .. } = &base_ty { if let TypeAnnotation::CustomClass { .. } = &base_ty {
@ -894,6 +895,7 @@ impl TopLevelComposer {
// NOTE: since only class need this, for function // NOTE: since only class need this, for function
// it should be fine to be empty map // it should be fine to be empty map
HashMap::new(), HashMap::new(),
None,
)?; )?;
let type_vars_within = let type_vars_within =
@ -961,6 +963,7 @@ impl TopLevelComposer {
// NOTE: since only class need this, for function // NOTE: since only class need this, for function
// it should be fine to be empty map // it should be fine to be empty map
HashMap::new(), HashMap::new(),
None,
)? )?
}; };
@ -1158,6 +1161,7 @@ impl TopLevelComposer {
vec![(class_id, class_type_vars_def.clone())] vec![(class_id, class_type_vars_def.clone())]
.into_iter() .into_iter()
.collect(), .collect(),
None,
)? )?
}; };
// find type vars within this method parameter type annotation // find type vars within this method parameter type annotation
@ -1221,6 +1225,7 @@ impl TopLevelComposer {
primitives, primitives,
result, result,
vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), vec![(class_id, class_type_vars_def.clone())].into_iter().collect(),
None,
)?; )?;
// find type vars within this return type annotation // find type vars within this return type annotation
let type_vars_within = let type_vars_within =
@ -1317,6 +1322,7 @@ impl TopLevelComposer {
primitives, primitives,
annotation.as_ref(), annotation.as_ref(),
vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), vec![(class_id, class_type_vars_def.clone())].into_iter().collect(),
None,
)?; )?;
// find type vars within this return type annotation // find type vars within this return type annotation
let type_vars_within = let type_vars_within =
@ -1735,7 +1741,7 @@ impl TopLevelComposer {
.iter() .iter()
.map(|(_, ty)| { .map(|(_, ty)| {
unifier.get_instantiations(*ty).unwrap_or_else(|| { unifier.get_instantiations(*ty).unwrap_or_else(|| {
if let TypeEnum::TVar { name, loc, .. } = &*unifier.get_ty(*ty) if let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty)
{ {
let rigid = unifier.get_fresh_rigid_var(*name, *loc).0; let rigid = unifier.get_fresh_rigid_var(*name, *loc).0;
no_ranges.push(rigid); no_ranges.push(rigid);

View File

@ -1,3 +1,4 @@
use crate::symbol_resolver::SymbolValue;
use super::*; use super::*;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -12,6 +13,16 @@ pub enum TypeAnnotation {
// can only be CustomClassKind // can only be CustomClassKind
Virtual(Box<TypeAnnotation>), Virtual(Box<TypeAnnotation>),
TypeVar(Type), TypeVar(Type),
/// A constant used in the context of a const-generic variable.
Constant {
/// The non-type variable associated with this constant.
///
/// Invoking [Unifier::get_ty] on this type will return a [TypeEnum::TVar] representing the
/// const generic variable of which this constant is associated with.
ty: Type,
/// The constant value of this constant.
value: SymbolValue
},
List(Box<TypeAnnotation>), List(Box<TypeAnnotation>),
Tuple(Vec<TypeAnnotation>), Tuple(Vec<TypeAnnotation>),
} }
@ -47,6 +58,7 @@ impl TypeAnnotation {
} }
) )
} }
Constant { value, .. } => format!("Const({value})"),
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)), Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
List(ty) => format!("list[{}]", ty.stringify(unifier)), List(ty) => format!("list[{}]", ty.stringify(unifier)),
Tuple(types) => { Tuple(types) => {
@ -56,6 +68,12 @@ impl TypeAnnotation {
} }
} }
/// Parses an AST expression `expr` into a [TypeAnnotation].
///
/// * `locked` - A [HashMap] containing the IDs of known definitions, mapped to a [Vec] of all
/// generic variables associated with the definition.
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
/// [None] when this function is invoked externally.
pub fn parse_ast_to_type_annotation_kinds<T>( pub fn parse_ast_to_type_annotation_kinds<T>(
resolver: &(dyn SymbolResolver + Send + Sync), resolver: &(dyn SymbolResolver + Send + Sync),
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
@ -64,6 +82,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
expr: &ast::Expr<T>, expr: &ast::Expr<T>,
// the key stores the type_var of this topleveldef::class, we only need this field here // the key stores the type_var of this topleveldef::class, we only need this field here
locked: HashMap<DefinitionId, Vec<Type>>, locked: HashMap<DefinitionId, Vec<Type>>,
type_var: Option<Type>,
) -> Result<TypeAnnotation, String> { ) -> Result<TypeAnnotation, String> {
let name_handle = |id: &StrRef, let name_handle = |id: &StrRef,
unifier: &mut Unifier, unifier: &mut Unifier,
@ -161,7 +180,8 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} }
let result = params_ast let result = params_ast
.iter() .iter()
.map(|x| { .enumerate()
.map(|(idx, x)| {
parse_ast_to_type_annotation_kinds( parse_ast_to_type_annotation_kinds(
resolver, resolver,
top_level_defs, top_level_defs,
@ -172,6 +192,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
locked.insert(obj_id, type_vars.clone()); locked.insert(obj_id, type_vars.clone());
locked.clone() locked.clone()
}, },
Some(type_vars[idx]),
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
@ -190,6 +211,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
}; };
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
}; };
match &expr.node { match &expr.node {
ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked), ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked),
// virtual // virtual
@ -205,6 +227,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
slice.as_ref(), slice.as_ref(),
locked, locked,
None,
)?; )?;
if !matches!(def, TypeAnnotation::CustomClass { .. }) { if !matches!(def, TypeAnnotation::CustomClass { .. }) {
unreachable!("must be concretized custom class kind in the virtual") unreachable!("must be concretized custom class kind in the virtual")
@ -225,6 +248,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
slice.as_ref(), slice.as_ref(),
locked, locked,
None,
)?; )?;
Ok(TypeAnnotation::List(def_ann.into())) Ok(TypeAnnotation::List(def_ann.into()))
} }
@ -242,6 +266,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
slice.as_ref(), slice.as_ref(),
locked, locked,
None,
)?; )?;
let id = let id =
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() { if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
@ -275,6 +300,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
e, e,
locked.clone(), locked.clone(),
None,
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
@ -290,6 +316,31 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} }
} }
ast::ExprKind::Constant { value, .. } => {
let type_var = type_var.expect("Expect type variable to be present");
let ntv_ty_enum = unifier.get_ty_immutable(type_var);
let TypeEnum::TVar { range: underlying_ty, .. } = ntv_ty_enum.as_ref() else {
unreachable!()
};
let underlying_ty = underlying_ty[0];
let value = SymbolValue::from_constant(value, underlying_ty, primitives, unifier)?;
if matches!(value, SymbolValue::Str(_) | SymbolValue::Tuple(_) | SymbolValue::OptionSome(_)) {
return Err(format!(
"expression {} is not allowed for constant type annotation (at {})",
value.to_string(),
expr.location
))
}
Ok(TypeAnnotation::Constant {
ty: type_var,
value,
})
}
_ => Err(format!("unsupported expression for type annotation (at {})", expr.location)), _ => Err(format!("unsupported expression for type annotation (at {})", expr.location)),
} }
} }
@ -308,94 +359,130 @@ pub fn get_type_from_type_annotation_kinds(
TypeAnnotation::CustomClass { id: obj_id, params } => { TypeAnnotation::CustomClass { id: obj_id, params } => {
let def_read = top_level_defs[obj_id.0].read(); let def_read = top_level_defs[obj_id.0].read();
let class_def: &TopLevelDef = def_read.deref(); let class_def: &TopLevelDef = def_read.deref();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def { let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def else {
if type_vars.len() != params.len() { unreachable!("should be class def here")
Err(format!( };
"unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
params.len()
))
} else {
let param_ty = params
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list
)
})
.collect::<Result<Vec<_>, _>>()?;
let subst = { if type_vars.len() != params.len() {
// check for compatible range return Err(format!(
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check "unexpected number of type parameters: expected {} but got {}",
let mut result: HashMap<u32, Type> = HashMap::new(); type_vars.len(),
for (tvar, p) in type_vars.iter().zip(param_ty) { params.len()
if let TypeEnum::TVar { id, range, fields: None, name, loc } = ))
unifier.get_ty(*tvar).as_ref() }
{
let ok: bool = { let param_ty = params
// create a temp type var and unify to check compatibility .iter()
p == *tvar || { .map(|x| {
let temp = unifier.get_fresh_var_with_range( get_type_from_type_annotation_kinds(
range.as_slice(), top_level_defs,
*name, unifier,
*loc, primitives,
); x,
unifier.unify(temp.0, p).is_ok() subst_list
} )
}; })
if ok { .collect::<Result<Vec<_>, _>>()?;
result.insert(*id, p);
} else { let subst = {
return Err(format!( // check for compatible range
"cannot apply type {} to type variable with id {:?}", // TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
unifier.internal_stringify( let mut result: HashMap<u32, Type> = HashMap::new();
p, for (tvar, p) in type_vars.iter().zip(param_ty) {
&mut |id| format!("class{}", id), match unifier.get_ty(*tvar).as_ref() {
&mut |id| format!("typevar{}", id), TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false } => {
&mut None let ok: bool = {
), // create a temp type var and unify to check compatibility
*id p == *tvar || {
)); let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.0, p).is_ok()
} }
};
if ok {
result.insert(*id, p);
} else { } else {
unreachable!("must be generic type var") return Err(format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{}", id),
&mut |id| format!("typevar{}", id),
&mut None
),
*id
));
} }
} }
result
}; TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => {
let mut tobj_fields = methods let ty = range[0];
.iter() let ok: bool = {
.map(|(name, ty, _)| { // create a temp type var and unify to check compatibility
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); p == *tvar || {
// methods are immutable let temp = unifier.get_fresh_const_generic_var(
(*name, (subst_ty, false)) ty,
}) *name,
.collect::<HashMap<_, _>>(); *loc,
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| { );
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); unifier.unify(temp.0, p).is_ok()
(*name, (subst_ty, *mutability)) }
})); };
let need_subst = !subst.is_empty(); if ok {
let ty = unifier.add_ty(TypeEnum::TObj { result.insert(*id, p);
obj_id: *obj_id, } else {
fields: tobj_fields, return Err(format!(
params: subst, "cannot apply type {} to type variable {}",
}); unifier.stringify(p),
if need_subst { name.unwrap_or_else(|| format!("typevar{id}").into()),
subst_list.as_mut().map(|wl| wl.push(ty)); ))
}
}
_ => unreachable!("must be generic type var"),
} }
Ok(ty)
} }
} else { result
unreachable!("should be class def here") };
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
subst_list.as_mut().map(|wl| wl.push(ty));
} }
Ok(ty)
} }
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Constant { ty, value, .. } => {
let ty_enum = unifier.get_ty(*ty);
let (ty, loc) = match &*ty_enum {
TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } => {
(ntv_underlying_ty[0], loc)
}
_ => unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name()),
};
let var = unifier.get_fresh_constant(value.clone(), ty, *loc);
Ok(var)
}
TypeAnnotation::Virtual(ty) => { TypeAnnotation::Virtual(ty) => {
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
top_level_defs, top_level_defs,
@ -470,7 +557,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<Ty
result.extend(get_type_var_contained_in_type_annotation(a)); result.extend(get_type_var_contained_in_type_annotation(a));
} }
} }
TypeAnnotation::Primitive(..) => {} TypeAnnotation::Primitive(..) | TypeAnnotation::Constant { .. } => {}
} }
result result
} }

View File

@ -134,6 +134,17 @@ pub enum TypeEnum {
range: Vec<Type>, range: Vec<Type>,
name: Option<StrRef>, name: Option<StrRef>,
loc: Option<Location>, loc: Option<Location>,
/// Whether this type variable refers to a const-generic variable.
is_const_generic: bool,
},
/// A constant for substitution into a const generic variable.
TConstant {
/// The value of the constant.
value: SymbolValue,
/// The underlying type of the value.
ty: Type,
loc: Option<Location>,
}, },
/// A tuple type. /// A tuple type.
@ -178,6 +189,7 @@ impl TypeEnum {
match self { match self {
TypeEnum::TRigidVar { .. } => "TRigidVar", TypeEnum::TRigidVar { .. } => "TRigidVar",
TypeEnum::TVar { .. } => "TVar", TypeEnum::TVar { .. } => "TVar",
TypeEnum::TConstant { .. } => "TConstant",
TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TTuple { .. } => "TTuple",
TypeEnum::TList { .. } => "TList", TypeEnum::TList { .. } => "TList",
TypeEnum::TObj { .. } => "TObj", TypeEnum::TObj { .. } => "TObj",
@ -263,6 +275,7 @@ impl Unifier {
fields: Some(fields), fields: Some(fields),
name: None, name: None,
loc: None, loc: None,
is_const_generic: false,
}) })
} }
@ -336,7 +349,33 @@ impl Unifier {
let id = self.var_id + 1; let id = self.var_id + 1;
self.var_id += 1; self.var_id += 1;
let range = range.to_vec(); let range = range.to_vec();
(self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc }), id) (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false }), id)
}
/// Returns a fresh type representing a constant generic variable with the given underlying type
/// `ty`.
pub fn get_fresh_const_generic_var(
&mut self,
ty: Type,
name: Option<StrRef>,
loc: Option<Location>,
) -> (Type, u32) {
let id = self.var_id + 1;
self.var_id += 1;
(self.add_ty(TypeEnum::TVar { id, range: vec![ty], fields: None, name, loc, is_const_generic: true }), id)
}
/// Returns a fresh type representing a [fresh constant][TypeEnum::TConstant] with the given
/// `value` and type `ty`.
pub fn get_fresh_constant(
&mut self,
value: SymbolValue,
ty: Type,
loc: Option<Location>,
) -> Type {
assert!(matches!(self.get_ty(ty).as_ref(), TypeEnum::TObj { .. }));
self.add_ty(TypeEnum::TConstant { ty, value, loc })
} }
/// Unification would not unify rigid variables with other types, but we want to do this for /// Unification would not unify rigid variables with other types, but we want to do this for
@ -412,7 +451,7 @@ impl Unifier {
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
use TypeEnum::*; use TypeEnum::*;
match &*self.get_ty(a) { match &*self.get_ty(a) {
TRigidVar { .. } => true, TRigidVar { .. } | TConstant { .. } => true,
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false, TCall { .. } => false,
TList { ty } => self.is_concrete(*ty, allowed_typevars), TList { ty } => self.is_concrete(*ty, allowed_typevars),
@ -560,8 +599,8 @@ impl Unifier {
}; };
match (&*ty_a, &*ty_b) { match (&*ty_a, &*ty_b) {
( (
TVar { fields: fields1, id, name: name1, loc: loc1, .. }, TVar { fields: fields1, id, name: name1, loc: loc1, is_const_generic: false, .. },
TVar { fields: fields2, id: id2, name: name2, loc: loc2, .. }, TVar { fields: fields2, id: id2, name: name2, loc: loc2, is_const_generic: false, .. },
) => { ) => {
let new_fields = match (fields1, fields2) { let new_fields = match (fields1, fields2) {
(None, None) => None, (None, None) => None,
@ -616,10 +655,11 @@ impl Unifier {
range, range,
name: name1.or(*name2), name: name1.or(*name2),
loc: loc1.or(*loc2), loc: loc1.or(*loc2),
is_const_generic: false,
}), }),
); );
} }
(TVar { fields: None, range, .. }, _) => { (TVar { fields: None, range, is_const_generic: false, .. }, _) => {
// We check for the range of the type variable to see if unification is allowed. // We check for the range of the type variable to see if unification is allowed.
// Note that although b may be compatible with a, we may have to constrain type // Note that although b may be compatible with a, we may have to constrain type
// variables in b to make sure that instantiations of b would always be compatible // variables in b to make sure that instantiations of b would always be compatible
@ -636,7 +676,7 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { fields: Some(fields), range, .. }, TTuple { ty }) => { (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => {
let len = ty.len() as i32; let len = ty.len() as i32;
for (k, v) in fields.iter() { for (k, v) in fields.iter() {
match *k { match *k {
@ -666,7 +706,7 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { fields: Some(fields), range, .. }, TList { ty }) => { (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => {
for (k, v) in fields.iter() { for (k, v) in fields.iter() {
match *k { match *k {
RecordKey::Int(_) => { RecordKey::Int(_) => {
@ -681,6 +721,35 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { id: id1, range: ty1, is_const_generic: true, .. }, TVar { id: id2, range: ty2, .. }) => {
let ty1 = ty1[0];
let ty2 = ty2[0];
if id1 != id2 {
self.unify_impl(ty1, ty2, false)?;
}
self.set_a_to_b(a, b);
}
(TVar { range: ty1, is_const_generic: true, .. }, TConstant { ty: ty2, .. }) => {
let ty1 = ty1[0];
self.unify_impl(ty1, *ty2, false)?;
self.set_a_to_b(a, b);
}
(TConstant { value: val1, ty: ty1, .. }, TConstant { value: val2, ty: ty2, .. }) => {
if val1 != val2 {
eprintln!("VALUE MISMATCH: lhs={val1:?} rhs={val2:?} eq={}", val1 == val2);
return self.incompatible_types(a, b)
}
self.unify_impl(*ty1, *ty2, false)?;
self.set_a_to_b(a, b);
}
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
if ty1.len() != ty2.len() { if ty1.len() != ty2.len() {
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
@ -775,7 +844,14 @@ impl Unifier {
if id1 != id2 { if id1 != id2 {
self.incompatible_types(a, b)?; self.incompatible_types(a, b)?;
} }
for (x, y) in zip(params1.values(), params2.values()) {
// Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits
// all K-V pairs "in arbitrary order"
let (tv1, tv2) = (
params1.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
params2.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
);
for (x, y) in zip(tv1, tv2) {
if self.unify_impl(*x, *y, false).is_err() { if self.unify_impl(*x, *y, false).is_err() {
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
}; };
@ -928,6 +1004,9 @@ impl Unifier {
}; };
n n
} }
TypeEnum::TConstant { value, .. } => {
format!("const({value})")
}
TypeEnum::TTuple { ty } => { TypeEnum::TTuple { ty } => {
let mut fields = let mut fields =
ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
@ -983,8 +1062,8 @@ impl Unifier {
} }
} }
/// Unifies `a` and `b` together, and set the value to the value of `b`.
fn set_a_to_b(&mut self, a: Type, b: Type) { fn set_a_to_b(&mut self, a: Type, b: Type) {
// unify a and b together, and set the value to b's value.
let table = &mut self.unification_table; let table = &mut self.unification_table;
let ty_b = table.probe_value(b).clone(); let ty_b = table.probe_value(b).clone();
table.unify(a, b); table.unify(a, b);
@ -1207,6 +1286,7 @@ impl Unifier {
range, range,
name: name2.or(*name), name: name2.or(*name),
loc: loc2.or(*loc), loc: loc2.or(*loc),
is_const_generic: false,
}; };
Ok(Some(self.unification_table.new_key(ty.into()))) Ok(Some(self.unification_table.new_key(ty.into())))
} }

View File

@ -9,7 +9,7 @@ import pathlib
from numpy import int32, int64, uint32, uint64 from numpy import int32, int64, uint32, uint64
from scipy import special from scipy import special
from typing import TypeVar, Generic from typing import TypeVar, Generic, Any
T = TypeVar('T') T = TypeVar('T')
class Option(Generic[T]): class Option(Generic[T]):
@ -44,6 +44,12 @@ def Some(v: T) -> Option[T]:
none = Option(None) none = Option(None)
class _ConstGenericDummy:
pass
def ConstGeneric(name, constraint):
return TypeVar(name, _ConstGenericDummy, constraint)
def round_away_zero(x): def round_away_zero(x):
if x >= 0.0: if x >= 0.0:
return math.floor(x + 0.5) return math.floor(x + 0.5)
@ -99,6 +105,7 @@ def patch(module):
module.uint32 = uint32 module.uint32 = uint32
module.uint64 = uint64 module.uint64 = uint64
module.TypeVar = TypeVar module.TypeVar = TypeVar
module.ConstGeneric = ConstGeneric
module.Generic = Generic module.Generic = Generic
module.extern = extern module.extern = extern
module.Option = Option module.Option = Option

View File

@ -0,0 +1,50 @@
A = ConstGeneric("A", int32)
B = ConstGeneric("B", uint32)
T = TypeVar("T")
class ConstGenericClass(Generic[A]):
def __init__(self):
pass
class ConstGeneric2Class(Generic[A, B]):
def __init__(self):
pass
class HybridGenericClass2(Generic[A, T]):
pass
class HybridGenericClass3(Generic[T, A, B]):
pass
def make_generic_2() -> ConstGenericClass[2]:
return ...
def make_generic2_1_2() -> ConstGeneric2Class[1, 2]:
return ...
def make_hybrid_class_2_int32() -> HybridGenericClass2[2, int32]:
return ...
def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, 0, 1]:
return ...
def consume_generic_2(instance: ConstGenericClass[2]):
pass
def consume_generic2_1_2(instance: ConstGeneric2Class[1, 2]):
pass
def consume_hybrid_class_2_i32(instance: HybridGenericClass2[2, int32]):
pass
def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, 0, 1]):
pass
def f():
consume_generic_2(make_generic_2())
consume_generic2_1_2(make_generic2_1_2())
consume_hybrid_class_2_i32(make_hybrid_class_2_int32())
consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1())
def run() -> int32:
return 0

View File

@ -25,7 +25,7 @@ use nac3core::{
}, },
}; };
use nac3parser::{ use nac3parser::{
ast::{Expr, ExprKind, StmtKind}, ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
parser, parser,
}; };
@ -83,6 +83,11 @@ fn handle_typevar_definition(
match &func.node { match &func.node {
ExprKind::Name { id, .. } if id == &"TypeVar".into() => { ExprKind::Name { id, .. } if id == &"TypeVar".into() => {
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
unreachable!("Expected string constant for first parameter of `TypeVar`, got {:?}", &args[0].node)
};
let generic_name: StrRef = ty_name.to_string().into();
let constraints = args let constraints = args
.iter() .iter()
.skip(1) .skip(1)
@ -94,13 +99,50 @@ fn handle_typevar_definition(
primitives, primitives,
x, x,
Default::default(), Default::default(),
None,
)?; )?;
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None def_list, unifier, primitives, &ty, &mut None
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.get_fresh_var_with_range(&constraints, None, None).0) let loc = func.location;
if constraints.len() == 1 {
return Err(format!("A single constraint is not allowed (at {})", loc))
}
Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).0)
}
ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => {
if args.len() != 2 {
return Err(format!("Expected 2 arguments for `ConstGeneric`, got {}", args.len()))
}
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
return Err(format!(
"Expected string constant for first parameter of `ConstGeneric`, got {:?}",
&args[0].node
))
};
let generic_name: StrRef = ty_name.to_string().into();
let ty = parse_ast_to_type_annotation_kinds(
resolver,
def_list,
unifier,
primitives,
&args[1],
Default::default(),
None,
)?;
let constraint = get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None
)?;
let loc = func.location;
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).0)
} }
_ => Err(format!( _ => Err(format!(