/// obtain class and function signature from AST use super::primitives::PRIMITIVES; use rustpython_parser::ast::{ExpressionType, Statement, StatementType, StringGroup}; fn typename_from_expr<'b: 'a, 'a>(typenames: &mut Vec<&'a str>, expr: &'b ExpressionType) { match expr { ExpressionType::Identifier { name } => typenames.push(&name), ExpressionType::String { value } => match value { StringGroup::Constant { value } => typenames.push(&value), _ => unimplemented!(), }, ExpressionType::Subscript { a, b } => { typename_from_expr(typenames, &b.node); typename_from_expr(typenames, &a.node) } _ => unimplemented!(), } } fn typename_from_fn<'b: 'a, 'a>(typenames: &mut Vec<&'a str>, fun: &'b StatementType) { match fun { StatementType::FunctionDef { args, returns, .. } => { for arg in args.args.iter() { if let Some(ann) = &arg.annotation { typename_from_expr(typenames, &ann.node); } } if let Some(returns) = &returns { typename_from_expr(typenames, &returns.node); } } _ => unreachable!(), } } pub fn resolve_classes<'b: 'a, 'a>(stmts: &'b [Statement]) -> (Vec<&'a str>, Vec<&'a str>) { let mut classes = Vec::new(); let mut typenames = Vec::new(); for stmt in stmts.iter() { match &stmt.node { StatementType::ClassDef { name, body, bases, .. } => { // check if class is not duplicated... // and annotations classes.push(&name[..]); for base in bases.iter() { let name = match &base.node { ExpressionType::Identifier { name } => name, ExpressionType::String { value } => match value { StringGroup::Constant { value } => value, _ => unimplemented!(), }, _ => unimplemented!(), }; typenames.push(&name[..]); } // may check if fields/functions are not duplicated for stmt in body.iter() { match &stmt.node { StatementType::AnnAssign { annotation, .. } => { typename_from_expr(&mut typenames, &annotation.node) } StatementType::FunctionDef { .. } => { typename_from_fn(&mut typenames, &stmt.node); } _ => unimplemented!(), } } } StatementType::FunctionDef { .. } => { // may check annotations typename_from_fn(&mut typenames, &stmt.node); } _ => (), } } let mut unknowns = Vec::new(); for n in typenames { if !PRIMITIVES.contains(&n) && !classes.contains(&n) && !unknowns.contains(&n) { unknowns.push(n); } } (classes, unknowns) } #[cfg(test)] mod tests { use super::*; use indoc::indoc; use rustpython_parser::parser::parse_program; #[test] fn test_get_classes() { let ast = parse_program(indoc! {" class Foo: a: int32 b: Test def test(self, a: int32) -> Test2: return b class Bar(Foo, 'FooBar'): def test2(self, a: list[Foo]) -> Test2: return b def test3(self, a: list[FooBar2]) -> Test2: return b " }) .unwrap(); let (mut classes, mut unknowns) = resolve_classes(&ast.statements); let classes_count = classes.len(); let unknowns_count = unknowns.len(); classes.sort(); unknowns.sort(); assert_eq!(classes.len(), classes_count); assert_eq!(unknowns.len(), unknowns_count); assert_eq!(&classes, &["Bar", "Foo"]); assert_eq!(&unknowns, &["FooBar", "FooBar2", "Test", "Test2"]); } }