From 638d9f8a3072264049d946fc2b19d230dd0e2328 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 5 Dec 2023 14:37:08 +0800 Subject: [PATCH] core: Initial implementation for const generics --- nac3core/src/symbol_resolver.rs | 2 +- nac3core/src/toplevel/composer.rs | 8 +- nac3core/src/toplevel/type_annotation.rs | 247 +++++++++++++++-------- nac3core/src/typecheck/typedef/mod.rs | 98 ++++++++- nac3standalone/demo/interpret_demo.py | 9 +- nac3standalone/demo/src/const_generic.py | 50 +++++ nac3standalone/src/main.rs | 46 ++++- 7 files changed, 366 insertions(+), 94 deletions(-) create mode 100644 nac3standalone/demo/src/const_generic.py diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 8e61e5e..dff3271 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -85,7 +85,7 @@ impl SymbolValue { .map(|val| SymbolValue::U64(val)) .map_err(|e| e.to_string()) } else { - Err(format!("Expected {:?}, but got int", expected_ty)) + Err(format!("Expected {}, but got int", unifier.stringify(expected_ty))) } } Constant::Tuple(t) => { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 45ace1e..683d911 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -560,6 +560,7 @@ impl TopLevelComposer { &primitive_types, b, vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(), + None, )?; if let TypeAnnotation::CustomClass { .. } = &base_ty { @@ -894,6 +895,7 @@ impl TopLevelComposer { // NOTE: since only class need this, for function // it should be fine to be empty map HashMap::new(), + None, )?; let type_vars_within = @@ -961,6 +963,7 @@ impl TopLevelComposer { // NOTE: since only class need this, for function // it should be fine to be empty map HashMap::new(), + None, )? }; @@ -1158,6 +1161,7 @@ impl TopLevelComposer { vec![(class_id, class_type_vars_def.clone())] .into_iter() .collect(), + None, )? }; // find type vars within this method parameter type annotation @@ -1221,6 +1225,7 @@ impl TopLevelComposer { primitives, result, vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), + None, )?; // find type vars within this return type annotation let type_vars_within = @@ -1317,6 +1322,7 @@ impl TopLevelComposer { primitives, annotation.as_ref(), vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), + None, )?; // find type vars within this return type annotation let type_vars_within = @@ -1735,7 +1741,7 @@ impl TopLevelComposer { .iter() .map(|(_, ty)| { unifier.get_instantiations(*ty).unwrap_or_else(|| { - if let TypeEnum::TVar { name, loc, .. } = &*unifier.get_ty(*ty) + if let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty) { let rigid = unifier.get_fresh_rigid_var(*name, *loc).0; no_ranges.push(rigid); diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 9255090..0d7063f 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,3 +1,4 @@ +use crate::symbol_resolver::SymbolValue; use super::*; #[derive(Clone, Debug)] @@ -12,6 +13,16 @@ pub enum TypeAnnotation { // can only be CustomClassKind Virtual(Box), TypeVar(Type), + /// A constant used in the context of a const-generic variable. + Constant { + /// The non-type variable associated with this constant. + /// + /// Invoking [Unifier::get_ty] on this type will return a [TypeEnum::TVar] representing the + /// const generic variable of which this constant is associated with. + ty: Type, + /// The constant value of this constant. + value: SymbolValue + }, List(Box), Tuple(Vec), } @@ -47,6 +58,7 @@ impl TypeAnnotation { } ) } + Constant { value, .. } => format!("Const({value})"), Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)), List(ty) => format!("list[{}]", ty.stringify(unifier)), Tuple(types) => { @@ -56,6 +68,12 @@ impl TypeAnnotation { } } +/// Parses an AST expression `expr` into a [TypeAnnotation]. +/// +/// * `locked` - A [HashMap] containing the IDs of known definitions, mapped to a [Vec] of all +/// generic variables associated with the definition. +/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass +/// [None] when this function is invoked externally. pub fn parse_ast_to_type_annotation_kinds( resolver: &(dyn SymbolResolver + Send + Sync), top_level_defs: &[Arc>], @@ -64,6 +82,7 @@ pub fn parse_ast_to_type_annotation_kinds( expr: &ast::Expr, // the key stores the type_var of this topleveldef::class, we only need this field here locked: HashMap>, + type_var: Option, ) -> Result { let name_handle = |id: &StrRef, unifier: &mut Unifier, @@ -161,7 +180,8 @@ pub fn parse_ast_to_type_annotation_kinds( } let result = params_ast .iter() - .map(|x| { + .enumerate() + .map(|(idx, x)| { parse_ast_to_type_annotation_kinds( resolver, top_level_defs, @@ -172,6 +192,7 @@ pub fn parse_ast_to_type_annotation_kinds( locked.insert(obj_id, type_vars.clone()); locked.clone() }, + Some(type_vars[idx]), ) }) .collect::, _>>()?; @@ -190,6 +211,7 @@ pub fn parse_ast_to_type_annotation_kinds( }; Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) }; + match &expr.node { ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked), // virtual @@ -205,6 +227,7 @@ pub fn parse_ast_to_type_annotation_kinds( primitives, slice.as_ref(), locked, + None, )?; if !matches!(def, TypeAnnotation::CustomClass { .. }) { unreachable!("must be concretized custom class kind in the virtual") @@ -225,6 +248,7 @@ pub fn parse_ast_to_type_annotation_kinds( primitives, slice.as_ref(), locked, + None, )?; Ok(TypeAnnotation::List(def_ann.into())) } @@ -242,6 +266,7 @@ pub fn parse_ast_to_type_annotation_kinds( primitives, slice.as_ref(), locked, + None, )?; let id = if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() { @@ -275,6 +300,7 @@ pub fn parse_ast_to_type_annotation_kinds( primitives, e, locked.clone(), + None, ) }) .collect::, _>>()?; @@ -290,6 +316,31 @@ pub fn parse_ast_to_type_annotation_kinds( } } + ast::ExprKind::Constant { value, .. } => { + let type_var = type_var.expect("Expect type variable to be present"); + + let ntv_ty_enum = unifier.get_ty_immutable(type_var); + let TypeEnum::TVar { range: underlying_ty, .. } = ntv_ty_enum.as_ref() else { + unreachable!() + }; + let underlying_ty = underlying_ty[0]; + + let value = SymbolValue::from_constant(value, underlying_ty, primitives, unifier)?; + + if matches!(value, SymbolValue::Str(_) | SymbolValue::Tuple(_) | SymbolValue::OptionSome(_)) { + return Err(format!( + "expression {} is not allowed for constant type annotation (at {})", + value.to_string(), + expr.location + )) + } + + Ok(TypeAnnotation::Constant { + ty: type_var, + value, + }) + } + _ => Err(format!("unsupported expression for type annotation (at {})", expr.location)), } } @@ -308,94 +359,130 @@ pub fn get_type_from_type_annotation_kinds( TypeAnnotation::CustomClass { id: obj_id, params } => { let def_read = top_level_defs[obj_id.0].read(); let class_def: &TopLevelDef = def_read.deref(); - if let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def { - if type_vars.len() != params.len() { - Err(format!( - "unexpected number of type parameters: expected {} but got {}", - type_vars.len(), - params.len() - )) - } else { - let param_ty = params - .iter() - .map(|x| { - get_type_from_type_annotation_kinds( - top_level_defs, - unifier, - primitives, - x, - subst_list - ) - }) - .collect::, _>>()?; + let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def else { + unreachable!("should be class def here") + }; - let subst = { - // check for compatible range - // TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check - let mut result: HashMap = HashMap::new(); - for (tvar, p) in type_vars.iter().zip(param_ty) { - if let TypeEnum::TVar { id, range, fields: None, name, loc } = - unifier.get_ty(*tvar).as_ref() - { - let ok: bool = { - // create a temp type var and unify to check compatibility - p == *tvar || { - let temp = unifier.get_fresh_var_with_range( - range.as_slice(), - *name, - *loc, - ); - unifier.unify(temp.0, p).is_ok() - } - }; - if ok { - result.insert(*id, p); - } else { - return Err(format!( - "cannot apply type {} to type variable with id {:?}", - unifier.internal_stringify( - p, - &mut |id| format!("class{}", id), - &mut |id| format!("typevar{}", id), - &mut None - ), - *id - )); + if type_vars.len() != params.len() { + return Err(format!( + "unexpected number of type parameters: expected {} but got {}", + type_vars.len(), + params.len() + )) + } + + let param_ty = params + .iter() + .map(|x| { + get_type_from_type_annotation_kinds( + top_level_defs, + unifier, + primitives, + x, + subst_list + ) + }) + .collect::, _>>()?; + + let subst = { + // check for compatible range + // TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check + let mut result: HashMap = HashMap::new(); + for (tvar, p) in type_vars.iter().zip(param_ty) { + match unifier.get_ty(*tvar).as_ref() { + TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false } => { + let ok: bool = { + // create a temp type var and unify to check compatibility + p == *tvar || { + let temp = unifier.get_fresh_var_with_range( + range.as_slice(), + *name, + *loc, + ); + unifier.unify(temp.0, p).is_ok() } + }; + if ok { + result.insert(*id, p); } else { - unreachable!("must be generic type var") + return Err(format!( + "cannot apply type {} to type variable with id {:?}", + unifier.internal_stringify( + p, + &mut |id| format!("class{}", id), + &mut |id| format!("typevar{}", id), + &mut None + ), + *id + )); } } - result - }; - let mut tobj_fields = methods - .iter() - .map(|(name, ty, _)| { - let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - // methods are immutable - (*name, (subst_ty, false)) - }) - .collect::>(); - tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| { - let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); - (*name, (subst_ty, *mutability)) - })); - let need_subst = !subst.is_empty(); - let ty = unifier.add_ty(TypeEnum::TObj { - obj_id: *obj_id, - fields: tobj_fields, - params: subst, - }); - if need_subst { - subst_list.as_mut().map(|wl| wl.push(ty)); + + TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => { + let ty = range[0]; + let ok: bool = { + // create a temp type var and unify to check compatibility + p == *tvar || { + let temp = unifier.get_fresh_const_generic_var( + ty, + *name, + *loc, + ); + unifier.unify(temp.0, p).is_ok() + } + }; + if ok { + result.insert(*id, p); + } else { + return Err(format!( + "cannot apply type {} to type variable {}", + unifier.stringify(p), + name.unwrap_or_else(|| format!("typevar{id}").into()), + )) + } + } + + _ => unreachable!("must be generic type var"), } - Ok(ty) } - } else { - unreachable!("should be class def here") + result + }; + let mut tobj_fields = methods + .iter() + .map(|(name, ty, _)| { + let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + // methods are immutable + (*name, (subst_ty, false)) + }) + .collect::>(); + tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| { + let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + (*name, (subst_ty, *mutability)) + })); + let need_subst = !subst.is_empty(); + let ty = unifier.add_ty(TypeEnum::TObj { + obj_id: *obj_id, + fields: tobj_fields, + params: subst, + }); + if need_subst { + subst_list.as_mut().map(|wl| wl.push(ty)); } + Ok(ty) } TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), + TypeAnnotation::Constant { ty, value, .. } => { + let ty_enum = unifier.get_ty(*ty); + let (ty, loc) = match &*ty_enum { + TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } => { + (ntv_underlying_ty[0], loc) + } + _ => unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name()), + }; + + let var = unifier.get_fresh_constant(value.clone(), ty, *loc); + Ok(var) + } TypeAnnotation::Virtual(ty) => { let ty = get_type_from_type_annotation_kinds( top_level_defs, @@ -470,7 +557,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec {} + TypeAnnotation::Primitive(..) | TypeAnnotation::Constant { .. } => {} } result } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 393de14..264bb17 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -134,6 +134,17 @@ pub enum TypeEnum { range: Vec, name: Option, loc: Option, + /// Whether this type variable refers to a const-generic variable. + is_const_generic: bool, + }, + + /// A constant for substitution into a const generic variable. + TConstant { + /// The value of the constant. + value: SymbolValue, + /// The underlying type of the value. + ty: Type, + loc: Option, }, /// A tuple type. @@ -178,6 +189,7 @@ impl TypeEnum { match self { TypeEnum::TRigidVar { .. } => "TRigidVar", TypeEnum::TVar { .. } => "TVar", + TypeEnum::TConstant { .. } => "TConstant", TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TList { .. } => "TList", TypeEnum::TObj { .. } => "TObj", @@ -263,6 +275,7 @@ impl Unifier { fields: Some(fields), name: None, loc: None, + is_const_generic: false, }) } @@ -336,7 +349,33 @@ impl Unifier { let id = self.var_id + 1; self.var_id += 1; let range = range.to_vec(); - (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc }), id) + (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false }), id) + } + + /// Returns a fresh type representing a constant generic variable with the given underlying type + /// `ty`. + pub fn get_fresh_const_generic_var( + &mut self, + ty: Type, + name: Option, + loc: Option, + ) -> (Type, u32) { + let id = self.var_id + 1; + self.var_id += 1; + (self.add_ty(TypeEnum::TVar { id, range: vec![ty], fields: None, name, loc, is_const_generic: true }), id) + } + + /// Returns a fresh type representing a [fresh constant][TypeEnum::TConstant] with the given + /// `value` and type `ty`. + pub fn get_fresh_constant( + &mut self, + value: SymbolValue, + ty: Type, + loc: Option, + ) -> Type { + assert!(matches!(self.get_ty(ty).as_ref(), TypeEnum::TObj { .. })); + + self.add_ty(TypeEnum::TConstant { ty, value, loc }) } /// Unification would not unify rigid variables with other types, but we want to do this for @@ -412,7 +451,7 @@ impl Unifier { pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { use TypeEnum::*; match &*self.get_ty(a) { - TRigidVar { .. } => true, + TRigidVar { .. } | TConstant { .. } => true, TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, TList { ty } => self.is_concrete(*ty, allowed_typevars), @@ -560,8 +599,8 @@ impl Unifier { }; match (&*ty_a, &*ty_b) { ( - TVar { fields: fields1, id, name: name1, loc: loc1, .. }, - TVar { fields: fields2, id: id2, name: name2, loc: loc2, .. }, + TVar { fields: fields1, id, name: name1, loc: loc1, is_const_generic: false, .. }, + TVar { fields: fields2, id: id2, name: name2, loc: loc2, is_const_generic: false, .. }, ) => { let new_fields = match (fields1, fields2) { (None, None) => None, @@ -616,10 +655,11 @@ impl Unifier { range, name: name1.or(*name2), loc: loc1.or(*loc2), + is_const_generic: false, }), ); } - (TVar { fields: None, range, .. }, _) => { + (TVar { fields: None, range, is_const_generic: false, .. }, _) => { // We check for the range of the type variable to see if unification is allowed. // Note that although b may be compatible with a, we may have to constrain type // variables in b to make sure that instantiations of b would always be compatible @@ -636,7 +676,7 @@ impl Unifier { self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } - (TVar { fields: Some(fields), range, .. }, TTuple { ty }) => { + (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => { let len = ty.len() as i32; for (k, v) in fields.iter() { match *k { @@ -666,7 +706,7 @@ impl Unifier { self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } - (TVar { fields: Some(fields), range, .. }, TList { ty }) => { + (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => { for (k, v) in fields.iter() { match *k { RecordKey::Int(_) => { @@ -681,6 +721,35 @@ impl Unifier { self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } + + (TVar { id: id1, range: ty1, is_const_generic: true, .. }, TVar { id: id2, range: ty2, .. }) => { + let ty1 = ty1[0]; + let ty2 = ty2[0]; + + if id1 != id2 { + self.unify_impl(ty1, ty2, false)?; + } + + self.set_a_to_b(a, b); + } + + (TVar { range: ty1, is_const_generic: true, .. }, TConstant { ty: ty2, .. }) => { + let ty1 = ty1[0]; + + self.unify_impl(ty1, *ty2, false)?; + self.set_a_to_b(a, b); + } + + (TConstant { value: val1, ty: ty1, .. }, TConstant { value: val2, ty: ty2, .. }) => { + if val1 != val2 { + eprintln!("VALUE MISMATCH: lhs={val1:?} rhs={val2:?} eq={}", val1 == val2); + return self.incompatible_types(a, b) + } + self.unify_impl(*ty1, *ty2, false)?; + + self.set_a_to_b(a, b); + } + (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { if ty1.len() != ty2.len() { return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); @@ -775,7 +844,14 @@ impl Unifier { if id1 != id2 { self.incompatible_types(a, b)?; } - for (x, y) in zip(params1.values(), params2.values()) { + + // Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits + // all K-V pairs "in arbitrary order" + let (tv1, tv2) = ( + params1.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(), + params2.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(), + ); + for (x, y) in zip(tv1, tv2) { if self.unify_impl(*x, *y, false).is_err() { return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); }; @@ -928,6 +1004,9 @@ impl Unifier { }; n } + TypeEnum::TConstant { value, .. } => { + format!("const({value})") + } TypeEnum::TTuple { ty } => { let mut fields = ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); @@ -983,8 +1062,8 @@ impl Unifier { } } + /// Unifies `a` and `b` together, and set the value to the value of `b`. fn set_a_to_b(&mut self, a: Type, b: Type) { - // unify a and b together, and set the value to b's value. let table = &mut self.unification_table; let ty_b = table.probe_value(b).clone(); table.unify(a, b); @@ -1207,6 +1286,7 @@ impl Unifier { range, name: name2.or(*name), loc: loc2.or(*loc), + is_const_generic: false, }; Ok(Some(self.unification_table.new_key(ty.into()))) } diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 9a753d0..b13f3ef 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -9,7 +9,7 @@ import pathlib from numpy import int32, int64, uint32, uint64 from scipy import special -from typing import TypeVar, Generic +from typing import TypeVar, Generic, Any T = TypeVar('T') class Option(Generic[T]): @@ -44,6 +44,12 @@ def Some(v: T) -> Option[T]: none = Option(None) +class _ConstGenericDummy: + pass + +def ConstGeneric(name, constraint): + return TypeVar(name, _ConstGenericDummy, constraint) + def round_away_zero(x): if x >= 0.0: return math.floor(x + 0.5) @@ -99,6 +105,7 @@ def patch(module): module.uint32 = uint32 module.uint64 = uint64 module.TypeVar = TypeVar + module.ConstGeneric = ConstGeneric module.Generic = Generic module.extern = extern module.Option = Option diff --git a/nac3standalone/demo/src/const_generic.py b/nac3standalone/demo/src/const_generic.py new file mode 100644 index 0000000..2ddca99 --- /dev/null +++ b/nac3standalone/demo/src/const_generic.py @@ -0,0 +1,50 @@ +A = ConstGeneric("A", int32) +B = ConstGeneric("B", uint32) +T = TypeVar("T") + +class ConstGenericClass(Generic[A]): + def __init__(self): + pass + +class ConstGeneric2Class(Generic[A, B]): + def __init__(self): + pass + +class HybridGenericClass2(Generic[A, T]): + pass + +class HybridGenericClass3(Generic[T, A, B]): + pass + +def make_generic_2() -> ConstGenericClass[2]: + return ... + +def make_generic2_1_2() -> ConstGeneric2Class[1, 2]: + return ... + +def make_hybrid_class_2_int32() -> HybridGenericClass2[2, int32]: + return ... + +def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, 0, 1]: + return ... + +def consume_generic_2(instance: ConstGenericClass[2]): + pass + +def consume_generic2_1_2(instance: ConstGeneric2Class[1, 2]): + pass + +def consume_hybrid_class_2_i32(instance: HybridGenericClass2[2, int32]): + pass + +def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, 0, 1]): + pass + +def f(): + consume_generic_2(make_generic_2()) + consume_generic2_1_2(make_generic2_1_2()) + consume_hybrid_class_2_i32(make_hybrid_class_2_int32()) + consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1()) + +def run() -> int32: + return 0 \ No newline at end of file diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index a6dc749..3a91c47 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -25,7 +25,7 @@ use nac3core::{ }, }; use nac3parser::{ - ast::{Expr, ExprKind, StmtKind}, + ast::{Constant, Expr, ExprKind, StmtKind, StrRef}, parser, }; @@ -83,6 +83,11 @@ fn handle_typevar_definition( match &func.node { ExprKind::Name { id, .. } if id == &"TypeVar".into() => { + let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else { + unreachable!("Expected string constant for first parameter of `TypeVar`, got {:?}", &args[0].node) + }; + let generic_name: StrRef = ty_name.to_string().into(); + let constraints = args .iter() .skip(1) @@ -94,13 +99,50 @@ fn handle_typevar_definition( primitives, x, Default::default(), + None, )?; get_type_from_type_annotation_kinds( def_list, unifier, primitives, &ty, &mut None ) }) .collect::, _>>()?; - Ok(unifier.get_fresh_var_with_range(&constraints, None, None).0) + let loc = func.location; + + if constraints.len() == 1 { + return Err(format!("A single constraint is not allowed (at {})", loc)) + } + + Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).0) + } + + ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => { + if args.len() != 2 { + return Err(format!("Expected 2 arguments for `ConstGeneric`, got {}", args.len())) + } + + let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else { + return Err(format!( + "Expected string constant for first parameter of `ConstGeneric`, got {:?}", + &args[0].node + )) + }; + let generic_name: StrRef = ty_name.to_string().into(); + + let ty = parse_ast_to_type_annotation_kinds( + resolver, + def_list, + unifier, + primitives, + &args[1], + Default::default(), + None, + )?; + let constraint = get_type_from_type_annotation_kinds( + def_list, unifier, primitives, &ty, &mut None + )?; + let loc = func.location; + + Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).0) } _ => Err(format!(