diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 72776b4d..bb4ea3db 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -349,6 +349,8 @@ impl TopLevelComposer { } /// step 1, analyze the type vars associated with top level class + /// note that we make a duplicate of the type var returned by symbol resolver + /// since one top level type var may be used at multiple places fn analyze_top_level_class_type_var(&mut self) -> Result<(), String> { let def_list = &self.definition_ast_list; let temp_def_list = self.extract_def_list(); @@ -441,7 +443,7 @@ impl TopLevelComposer { .map(|x| { // must be type var here after previous check let dup = duplicate_type_var(unifier, x); - (dup.1, dup.0) + (dup.2, dup.0) }) .collect_vec(); @@ -458,6 +460,8 @@ impl TopLevelComposer { } /// step 2, base classes. + /// now that the type vars of all classes are done, handle base classes and + /// put Self class into the ancestors list. We only allow single inheritance fn analyze_top_level_class_bases(&mut self) -> Result<(), String> { let temp_def_list = self.extract_def_list(); for (class_def, class_ast) in self.definition_ast_list.iter_mut() { @@ -479,6 +483,7 @@ impl TopLevelComposer { let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.deref(); + // only allow single inheritance let mut has_base = false; for b in class_bases { // type vars have already been handled, so skip on `Generic[...]` @@ -508,8 +513,12 @@ impl TopLevelComposer { b, )?; - if let TypeAnnotation::CustomClassKind { .. } = &base_ty { - // TODO: check to prevent cyclic base class + if let TypeAnnotation::CustomClassKind { id, .. } = &base_ty { + // check to prevent cyclic base class + let all_base = Self::get_all_base(*id, &temp_def_list); + if all_base.contains(&class_id) { + return Err("cyclic base detected".into()); + } class_ancestors.push(base_ty); } else { return Err( @@ -905,4 +914,32 @@ impl TopLevelComposer { } Err(format!("no method {} in the current class", method_name)) } + + /// get all base class def id of a class, including it self + fn get_all_base( + child: DefinitionId, + temp_def_list: &[Arc>] + ) -> Vec { + let mut result: Vec = Vec::new(); + let child_def = temp_def_list.get(child.0).unwrap(); + let child_def = child_def.read(); + let child_def = child_def.deref(); + + if let TopLevelDef::Class { ancestors, .. } = child_def { + for a in ancestors { + if let TypeAnnotation::CustomClassKind { id, .. } = a { + if *id != child { + result.extend(Self::get_all_base(*id, temp_def_list)); + } + } else { + unreachable!("must be class type annotation type") + } + } + } else { + unreachable!("this function should only be called with class def id as parameter") + } + + result.push(child); + result + } }