From 935e7410fd53c6ca75d7b22372046253fca84194 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Thu, 26 Aug 2021 11:54:37 +0800 Subject: [PATCH] check type params in class generic base declaration --- nac3core/src/toplevel/mod.rs | 25 ++++++++++++++++++++---- nac3core/src/toplevel/type_annotation.rs | 17 +++++++++++++--- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index bb4ea3dbb..4d02635b6 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -466,13 +466,13 @@ impl TopLevelComposer { let temp_def_list = self.extract_def_list(); for (class_def, class_ast) in self.definition_ast_list.iter_mut() { let mut class_def = class_def.write(); - let (class_bases, class_ancestors, class_resolver, class_id) = { - if let TopLevelDef::Class { ancestors, resolver, object_id, .. } = class_def.deref_mut() { + let (class_bases, class_ancestors, class_resolver, class_id, class_type_vars) = { + if let TopLevelDef::Class { ancestors, resolver, object_id, type_vars, .. } = class_def.deref_mut() { if let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) = class_ast { - (bases, ancestors, resolver, *object_id) + (bases, ancestors, resolver, *object_id, type_vars) } else { unreachable!("must be both class") } @@ -519,10 +519,27 @@ impl TopLevelComposer { if all_base.contains(&class_id) { return Err("cyclic base detected".into()); } + + // find the intersection between type vars occured in the base class type parameter + // and the type vars occured in the class generic declaration + let type_var_occured_in_base = get_type_var_contained_in_type_annotation(&base_ty); + for type_ann in type_var_occured_in_base { + if let TypeAnnotation::TypeVarKind(id, ty) = type_ann { + for (ty_id, class_typvar_ty) in class_type_vars.iter() { + if id == *ty_id { + // if they refer to the same top level defined type var, we unify them together + self.unifier.unify(ty, *class_typvar_ty)?; + } + } + } else { + unreachable!("must be type var annotation") + } + } + class_ancestors.push(base_ty); } else { return Err( - "class base declaration can only be concretized custom class".into() + "class base declaration can only be custom class".into() ); } } diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 7839437f8..429cc470b 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -38,7 +38,11 @@ pub fn parse_ast_to_type_annotation_kinds( x => { if let Some(obj_id) = resolver.get_identifier_def(x) { let def = top_level_defs[obj_id.0].read(); - if let TopLevelDef::Class { .. } = &*def { + if let TopLevelDef::Class { type_vars, .. } = &*def { + // also check param number here + if !type_vars.is_empty() { + return Err(format!("expect {} type variable parameter but got 0", type_vars.len())) + } Ok(TypeAnnotation::CustomClassKind { id: obj_id, params: vec![], @@ -67,7 +71,7 @@ pub fn parse_ast_to_type_annotation_kinds( } }, - // TODO: subscript or call? + // TODO: subscript or call for virtual? ast::ExprKind::Subscript { value, slice, .. } if { matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "virtual") } => { @@ -93,7 +97,7 @@ pub fn parse_ast_to_type_annotation_kinds( .get_identifier_def(id) .ok_or_else(|| "unknown class name".to_string())?; let def = top_level_defs[obj_id.0].read(); - if let TopLevelDef::Class { .. } = &*def { + if let TopLevelDef::Class { type_vars, .. } = &*def { let param_type_infos = if let ast::ExprKind::Tuple { elts, .. } = &slice.node { elts.iter() .map(|v| { @@ -115,6 +119,13 @@ pub fn parse_ast_to_type_annotation_kinds( slice, )?] }; + if type_vars.len() != param_type_infos.len() { + return Err(format!( + "expect {} type parameters but got {}", + type_vars.len(), + param_type_infos.len() + )) + } // NOTE: allow type var in class generic application list Ok(TypeAnnotation::CustomClassKind { id: obj_id,