diff --git a/nac3core/src/type_check/signature.rs b/nac3core/src/type_check/signature.rs index 7b26e827f3..b1e35db0d6 100644 --- a/nac3core/src/type_check/signature.rs +++ b/nac3core/src/type_check/signature.rs @@ -2,7 +2,9 @@ use super::context::TopLevelContext; use super::primitives::*; use super::typedef::*; -use rustpython_parser::ast::{ExpressionType, Statement, StatementType, StringGroup, ComprehensionKind}; +use rustpython_parser::ast::{ + ComprehensionKind, ExpressionType, Statement, StatementType, StringGroup, +}; use std::collections::HashMap; // TODO: fix condition checking, return error message instead of panic... @@ -49,10 +51,7 @@ fn name_from_expr<'a>(expr: &'a ExpressionType) -> &'a str { } } -fn type_from_expr<'a>( - ctx: &'a TopLevelContext, - expr: &'a ExpressionType, -) -> Result { +fn type_from_expr<'a>(ctx: &'a TopLevelContext, expr: &'a ExpressionType) -> Result { match expr { ExpressionType::Identifier { name } => { ctx.get_type(name).ok_or_else(|| "no such type".into()) @@ -161,7 +160,11 @@ fn resolve_function<'a>( } } -fn get_expression_unknowns<'a>(defined: &mut Vec<&'a str>, unknowns: &mut Vec<&'a str>, expr: &'a ExpressionType) { +fn get_expression_unknowns<'a>( + defined: &[&'a str], + unknowns: &mut Vec<&'a str>, + expr: &'a ExpressionType, +) { match expr { ExpressionType::BoolOp { values, .. } => { for v in values.iter() { @@ -209,16 +212,16 @@ fn get_expression_unknowns<'a>(defined: &mut Vec<&'a str>, unknowns: &mut Vec<&' } let g = &generators[0]; get_expression_unknowns(defined, unknowns, &g.iter.node); - let mut scoped = defined.clone(); + let mut scoped = defined.to_owned(); get_expression_unknowns(defined, &mut scoped, &g.target.node); for if_expr in g.ifs.iter() { - get_expression_unknowns(&mut scoped, unknowns, &if_expr.node); + get_expression_unknowns(&scoped, unknowns, &if_expr.node); } match kind.as_ref() { ComprehensionKind::List { element } => { - get_expression_unknowns(&mut scoped, unknowns, &element.node); + get_expression_unknowns(&scoped, unknowns, &element.node); } - _ => unimplemented!() + _ => unimplemented!(), } } ExpressionType::Slice { elements } => { @@ -236,10 +239,95 @@ fn get_expression_unknowns<'a>(defined: &mut Vec<&'a str>, unknowns: &mut Vec<&' get_expression_unknowns(defined, unknowns, &body.node); get_expression_unknowns(defined, unknowns, &orelse.node); } - _ => () + _ => (), }; } +pub fn get_pattern_match_unknowns<'a>( + defined: &mut Vec<&'a str>, + unknowns: &mut Vec<&'a str>, + expr: &'a ExpressionType, +) { + match expr { + ExpressionType::Identifier { name } => { + defined.push(&name.as_str()); + } + ExpressionType::Tuple { elements } => { + for v in elements.iter() { + get_pattern_match_unknowns(defined, unknowns, &v.node); + } + } + _ => { + get_expression_unknowns(defined, unknowns, expr); + } + } +} + +pub fn get_statement_unknowns<'a>( + defined: &mut Vec<&'a str>, + unknowns: &mut Vec<&'a str>, + stmts: &'a [Statement], +) { + for stmt in stmts.iter() { + match &stmt.node { + StatementType::Return { value } => { + if let Some(v) = value { + get_expression_unknowns(defined, unknowns, &v.node); + } + } + StatementType::Assign { targets, value } => { + get_expression_unknowns(defined, unknowns, &value.node); + for t in targets.iter() { + get_pattern_match_unknowns(defined, unknowns, &t.node); + } + } + StatementType::AugAssign { target, value, .. } => { + get_expression_unknowns(defined, unknowns, &target.node); + get_expression_unknowns(defined, unknowns, &value.node); + } + StatementType::AnnAssign { target, value, .. } => { + if let Some(v) = value { + get_expression_unknowns(defined, unknowns, &v.node); + } + get_pattern_match_unknowns(defined, unknowns, &target.node); + } + StatementType::Expression { expression } => { + get_expression_unknowns(defined, unknowns, &expression.node); + } + StatementType::If { test, body, orelse } + | StatementType::While { test, body, orelse } => { + get_expression_unknowns(defined, unknowns, &test.node); + get_statement_unknowns(defined, unknowns, body.as_slice()); + if let Some(orelse) = orelse { + get_statement_unknowns(defined, unknowns, orelse.as_slice()); + } + } + StatementType::For { is_async, target, body, orelse, iter } => { + if *is_async { + unimplemented!(); + } + let mut scoped = defined.to_owned(); + get_expression_unknowns(defined, unknowns, &iter.node); + get_expression_unknowns(defined, &mut scoped, &target.node); + get_statement_unknowns(&mut scoped, unknowns, body.as_slice()); + if let Some(orelse) = orelse { + get_statement_unknowns(&mut scoped, unknowns, orelse.as_slice()); + } + } + StatementType::With { is_async, items, body } => { + if *is_async { + unimplemented!(); + } + let mut scoped = defined.to_owned(); + for item in items.iter() { + + } + } + _ => {} + } + } +} + pub fn resolve_signatures<'a>(ctx: &mut TopLevelContext<'a>, stmts: &'a [Statement]) { for stmt in stmts.iter() { match &stmt.node {