diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index fda8cea..50eaccc 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -84,31 +84,36 @@ impl TopLevelComposer { Err(format!("no method {} in the current class", method_name)) } - /// get all base class def id of a class, including itself - pub fn get_all_base( - child: DefinitionId, + /// get all base class def id of a class, excluding itself. \ + /// this function should called only after the direct parent is set + /// and before all the ancestors are set + /// and when we allow single inheritance \ + /// the order of the returned list is from the child to the deepest ancestor + pub fn get_all_ancestors_helper( + child: &TypeAnnotation, temp_def_list: &[Arc>], - ) -> Vec { - let mut result: Vec = Vec::new(); - let child_def = temp_def_list.get(child.0).unwrap(); - let child_def = child_def.read(); - let child_def = child_def.deref(); - - if let TopLevelDef::Class { ancestors, .. } = child_def { - for a in ancestors { - if let TypeAnnotation::CustomClassKind { id, .. } = a { - if *id != child { - result.extend(Self::get_all_base(*id, temp_def_list)); - } - } else { - unreachable!("must be class type annotation type") - } - } - } else { - unreachable!("this function should only be called with class def id as parameter") + ) -> Vec { + let mut result: Vec = Vec::new(); + let mut parent = Self::get_parent(child, temp_def_list); + while let Some(p) = parent { + parent = Self::get_parent(&p, temp_def_list); + result.push(p); } - - result.push(child); result } + + fn get_parent( + child: &TypeAnnotation, + temp_def_list: &[Arc>], + ) -> Option { + let child_id = + if let TypeAnnotation::CustomClassKind { id, .. } = child { *id } else { return None }; + let child_def = temp_def_list.get(child_id.0).unwrap(); + let child_def = child_def.read(); + if let TopLevelDef::Class { ancestors, .. } = &*child_def { + Some(ancestors[0].clone()) + } else { + None + } + } } diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 4d67f55..e14772b 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -16,7 +16,7 @@ use itertools::{izip, Itertools}; use parking_lot::RwLock; use rustpython_parser::ast::{self, Stmt}; -#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] pub struct DefinitionId(pub usize); mod type_annotation; @@ -393,6 +393,8 @@ impl TopLevelComposer { fn analyze_top_level_class_bases(&mut self) -> Result<(), String> { let temp_def_list = self.extract_def_list(); let unifier = self.unifier.borrow_mut(); + + // first, only push direct parent into the list for (class_def, class_ast) in self.definition_ast_list.iter_mut() { let mut class_def = class_def.write(); let (class_bases, class_ancestors, class_resolver, class_id, class_type_vars) = { @@ -414,7 +416,6 @@ impl TopLevelComposer { let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.deref(); - // only allow single inheritance let mut has_base = false; for b in class_bases { // type vars have already been handled, so skip on `Generic[...]` @@ -436,6 +437,8 @@ impl TopLevelComposer { } has_base = true; + // the function parse_ast_to make sure that no type var occured in + // bast_ty if it is a CustomClassKind let base_ty = parse_ast_to_type_annotation_kinds( class_resolver, &temp_def_list, @@ -444,36 +447,8 @@ impl TopLevelComposer { b, )?; - if let TypeAnnotation::CustomClassKind { id, .. } = &base_ty { - // TODO: change a way to check to prevent cyclic base class - let all_base = Self::get_all_base(*id, &temp_def_list); - if all_base.contains(&class_id) { - return Err("cyclic base detected".into()); - } - - // NOTE: type vars occured in the base type annotation must be - // a subset of the class decalred type annotation - let base_type_var_ids = get_type_var_contained_in_type_annotation(&base_ty) - .into_iter() - .map(|x| { - if let TypeAnnotation::TypeVarKind(ty) = x { - get_var_id(ty, unifier) - } else { - unreachable!("must be type var") - } - }) - .collect::, _>>()?; - let class_generic_type_var_ids = class_type_vars - .iter() - .map(|x| get_var_id(*x, unifier)) - .collect::, _>>()?; - if class_generic_type_var_ids.is_superset(&base_type_var_ids) { - // TODO: this base confirmed - // in what order to push all the ancestors? - // class_ancestors.push(base_ty); - } else { - return Err("base class generic type parameter must be declared".into()); - } + if let TypeAnnotation::CustomClassKind { .. } = &base_ty { + class_ancestors.push(base_ty); } else { return Err("class base declaration can only be custom class".into()); } @@ -481,8 +456,57 @@ impl TopLevelComposer { // TODO: ancestors should include all bases, need to rewrite // push self to the ancestors - class_ancestors.push(make_self_type_annotation(&temp_def_list, class_id)?) + // class_ancestors.push(make_self_type_annotation(&temp_def_list, class_id)?) } + + // second get all ancestors + let mut ancestors_store: HashMap> = Default::default(); + for (class_def, class_ast) in self.definition_ast_list.iter_mut() { + let mut class_def = class_def.write(); + let (class_ancestors, class_id) = { + if let TopLevelDef::Class { ancestors, object_id, .. } = class_def.deref_mut() { + if let Some(ast::Located { node: ast::StmtKind::ClassDef { .. }, .. }) = + class_ast + { + (ancestors, *object_id) + } else { + unreachable!("must be both class") + } + } else { + continue; + } + }; + ancestors_store.insert( + class_id, + Self::get_all_ancestors_helper(&class_ancestors[0], temp_def_list.as_slice()), + ); + } + + // insert the ancestors to the def list + for (class_def, class_ast) in self.definition_ast_list.iter_mut() { + let mut class_def = class_def.write(); + let (class_ancestors, class_id) = { + if let TopLevelDef::Class { ancestors, object_id, .. } = class_def.deref_mut() { + if let Some(ast::Located { node: ast::StmtKind::ClassDef { .. }, .. }) = + class_ast + { + (ancestors, *object_id) + } else { + unreachable!("must be both class") + } + } else { + continue; + } + }; + + let ans = ancestors_store.get_mut(&class_id).unwrap(); + class_ancestors.append(ans); + + // insert self type annotation + class_ancestors + .insert(0, make_self_type_annotation(temp_def_list.as_slice(), class_id)?); + } + Ok(()) } diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index e8306ae..fc1759c 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -95,7 +95,7 @@ pub fn parse_ast_to_type_annotation_kinds( params_ast.len() )); } - params_ast + let result = params_ast .into_iter() .map(|x| { parse_ast_to_type_annotation_kinds( @@ -106,7 +106,19 @@ pub fn parse_ast_to_type_annotation_kinds( x, ) }) - .collect::, _>>()? + .collect::, _>>()?; + + // make sure the result do not contain any type vars + let no_type_var = result + .iter() + .all(|x| get_type_var_contained_in_type_annotation(x).is_empty()); + if no_type_var { + result + } else { + return Err("application of type vars to generic class \ + not currently supported" + .into()); + } }; // allow type var in class generic application list @@ -156,9 +168,6 @@ pub fn get_type_from_type_annotation_kinds( .iter() .map(|x| { if let TypeEnum::TVar { id, .. } = unifier.get_ty(*x).as_ref() { - // this is for the class generic application, - // we only need the information for the copied type var - // associated with the class *id } else { unreachable!()