diff --git a/nac3core/src/type_check/mod.rs b/nac3core/src/type_check/mod.rs index dfad981a2b..c8e4943fa3 100644 --- a/nac3core/src/type_check/mod.rs +++ b/nac3core/src/type_check/mod.rs @@ -5,3 +5,5 @@ mod magic_methods; pub mod primitives; pub mod statement_check; pub mod typedef; +pub mod signature; + diff --git a/nac3core/src/type_check/primitives.rs b/nac3core/src/type_check/primitives.rs index 0aa076ddf6..cf2900063d 100644 --- a/nac3core/src/type_check/primitives.rs +++ b/nac3core/src/type_check/primitives.rs @@ -2,6 +2,8 @@ use super::context::*; use super::typedef::{TypeEnum::*, *}; use std::collections::HashMap; +pub const PRIMITIVES: [&str; 6] = ["int32", "int64", "float", "bool", "list", "tuple"]; + pub const TUPLE_TYPE: ParamId = ParamId(0); pub const LIST_TYPE: ParamId = ParamId(1); diff --git a/nac3core/src/type_check/signature.rs b/nac3core/src/type_check/signature.rs new file mode 100644 index 0000000000..04d3fc4ba1 --- /dev/null +++ b/nac3core/src/type_check/signature.rs @@ -0,0 +1,121 @@ +/// obtain class and function signature from AST +use super::primitives::PRIMITIVES; +use rustpython_parser::ast::{ExpressionType, Statement, StatementType, StringGroup}; + +fn typename_from_expr<'b: 'a, 'a>(typenames: &mut Vec<&'a str>, expr: &'b ExpressionType) { + match expr { + ExpressionType::Identifier { name } => typenames.push(&name), + ExpressionType::String { value } => match value { + StringGroup::Constant { value } => typenames.push(&value), + _ => unimplemented!(), + }, + ExpressionType::Subscript { a, b } => { + typename_from_expr(typenames, &b.node); + typename_from_expr(typenames, &a.node) + } + _ => unimplemented!(), + } +} + +fn typename_from_fn<'b: 'a, 'a>(typenames: &mut Vec<&'a str>, fun: &'b StatementType) { + match fun { + StatementType::FunctionDef { args, returns, .. } => { + for arg in args.args.iter() { + if let Some(ann) = &arg.annotation { + typename_from_expr(typenames, &ann.node); + } + } + if let Some(returns) = &returns { + typename_from_expr(typenames, &returns.node); + } + } + _ => unreachable!(), + } +} + +pub fn resolve_classes<'b: 'a, 'a>(stmts: &'b [Statement]) -> (Vec<&'a str>, Vec<&'a str>) { + let mut classes = Vec::new(); + let mut typenames = Vec::new(); + for stmt in stmts.iter() { + match &stmt.node { + StatementType::ClassDef { + name, body, bases, .. + } => { + // check if class is not duplicated... + // and annotations + classes.push(&name[..]); + for base in bases.iter() { + let name = match &base.node { + ExpressionType::Identifier { name } => name, + ExpressionType::String { value } => match value { + StringGroup::Constant { value } => value, + _ => unimplemented!(), + }, + _ => unimplemented!(), + }; + typenames.push(&name[..]); + } + // may check if fields/functions are not duplicated + for stmt in body.iter() { + match &stmt.node { + StatementType::AnnAssign { annotation, .. } => { + typename_from_expr(&mut typenames, &annotation.node) + } + StatementType::FunctionDef { .. } => { + typename_from_fn(&mut typenames, &stmt.node); + } + _ => unimplemented!(), + } + } + } + StatementType::FunctionDef { .. } => { + // may check annotations + typename_from_fn(&mut typenames, &stmt.node); + } + _ => (), + } + } + let mut unknowns = Vec::new(); + for n in typenames { + if !PRIMITIVES.contains(&n) && !classes.contains(&n) && !unknowns.contains(&n) { + unknowns.push(n); + } + } + (classes, unknowns) +} + +#[cfg(test)] +mod tests { + use super::*; + use indoc::indoc; + use rustpython_parser::parser::parse_program; + + #[test] + fn test_get_classes() { + let ast = parse_program(indoc! {" + class Foo: + a: int32 + b: Test + + def test(self, a: int32) -> Test2: + return b + + class Bar(Foo, 'FooBar'): + def test2(self, a: list[Foo]) -> Test2: + return b + + def test3(self, a: list[FooBar2]) -> Test2: + return b + " }) + .unwrap(); + let (mut classes, mut unknowns) = resolve_classes(&ast.statements); + let classes_count = classes.len(); + let unknowns_count = unknowns.len(); + classes.sort(); + unknowns.sort(); + assert_eq!(classes.len(), classes_count); + assert_eq!(unknowns.len(), unknowns_count); + assert_eq!(&classes, &["Bar", "Foo"]); + assert_eq!(&unknowns, &["FooBar", "FooBar2", "Test", "Test2"]); + } +}