/// obtain class and function signature from AST use super::context::TopLevelContext; use super::primitives::*; use super::typedef::*; use rustpython_parser::ast::{ ComprehensionKind, ExpressionType, Statement, StatementType, StringGroup, }; use std::collections::HashMap; // TODO: fix condition checking, return error message instead of panic... fn typename_from_expr<'a>(typenames: &mut Vec<&'a str>, expr: &'a ExpressionType) { match expr { ExpressionType::Identifier { name } => typenames.push(&name), ExpressionType::String { value } => match value { StringGroup::Constant { value } => typenames.push(&value), _ => unimplemented!(), }, ExpressionType::Subscript { a, b } => { typename_from_expr(typenames, &b.node); typename_from_expr(typenames, &a.node) } _ => unimplemented!(), } } fn typename_from_fn<'a>(typenames: &mut Vec<&'a str>, fun: &'a StatementType) { match fun { StatementType::FunctionDef { args, returns, .. } => { for arg in args.args.iter() { if let Some(ann) = &arg.annotation { typename_from_expr(typenames, &ann.node); } } if let Some(returns) = &returns { typename_from_expr(typenames, &returns.node); } } _ => unreachable!(), } } fn name_from_expr<'a>(expr: &'a ExpressionType) -> &'a str { match &expr { ExpressionType::Identifier { name } => &name, ExpressionType::String { value } => match value { StringGroup::Constant { value } => &value, _ => unimplemented!(), }, _ => unimplemented!(), } } fn type_from_expr<'a>(ctx: &'a TopLevelContext, expr: &'a ExpressionType) -> Result { match expr { ExpressionType::Identifier { name } => { ctx.get_type(name).ok_or_else(|| "no such type".into()) } ExpressionType::String { value } => match value { StringGroup::Constant { value } => { ctx.get_type(&value).ok_or_else(|| "no such type".into()) } _ => unimplemented!(), }, ExpressionType::Subscript { a, b } => { if let ExpressionType::Identifier { name } = &a.node { match name.as_str() { "list" => { let ty = type_from_expr(ctx, &b.node)?; Ok(TypeEnum::ParametricType(LIST_TYPE, vec![ty]).into()) } "tuple" => { if let ExpressionType::Tuple { elements } = &b.node { let ty_list: Result, _> = elements .iter() .map(|v| type_from_expr(ctx, &v.node)) .collect(); Ok(TypeEnum::ParametricType(TUPLE_TYPE, ty_list?).into()) } else { Err("unsupported format".into()) } } _ => Err("no such parameterized type".into()), } } else { // we require a to be an identifier, for a[b] Err("unsupported format".into()) } } _ => Err("unsupported format".into()), } } pub fn get_typenames<'a>(stmts: &'a [Statement]) -> (Vec<&'a str>, Vec<&'a str>) { let mut classes = Vec::new(); let mut typenames = Vec::new(); for stmt in stmts.iter() { match &stmt.node { StatementType::ClassDef { name, body, bases, .. } => { // check if class is not duplicated... // and annotations classes.push(&name[..]); for base in bases.iter() { let name = name_from_expr(&base.node); typenames.push(name); } // may check if fields/functions are not duplicated for stmt in body.iter() { match &stmt.node { StatementType::AnnAssign { annotation, .. } => { typename_from_expr(&mut typenames, &annotation.node) } StatementType::FunctionDef { .. } => { typename_from_fn(&mut typenames, &stmt.node); } _ => unimplemented!(), } } } StatementType::FunctionDef { .. } => { // may check annotations typename_from_fn(&mut typenames, &stmt.node); } _ => (), } } let mut unknowns = Vec::new(); for n in typenames { if !PRIMITIVES.contains(&n) && !classes.contains(&n) && !unknowns.contains(&n) { unknowns.push(n); } } (classes, unknowns) } fn resolve_function<'a>( ctx: &'a TopLevelContext, fun: &'a StatementType, method: bool, ) -> Result { if let StatementType::FunctionDef { args, returns, .. } = &fun { let args = if method { args.args[1..].iter() } else { args.args.iter() }; let args: Result, _> = args .map(|arg| type_from_expr(ctx, &arg.annotation.as_ref().unwrap().node)) .collect(); let args = args?; let result = match returns { Some(v) => Some(type_from_expr(ctx, &v.node)?), None => None, }; Ok(FnDef { args, result }) } else { unreachable!() } } fn get_expr_unknowns<'a>( defined: &mut Vec<&'a str>, unknowns: &mut Vec<&'a str>, expr: &'a ExpressionType, ) { match expr { ExpressionType::BoolOp { values, .. } => { for v in values.iter() { get_expr_unknowns(defined, unknowns, &v.node) } } ExpressionType::Binop { a, b, .. } => { get_expr_unknowns(defined, unknowns, &a.node); get_expr_unknowns(defined, unknowns, &b.node); } ExpressionType::Subscript { a, b } => { get_expr_unknowns(defined, unknowns, &a.node); get_expr_unknowns(defined, unknowns, &b.node); } ExpressionType::Unop { a, .. } => { get_expr_unknowns(defined, unknowns, &a.node); } ExpressionType::Compare { vals, .. } => { for v in vals.iter() { get_expr_unknowns(defined, unknowns, &v.node) } } ExpressionType::Attribute { value, .. } => { get_expr_unknowns(defined, unknowns, &value.node); } ExpressionType::Call { function, args, .. } => { get_expr_unknowns(defined, unknowns, &function.node); for v in args.iter() { get_expr_unknowns(defined, unknowns, &v.node) } } ExpressionType::List { elements } => { for v in elements.iter() { get_expr_unknowns(defined, unknowns, &v.node) } } ExpressionType::Tuple { elements } => { for v in elements.iter() { get_expr_unknowns(defined, unknowns, &v.node) } } ExpressionType::Comprehension { kind, generators } => { if generators.len() != 1 { unimplemented!() } let g = &generators[0]; get_expr_unknowns(defined, unknowns, &g.iter.node); let mut scoped = defined.clone(); get_expr_unknowns(defined, &mut scoped, &g.target.node); for if_expr in g.ifs.iter() { get_expr_unknowns(&mut scoped, unknowns, &if_expr.node); } match kind.as_ref() { ComprehensionKind::List { element } => { get_expr_unknowns(&mut scoped, unknowns, &element.node); } _ => unimplemented!(), } } ExpressionType::Slice { elements } => { for v in elements.iter() { get_expr_unknowns(defined, unknowns, &v.node); } } ExpressionType::Identifier { name } => { if !defined.contains(&name.as_str()) && !unknowns.contains(&name.as_str()) { unknowns.push(name); } } ExpressionType::IfExpression { test, body, orelse } => { get_expr_unknowns(defined, unknowns, &test.node); get_expr_unknowns(defined, unknowns, &body.node); get_expr_unknowns(defined, unknowns, &orelse.node); } _ => (), }; } struct ExprPattern<'a>(&'a ExpressionType, Vec, bool); impl<'a> ExprPattern<'a> { fn new(expr: &'a ExpressionType) -> ExprPattern { let mut pattern = ExprPattern(expr, Vec::new(), true); pattern.find_leaf(); pattern } fn pointed(&mut self) -> &'a ExpressionType { let mut current = self.0; for v in self.1.iter() { if let ExpressionType::Tuple { elements } = current { current = &elements[*v].node } else { unreachable!() } } current } fn find_leaf(&mut self) { let mut current = self.pointed(); while let ExpressionType::Tuple { elements } = current { if elements.is_empty() { break; } current = &elements[0].node; self.1.push(0); } } fn inc(&mut self) -> bool { loop { if self.1.is_empty() { return false; } let ind = self.1.pop().unwrap() + 1; let parent = self.pointed(); if let ExpressionType::Tuple { elements } = parent { if ind < elements.len() { self.1.push(ind); self.find_leaf(); return true; } } else { unreachable!() } } } } impl<'a> Iterator for ExprPattern<'a> { type Item = &'a ExpressionType; fn next(&mut self) -> Option { if self.2 { self.2 = false; Some(self.pointed()) } else if self.inc() { Some(self.pointed()) } else { None } } } fn get_stmt_unknowns<'a>( defined: &mut Vec<&'a str>, unknowns: &mut Vec<&'a str>, stmts: &'a [Statement], ) { for stmt in stmts.iter() { match &stmt.node { StatementType::Return { value } => { if let Some(value) = value { get_expr_unknowns(defined, unknowns, &value.node); } } StatementType::Assign { targets, value } => { get_expr_unknowns(defined, unknowns, &value.node); for target in targets.iter() { for node in ExprPattern::new(&target.node).into_iter() { if let ExpressionType::Identifier { name } = node { let name = name.as_str(); if !defined.contains(&name) { defined.push(name); } } else { get_expr_unknowns(defined, unknowns, node); } } } } StatementType::AugAssign { target, value, .. } => { get_expr_unknowns(defined, unknowns, &target.node); get_expr_unknowns(defined, unknowns, &value.node); } StatementType::AnnAssign { target, value, .. } => { get_expr_unknowns(defined, unknowns, &target.node); if let Some(value) = value { get_expr_unknowns(defined, unknowns, &value.node); } } StatementType::Expression { expression } => { get_expr_unknowns(defined, unknowns, &expression.node); } StatementType::Global { names } => { for name in names.iter() { let name = name.as_str(); if !unknowns.contains(&name) { unknowns.push(name); } } } StatementType::If { test, body, orelse } | StatementType::While { test, body, orelse } => { get_expr_unknowns(defined, unknowns, &test.node); // we are not very strict at this point... // some identifiers not treated as unknowns may not be resolved // but should be checked during type inference get_stmt_unknowns(defined, unknowns, body.as_slice()); if let Some(orelse) = orelse { get_stmt_unknowns(defined, unknowns, orelse.as_slice()); } } StatementType::For { is_async, target, iter, body, orelse } => { if *is_async { unimplemented!() } get_expr_unknowns(defined, unknowns, &iter.node); for node in ExprPattern::new(&target.node).into_iter() { if let ExpressionType::Identifier { name } = node { let name = name.as_str(); if !defined.contains(&name) { defined.push(name); } } else { get_expr_unknowns(defined, unknowns, node); } } get_stmt_unknowns(defined, unknowns, body.as_slice()); if let Some(orelse) = orelse { get_stmt_unknowns(defined, unknowns, orelse.as_slice()); } } _ => (), } } } pub fn resolve_signatures<'a>(ctx: &mut TopLevelContext<'a>, stmts: &'a [Statement]) { for stmt in stmts.iter() { match &stmt.node { StatementType::ClassDef { name, bases, body, .. } => { let mut parents = Vec::new(); for base in bases.iter() { let name = name_from_expr(&base.node); let c = ctx.get_type(name).unwrap(); let id = if let TypeEnum::ClassType(id) = c.as_ref() { *id } else { unreachable!() }; parents.push(id); } let mut fields = HashMap::new(); let mut functions = HashMap::new(); for stmt in body.iter() { match &stmt.node { StatementType::AnnAssign { target, annotation, .. } => { let name = name_from_expr(&target.node); let ty = type_from_expr(ctx, &annotation.node).unwrap(); fields.insert(name, ty); } StatementType::FunctionDef { name, .. } => { functions.insert( &name[..], resolve_function(ctx, &stmt.node, true).unwrap(), ); } _ => unimplemented!(), } } let class = ctx.get_type(name).unwrap(); let class = if let TypeEnum::ClassType(id) = class.as_ref() { ctx.get_class_def_mut(*id) } else { unreachable!() }; class.parents.extend_from_slice(&parents); class.base.fields.clone_from(&fields); class.base.methods.clone_from(&functions); } StatementType::FunctionDef { name, .. } => { ctx.add_fn(&name[..], resolve_function(ctx, &stmt.node, false).unwrap()); } _ => unimplemented!(), } } } #[cfg(test)] mod tests { use super::*; use indoc::indoc; use rustpython_parser::parser::{parse_program, parse_statement}; #[test] fn test_get_classes() { let ast = parse_program(indoc! {" class Foo: a: int32 b: Test def test(self, a: int32) -> Test2: return self.b class Bar(Foo, 'FooBar'): def test2(self, a: list[Foo]) -> Test2: return self.b def test3(self, a: list[FooBar2]) -> Test2: return self.b " }) .unwrap(); let (mut classes, mut unknowns) = get_typenames(&ast.statements); let classes_count = classes.len(); let unknowns_count = unknowns.len(); classes.sort(); unknowns.sort(); assert_eq!(classes.len(), classes_count); assert_eq!(unknowns.len(), unknowns_count); assert_eq!(&classes, &["Bar", "Foo"]); assert_eq!(&unknowns, &["FooBar", "FooBar2", "Test", "Test2"]); } #[test] fn test_assignment() { let ast = parse_statement(indoc! {" ((a, b), c[i]) = core.foo(x, get_y()) " }) .unwrap(); let mut defined = Vec::new(); let mut unknowns = Vec::new(); get_stmt_unknowns(&mut defined, &mut unknowns, ast.as_slice()); defined.sort(); unknowns.sort(); assert_eq!(defined.as_slice(), &["a", "b"]); assert_eq!(unknowns.as_slice(), &["c", "core", "get_y", "i", "x"]); } }