/// obtain class and function signature from AST use super::context::TopLevelContext; use super::primitives::*; use super::typedef::*; use rustpython_parser::ast::{ExpressionType, Statement, StatementType, StringGroup}; use std::collections::HashMap; fn typename_from_expr<'b: 'a, 'a>(typenames: &mut Vec<&'a str>, expr: &'b ExpressionType) { match expr { ExpressionType::Identifier { name } => typenames.push(&name), ExpressionType::String { value } => match value { StringGroup::Constant { value } => typenames.push(&value), _ => unimplemented!(), }, ExpressionType::Subscript { a, b } => { typename_from_expr(typenames, &b.node); typename_from_expr(typenames, &a.node) } _ => unimplemented!(), } } fn typename_from_fn<'b: 'a, 'a>(typenames: &mut Vec<&'a str>, fun: &'b StatementType) { match fun { StatementType::FunctionDef { args, returns, .. } => { for arg in args.args.iter() { if let Some(ann) = &arg.annotation { typename_from_expr(typenames, &ann.node); } } if let Some(returns) = &returns { typename_from_expr(typenames, &returns.node); } } _ => unreachable!(), } } fn name_from_expr<'b: 'a, 'a>(expr: &'b ExpressionType) -> &'a str { match &expr { ExpressionType::Identifier { name } => &name, ExpressionType::String { value } => match value { StringGroup::Constant { value } => &value, _ => unimplemented!(), }, _ => unimplemented!(), } } fn type_from_expr<'b: 'a, 'a>( ctx: &'a TopLevelContext, expr: &'b ExpressionType, ) -> Result { match expr { ExpressionType::Identifier { name } => { ctx.get_type(name).ok_or_else(|| "no such type".into()) } ExpressionType::String { value } => match value { StringGroup::Constant { value } => { ctx.get_type(&value).ok_or_else(|| "no such type".into()) } _ => unimplemented!(), }, ExpressionType::Subscript { a, b } => { if let ExpressionType::Identifier { name } = &a.node { match name.as_str() { "list" => { let ty = type_from_expr(ctx, &b.node)?; Ok(TypeEnum::ParametricType(LIST_TYPE, vec![ty]).into()) } "tuple" => { if let ExpressionType::Tuple { elements } = &b.node { let ty_list: Result, _> = elements .iter() .map(|v| type_from_expr(ctx, &v.node)) .collect(); Ok(TypeEnum::ParametricType(TUPLE_TYPE, ty_list?).into()) } else { Err("unsupported format".into()) } } _ => Err("no such parameterized type".into()), } } else { // we require a to be an identifier, for a[b] Err("unsupported format".into()) } } _ => Err("unsupported format".into()), } } pub fn get_typenames<'b: 'a, 'a>(stmts: &'b [Statement]) -> (Vec<&'a str>, Vec<&'a str>) { let mut classes = Vec::new(); let mut typenames = Vec::new(); for stmt in stmts.iter() { match &stmt.node { StatementType::ClassDef { name, body, bases, .. } => { // check if class is not duplicated... // and annotations classes.push(&name[..]); for base in bases.iter() { let name = name_from_expr(&base.node); typenames.push(name); } // may check if fields/functions are not duplicated for stmt in body.iter() { match &stmt.node { StatementType::AnnAssign { annotation, .. } => { typename_from_expr(&mut typenames, &annotation.node) } StatementType::FunctionDef { .. } => { typename_from_fn(&mut typenames, &stmt.node); } _ => unimplemented!(), } } } StatementType::FunctionDef { .. } => { // may check annotations typename_from_fn(&mut typenames, &stmt.node); } _ => (), } } let mut unknowns = Vec::new(); for n in typenames { if !PRIMITIVES.contains(&n) && !classes.contains(&n) && !unknowns.contains(&n) { unknowns.push(n); } } (classes, unknowns) } fn resolve_function<'b: 'a, 'a>( ctx: &'a TopLevelContext, fun: &'b StatementType, ) -> Result { if let StatementType::FunctionDef { args, returns, .. } = &fun { let args: Result, _> = args .args .iter() .map(|arg| type_from_expr(ctx, &arg.annotation.as_ref().unwrap().node)) .collect(); let args = args?; let result = match returns { Some(v) => Some(type_from_expr(ctx, &v.node)?), None => None, }; Ok(FnDef { args, result }) } else { unreachable!() } } pub fn resolve_signatures<'b: 'a, 'a>(ctx: &'a mut TopLevelContext<'a>, stmts: &'b [Statement]) { for stmt in stmts.iter() { match &stmt.node { StatementType::ClassDef { name, bases, body, .. } => { let mut parents = Vec::new(); for base in bases.iter() { let name = name_from_expr(&base.node); let c = ctx.get_type(name).unwrap(); let id = if let TypeEnum::ClassType(id) = c.as_ref() { *id } else { unreachable!() }; parents.push(id); } let mut fields = HashMap::new(); let mut functions = HashMap::new(); for stmt in body.iter() { match &stmt.node { StatementType::AnnAssign { target, annotation, .. } => { let name = name_from_expr(&target.node); let ty = type_from_expr(ctx, &annotation.node).unwrap(); fields.insert(name, ty); } StatementType::FunctionDef { name, .. } => { functions.insert(&name[..], resolve_function(ctx, &stmt.node).unwrap()); } _ => unimplemented!(), } } let class = ctx.get_type(name).unwrap(); let class = if let TypeEnum::ClassType(id) = class.as_ref() { ctx.get_class_def_mut(*id) } else { unreachable!() }; class.parents.extend_from_slice(&parents); class.base.fields.clone_from(&fields); class.base.methods.clone_from(&functions); } StatementType::FunctionDef { name, .. } => { ctx.add_fn(&name[..], resolve_function(ctx, &stmt.node).unwrap()); } _ => unimplemented!(), } } } #[cfg(test)] mod tests { use super::*; use indoc::indoc; use rustpython_parser::parser::parse_program; #[test] fn test_get_classes() { let ast = parse_program(indoc! {" class Foo: a: int32 b: Test def test(self, a: int32) -> Test2: return b class Bar(Foo, 'FooBar'): def test2(self, a: list[Foo]) -> Test2: return b def test3(self, a: list[FooBar2]) -> Test2: return b " }) .unwrap(); let (mut classes, mut unknowns) = get_typenames(&ast.statements); let classes_count = classes.len(); let unknowns_count = unknowns.len(); classes.sort(); unknowns.sort(); assert_eq!(classes.len(), classes_count); assert_eq!(unknowns.len(), unknowns_count); assert_eq!(&classes, &["Bar", "Foo"]); assert_eq!(&unknowns, &["FooBar", "FooBar2", "Test", "Test2"]); } }