diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index 1e8a2de..537656a 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -1,20 +1,41 @@ use super::type_inferencer::Inferencer; use super::typedef::Type; -use rustpython_parser::ast::{self, Expr, ExprKind, StmtKind}; +use rustpython_parser::ast::{self, Expr, ExprKind, Stmt, StmtKind}; use std::iter::once; impl<'a> Inferencer<'a> { + fn check_pattern( + &mut self, + pattern: &Expr>, + defined_identifiers: &mut Vec, + ) { + match &pattern.node { + ExprKind::Name { id, .. } => { + if !defined_identifiers.contains(id) { + defined_identifiers.push(id.clone()); + } + } + ExprKind::Tuple { elts, .. } => { + for elt in elts.iter() { + self.check_pattern(elt, defined_identifiers); + } + } + _ => unimplemented!(), + } + } + fn check_expr( &mut self, expr: &Expr>, defined_identifiers: &[String], ) -> Result<(), String> { + // there are some cases where the custom field is None if let Some(ty) = &expr.custom { let ty = self.unifier.get_ty(*ty); let ty = ty.as_ref().borrow(); - if ty.is_concrete() { + if !ty.is_concrete() { return Err(format!( - "expected concrete type at {:?} but got {}", + "expected concrete type at {} but got {}", expr.location, ty.get_type_name() )); @@ -23,7 +44,7 @@ impl<'a> Inferencer<'a> { match &expr.node { ExprKind::Name { id, .. } => { if !defined_identifiers.contains(id) { - return Err(format!("unknown identifier {} (use before def?)", id)); + return Err(format!("unknown identifier {} (use before def?) at {}", id, expr.location)); } } ExprKind::List { elts, .. } @@ -34,14 +55,14 @@ impl<'a> Inferencer<'a> { } } ExprKind::Attribute { value, .. } => { - self.check_expr(value.as_ref(), defined_identifiers)?; + self.check_expr(value, defined_identifiers)?; } ExprKind::BinOp { left, right, .. } => { - self.check_expr(left.as_ref(), defined_identifiers)?; - self.check_expr(right.as_ref(), defined_identifiers)?; + self.check_expr(left, defined_identifiers)?; + self.check_expr(right, defined_identifiers)?; } ExprKind::UnaryOp { operand, .. } => { - self.check_expr(operand.as_ref(), defined_identifiers)?; + self.check_expr(operand, defined_identifiers)?; } ExprKind::Compare { left, comparators, .. @@ -51,13 +72,13 @@ impl<'a> Inferencer<'a> { } } ExprKind::Subscript { value, slice, .. } => { - self.check_expr(value.as_ref(), defined_identifiers)?; - self.check_expr(slice.as_ref(), defined_identifiers)?; + self.check_expr(value, defined_identifiers)?; + self.check_expr(slice, defined_identifiers)?; } ExprKind::IfExp { test, body, orelse } => { - self.check_expr(test.as_ref(), defined_identifiers)?; - self.check_expr(body.as_ref(), defined_identifiers)?; - self.check_expr(orelse.as_ref(), defined_identifiers)?; + self.check_expr(test, defined_identifiers)?; + self.check_expr(body, defined_identifiers)?; + self.check_expr(orelse, defined_identifiers)?; } ExprKind::Slice { lower, upper, step } => { for elt in [lower.as_ref(), upper.as_ref(), step.as_ref()] @@ -67,10 +88,132 @@ impl<'a> Inferencer<'a> { self.check_expr(elt, defined_identifiers)?; } } - ExprKind::ListComp { .. } => unimplemented!(), - ExprKind::Lambda { .. } => unimplemented!(), - _ => {} + ExprKind::Lambda { args, body } => { + let mut defined_identifiers = defined_identifiers.to_vec(); + for arg in args.args.iter() { + if !defined_identifiers.contains(&arg.node.arg) { + defined_identifiers.push(arg.node.arg.clone()); + } + } + self.check_expr(body, &defined_identifiers)?; + } + ExprKind::ListComp { + elt, generators, .. + } => { + // in our type inference stage, we already make sure that there is only 1 generator + let ast::Comprehension { + target, iter, ifs, .. + } = &generators[0]; + self.check_expr(iter, defined_identifiers)?; + let mut defined_identifiers = defined_identifiers.to_vec(); + self.check_pattern(target, &mut defined_identifiers); + for term in once(elt.as_ref()).chain(ifs.iter()) { + self.check_expr(term, &defined_identifiers)?; + } + } + ExprKind::Call { + func, + args, + keywords, + } => { + for expr in once(func.as_ref()) + .chain(args.iter()) + .chain(keywords.iter().map(|v| v.node.value.as_ref())) + { + self.check_expr(expr, defined_identifiers)?; + } + } + ExprKind::Constant { .. } => {} + _ => { + println!("{:?}", expr.node); + unimplemented!() + } } Ok(()) } + + fn check_stmt( + &mut self, + stmt: &Stmt>, + defined_identifiers: &mut Vec, + ) -> Result { + match &stmt.node { + StmtKind::For { + target, + iter, + body, + orelse, + .. + } => { + self.check_expr(iter, defined_identifiers)?; + for stmt in orelse.iter() { + self.check_stmt(stmt, defined_identifiers)?; + } + let mut defined_identifiers = defined_identifiers.clone(); + self.check_pattern(target, &mut defined_identifiers); + for stmt in body.iter() { + self.check_stmt(stmt, &mut defined_identifiers)?; + } + Ok(false) + } + StmtKind::If { test, body, orelse } => { + self.check_expr(test, defined_identifiers)?; + let mut body_identifiers = defined_identifiers.clone(); + let mut orelse_identifiers = defined_identifiers.clone(); + let body_returned = self.check_block(body, &mut body_identifiers)?; + let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?; + + for ident in body_identifiers.iter() { + if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) { + defined_identifiers.push(ident.clone()) + } + } + Ok(body_returned && orelse_returned) + } + StmtKind::While { test, body, orelse } => { + self.check_expr(test, defined_identifiers)?; + let mut defined_identifiers = defined_identifiers.clone(); + self.check_block(body, &mut defined_identifiers)?; + self.check_block(orelse, &mut defined_identifiers)?; + Ok(false) + } + StmtKind::Expr { value } => { + self.check_expr(value, defined_identifiers)?; + Ok(false) + } + StmtKind::Assign { targets, value, .. } => { + self.check_expr(value, defined_identifiers)?; + for target in targets { + self.check_pattern(target, defined_identifiers); + } + Ok(false) + } + StmtKind::AnnAssign { target, value, .. } => { + if let Some(value) = value { + self.check_expr(value, defined_identifiers)?; + self.check_pattern(target, defined_identifiers); + } + Ok(false) + } + // break, return, raise, etc. + _ => Ok(false), + } + } + + pub fn check_block( + &mut self, + block: &[Stmt>], + defined_identifiers: &mut Vec, + ) -> Result { + let mut ret = false; + for stmt in block { + if ret { + return Err(format!("dead code at {:?}", stmt.location)); + } + if self.check_stmt(stmt, defined_identifiers)? { + ret = true; + } + } + Ok(ret) + } } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index b7d296c..1f27dab 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -107,6 +107,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } } ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {} + ast::StmtKind::Break | ast::StmtKind::Continue => {} ast::StmtKind::Return { value } => match (value, self.return_type) { (Some(v), Some(v1)) => { self.unifier.unify(v.custom.unwrap(), v1)?; @@ -130,12 +131,14 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { func, args, keywords, - } => self.fold_call(node.location, *func, args, keywords)?, + } => { + return self.fold_call(node.location, *func, args, keywords); + } ast::ExprKind::Lambda { args, body } => { - self.fold_lambda(node.location, *args, *body)? + return self.fold_lambda(node.location, *args, *body); } ast::ExprKind::ListComp { elt, generators } => { - self.fold_listcomp(node.location, *elt, generators)? + return self.fold_listcomp(node.location, *elt, generators); } _ => fold::fold_expr(self, node)?, }; diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index c85408f..be9b150 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -8,12 +8,12 @@ use rustpython_parser::parser::parse_program; use test_case::test_case; struct Resolver { - type_mapping: HashMap, + identifier_mapping: HashMap, } impl SymbolResolver for Resolver { fn get_symbol_type(&mut self, str: &str) -> Option { - self.type_mapping.get(str).cloned() + self.identifier_mapping.get(str).cloned() } fn parse_type_name(&mut self, _: &ast::Expr<()>) -> Option { @@ -35,12 +35,13 @@ struct TestEnvironment { pub calls: Vec>, pub primitives: PrimitiveStore, pub id_to_name: HashMap, + pub identifier_mapping: HashMap, } impl TestEnvironment { fn new() -> TestEnvironment { let mut unifier = Unifier::new(); - let mut type_mapping = HashMap::new(); + let mut identifier_mapping = HashMap::new(); let int32 = unifier.add_ty(TypeEnum::TObj { obj_id: 0, fields: HashMap::new(), @@ -66,7 +67,7 @@ impl TestEnvironment { fields: HashMap::new(), params: HashMap::new(), }); - type_mapping.insert("None".into(), none); + identifier_mapping.insert("None".into(), none); let primitives = PrimitiveStore { int32, @@ -84,7 +85,7 @@ impl TestEnvironment { params: [(id, v0)].iter().cloned().collect(), }); - type_mapping.insert( + identifier_mapping.insert( "Foo".into(), unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], @@ -105,13 +106,14 @@ impl TestEnvironment { .cloned() .collect(); - let resolver = Box::new(Resolver { type_mapping }) as Box; + let resolver = Box::new(Resolver { identifier_mapping: identifier_mapping.clone() }) as Box; TestEnvironment { unifier, resolver, primitives, id_to_name, + identifier_mapping, calls: Vec::new(), } } @@ -168,15 +170,20 @@ impl TestEnvironment { [("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect() ; "listcomp test")] fn test_basic(source: &str, mapping: HashMap<&str, &str>) { + println!("source:\n{}", source); let mut env = TestEnvironment::new(); let id_to_name = std::mem::take(&mut env.id_to_name); + let mut defined_identifiers = env.identifier_mapping.keys().cloned().collect(); let mut inferencer = env.get_inferencer(); let statements = parse_program(source).unwrap(); - statements + let statements = statements .into_iter() .map(|v| inferencer.fold_stmt(v)) .collect::, _>>() .unwrap(); + + inferencer.check_block(&statements, &mut defined_identifiers).unwrap(); + for (k, v) in inferencer.variable_mapping.iter() { let name = inferencer.unifier.stringify( *v,