diff --git a/nac3core/src/type_check/signature.rs b/nac3core/src/type_check/signature.rs index 81895252..7b26e827 100644 --- a/nac3core/src/type_check/signature.rs +++ b/nac3core/src/type_check/signature.rs @@ -2,7 +2,7 @@ use super::context::TopLevelContext; use super::primitives::*; use super::typedef::*; -use rustpython_parser::ast::{ExpressionType, Statement, StatementType, StringGroup}; +use rustpython_parser::ast::{ExpressionType, Statement, StatementType, StringGroup, ComprehensionKind}; use std::collections::HashMap; // TODO: fix condition checking, return error message instead of panic... @@ -161,6 +161,85 @@ fn resolve_function<'a>( } } +fn get_expression_unknowns<'a>(defined: &mut Vec<&'a str>, unknowns: &mut Vec<&'a str>, expr: &'a ExpressionType) { + match expr { + ExpressionType::BoolOp { values, .. } => { + for v in values.iter() { + get_expression_unknowns(defined, unknowns, &v.node) + } + } + ExpressionType::Binop { a, b, .. } => { + get_expression_unknowns(defined, unknowns, &a.node); + get_expression_unknowns(defined, unknowns, &b.node); + } + ExpressionType::Subscript { a, b } => { + get_expression_unknowns(defined, unknowns, &a.node); + get_expression_unknowns(defined, unknowns, &b.node); + } + ExpressionType::Unop { a, .. } => { + get_expression_unknowns(defined, unknowns, &a.node); + } + ExpressionType::Compare { vals, .. } => { + for v in vals.iter() { + get_expression_unknowns(defined, unknowns, &v.node) + } + } + ExpressionType::Attribute { value, .. } => { + get_expression_unknowns(defined, unknowns, &value.node); + } + ExpressionType::Call { function, args, .. } => { + get_expression_unknowns(defined, unknowns, &function.node); + for v in args.iter() { + get_expression_unknowns(defined, unknowns, &v.node) + } + } + ExpressionType::List { elements } => { + for v in elements.iter() { + get_expression_unknowns(defined, unknowns, &v.node) + } + } + ExpressionType::Tuple { elements } => { + for v in elements.iter() { + get_expression_unknowns(defined, unknowns, &v.node) + } + } + ExpressionType::Comprehension { kind, generators } => { + if generators.len() != 1 { + unimplemented!() + } + let g = &generators[0]; + get_expression_unknowns(defined, unknowns, &g.iter.node); + let mut scoped = defined.clone(); + get_expression_unknowns(defined, &mut scoped, &g.target.node); + for if_expr in g.ifs.iter() { + get_expression_unknowns(&mut scoped, unknowns, &if_expr.node); + } + match kind.as_ref() { + ComprehensionKind::List { element } => { + get_expression_unknowns(&mut scoped, unknowns, &element.node); + } + _ => unimplemented!() + } + } + ExpressionType::Slice { elements } => { + for v in elements.iter() { + get_expression_unknowns(defined, unknowns, &v.node); + } + } + ExpressionType::Identifier { name } => { + if !defined.contains(&name.as_str()) && !unknowns.contains(&name.as_str()) { + unknowns.push(name); + } + } + ExpressionType::IfExpression { test, body, orelse } => { + get_expression_unknowns(defined, unknowns, &test.node); + get_expression_unknowns(defined, unknowns, &body.node); + get_expression_unknowns(defined, unknowns, &orelse.node); + } + _ => () + }; +} + pub fn resolve_signatures<'a>(ctx: &mut TopLevelContext<'a>, stmts: &'a [Statement]) { for stmt in stmts.iter() { match &stmt.node {