diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 7e3b132..9dac931 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -486,9 +486,9 @@ impl TopLevelComposer { // insert the ancestors to the def list for (class_def, _) in self.definition_ast_list.iter_mut() { let mut class_def = class_def.write(); - let (class_ancestors, class_id) = { - if let TopLevelDef::Class { ancestors, object_id, .. } = class_def.deref_mut() { - (ancestors, *object_id) + let (class_ancestors, class_id, class_type_vars) = { + if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = class_def.deref_mut() { + (ancestors, *object_id, type_vars) } else { continue; } @@ -499,7 +499,7 @@ impl TopLevelComposer { // insert self type annotation to the front of the vector to maintain the order class_ancestors - .insert(0, make_self_type_annotation(temp_def_list.as_slice(), class_id)?); + .insert(0, make_self_type_annotation(class_type_vars.as_slice(), class_id)); } Ok(()) @@ -862,7 +862,7 @@ impl TopLevelComposer { }; type_var_to_concrete_def.insert( dummy_func_arg.ty, - make_self_type_annotation(temp_def_list, class_id)?, + make_self_type_annotation(class_type_vars_def.as_slice(), class_id), ); result.push(dummy_func_arg); } @@ -916,7 +916,7 @@ impl TopLevelComposer { let dummy_return_type = unifier.get_fresh_var().0; type_var_to_concrete_def.insert( dummy_return_type, - make_self_type_annotation(temp_def_list, class_id)?, + make_self_type_annotation(class_type_vars_def.as_slice(), class_id), ); dummy_return_type } @@ -1035,8 +1035,9 @@ impl TopLevelComposer { { if class_method_name == anc_method_name { // ignore and handle self + // if is __init__ method, no need to check return type let ok = class_method_name == "__init__" - && Self::check_overload_function_type( + || Self::check_overload_function_type( *class_method_ty, *anc_method_ty, unifier, diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index d4b3538..f4c82eb 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -141,3 +141,55 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s } } } + +#[test_case( + vec![ + indoc! {" + class A: + def __init__(): + pass + "}, + indoc! {" + class B(C): + def __init__(): + pass + "}, + indoc! {" + class C(A): + def __init__(): + pass + "}, + indoc! {" + def foo(a: A): + pass + "}, + ] +)] +fn test_simple_class_analyze(source: Vec<&str>) { + let mut composer = TopLevelComposer::new(); + + let resolver = Arc::new(Mutex::new(Box::new(Resolver { + id_to_def: Default::default(), + id_to_type: Default::default(), + class_names: Default::default(), + }) as Box)); + + for s in source { + let ast = parse_program(s).unwrap(); + let ast = ast[0].clone(); + + let (id, def_id) = composer.register_top_level(ast, Some(resolver.clone())).unwrap(); + resolver.lock().add_id_def(id, def_id); + } + + composer.start_analysis().unwrap(); + + // for (i, (def, _)) in composer.definition_ast_list.into_iter().enumerate() { + // let def = &*def.read(); + // if let TopLevelDef::Function { signature, name, .. } = def { + // let ty_str = composer.unifier.stringify(*signature, &mut |id| id.to_string(), &mut |id| id.to_string()); + // assert_eq!(ty_str, tys[i]); + // assert_eq!(name, names[i]); + // } + // } +} \ No newline at end of file diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index eba7873..3834505 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -296,22 +296,10 @@ pub fn get_type_from_type_annotation_kinds( /// but equivalent to seeing `A[T, V]` inside the class def body ast, where although we /// create copies of `T` and `V`, we will find them out as occured type vars in the analyze_class() /// and unify them with the class generic `T`, `V` -pub fn make_self_type_annotation( - top_level_defs: &[Arc>], - def_id: DefinitionId, -) -> Result { - let obj_def = - top_level_defs.get(def_id.0).ok_or_else(|| "invalid definition id".to_string())?; - let obj_def = obj_def.read(); - let obj_def = obj_def.deref(); - - if let TopLevelDef::Class { type_vars, .. } = obj_def { - Ok(TypeAnnotation::CustomClassKind { - id: def_id, - params: type_vars.iter().map(|ty| TypeAnnotation::TypeVarKind(*ty)).collect_vec(), - }) - } else { - unreachable!("must be top level class def here") +pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation { + TypeAnnotation::CustomClassKind { + id: object_id, + params: type_vars.iter().map(|ty| TypeAnnotation::TypeVarKind(*ty)).collect_vec(), } }