diff --git a/nac3core/src/typecheck/expression_inference.rs b/nac3core/src/typecheck/expression_inference.rs index 58a17f347..3afdf63d2 100644 --- a/nac3core/src/typecheck/expression_inference.rs +++ b/nac3core/src/typecheck/expression_inference.rs @@ -36,7 +36,7 @@ impl<'a> ast::fold::Fold> for InferenceContext<'a> { // pre-fold let mut expr = node; expr = match &expr.node { - ast::ExprKind::ListComp { .. } => self.prefold_list_comprehension(expr)?, + ast::ExprKind::ListComp { .. } => self.fold_listcomp(expr)?, _ => rustpython_parser::ast::fold::fold_expr(self, expr)? }; @@ -55,7 +55,7 @@ impl<'a> ast::fold::Fold> for InferenceContext<'a> { ast::ExprKind::Call {func, args, keywords} => self.infer_call(func, args, keywords), ast::ExprKind::Subscript {value, slice, ctx: _} => self.infer_subscript(value, slice), ast::ExprKind::IfExp {test, body, orelse} => self.infer_if_expr(test, body, orelse), - ast::ExprKind::ListComp {elt, generators} => self.infer_list_comprehesion(elt, generators), + ast::ExprKind::ListComp {elt: _, generators: _} => Ok(expr.custom), // already folded ast::ExprKind::Slice { .. } => Ok(None), // special handling for slice, which is supported _ => Err("not supported yet".into()) }?, @@ -318,7 +318,7 @@ impl<'a> InferenceContext<'a> { } } - fn infer_list_comprehesion(&self, elt: &Box>>, generators: &Vec>>) -> Result, String> { + fn _infer_list_comprehesion(&self, elt: &Box>>, generators: &Vec>>) -> Result, String> { if generators[0] .ifs .iter() @@ -328,39 +328,37 @@ impl<'a> InferenceContext<'a> { vec![elt.custom.clone().ok_or_else(|| "elements should have value".to_string())?]).into())) } else { Err("test must be bool".into()) - } + } } // some pre-folds need special handling - fn prefold_list_comprehension(&mut self, expr: ast::Expr>) -> Result>, String> { + fn fold_listcomp(&mut self, expr: ast::Expr>) -> Result>, String> { if let ast::Expr { location, - custom, + custom: _, node: ast::ExprKind::ListComp { elt, - generators}} = expr { + mut generators}} = expr { // if is list comprehension, need special pre-fold if generators.len() != 1 { return Err("only 1 generator statement is supported".into()); } - if generators[0].is_async { + let gen = generators.remove(0); + if gen.is_async { return Err("async is not supported".into()); } // fold iter first since it does not contain new identifiers - let generators_first_folded = generators - .into_iter() - .map(|x| -> Result>, String> {Ok(ast::Comprehension { - target: x.target, - iter: Box::new(self.fold_expr(*x.iter)?), // fold here - ifs: x.ifs, - is_async: x.is_async - })}) - .collect::, _>>()?; + let gen_first_folded = ast::Comprehension { + target: gen.target, + iter: Box::new(self.fold_expr(*gen.iter)?), // fold here + ifs: gen.ifs, + is_async: gen.is_async + }; if let TypeEnum::ParametricType( primitives::LIST_TYPE, - ls) = generators_first_folded[0] + ls) = gen_first_folded .iter .custom .as_ref() @@ -368,29 +366,39 @@ impl<'a> InferenceContext<'a> { .as_ref() .clone() { self.with_scope(|ctx| -> Result>, String> { - ctx.infer_simple_binding(&generators_first_folded[0].target, ls[0].clone())?; - Ok(ast::Expr { - location, - custom, - node: ast::ExprKind::ListComp { // now fold things with new name - elt: - Box::new(ctx.fold_expr(*elt)?), - generators: - generators_first_folded - .into_iter() - .map(|x| -> Result>, String> {Ok(ast::Comprehension { - target: Box::new(ctx.fold_expr(*x.target)?), - iter: x.iter, - ifs: x - .ifs - .into_iter() - .map(|x| ctx.fold_expr(x)) - .collect::, _>>()?, - is_async: x.is_async - })}) - .collect::, _>>()? - } - }) + ctx.infer_simple_binding(&gen_first_folded.target, ls[0].clone())?; + let elt_folded = Box::new(ctx.fold_expr(*elt)?); + let target_folded = Box::new(ctx.fold_expr(*gen_first_folded.target)?); + let ifs_folded = gen_first_folded + .ifs + .into_iter() + .map(|x| ctx.fold_expr(x)) + .collect::>>, _>>()?; + + if ifs_folded + .iter() + .all(|x| x.custom == Some(ctx.get_primitive(primitives::BOOL_TYPE))) { + Ok(ast::Expr { + location, + custom: Some(TypeEnum::ParametricType( + primitives::LIST_TYPE, + vec![elt_folded + .custom + .clone() + .ok_or_else(|| "elements cannot be typped".to_string())?]).into()), + node: ast::ExprKind::ListComp { + elt: elt_folded, + generators: vec![ast::Comprehension { + target: target_folded, + ifs: ifs_folded, + iter: gen_first_folded.iter, + is_async: gen_first_folded.is_async + }] + } + }) + } else { + Err("test must be bool".into()) + } }).1 } else { Err("iteration is supported for list only".into())