diff --git a/nac3core/src/typecheck/expression_inference.rs b/nac3core/src/typecheck/expression_inference.rs index 3afdf63..93229fb 100644 --- a/nac3core/src/typecheck/expression_inference.rs +++ b/nac3core/src/typecheck/expression_inference.rs @@ -8,8 +8,7 @@ use crate::typecheck::primitives; use rustpython_parser::ast; use rustpython_parser::ast::fold::Fold; -struct Premapper; -impl ast::fold::Fold<()> for Premapper { +impl<'a> ast::fold::Fold<()> for InferenceContext<'a> { type TargetU = Option; type Error = String; @@ -18,26 +17,11 @@ impl ast::fold::Fold<()> for Premapper { } fn fold_expr(&mut self, node: ast::Expr<()>) -> Result, Self::Error> { - ast::fold::fold_expr(self, node) - } -} - -impl<'a> ast::fold::Fold> for InferenceContext<'a> { - type TargetU = Option; - type Error = String; - - fn map_user(&mut self, user: Option) -> Result { - Ok(user) - } - - fn fold_expr(&mut self, node: ast::Expr>) -> Result, Self::Error> { - assert_eq!(node.custom, None); + // assert_eq!(node.custom, None); - // pre-fold - let mut expr = node; - expr = match &expr.node { - ast::ExprKind::ListComp { .. } => self.fold_listcomp(expr)?, - _ => rustpython_parser::ast::fold::fold_expr(self, expr)? + let expr = match &node.node { + ast::ExprKind::ListComp { .. } => self.fold_listcomp(node)?, + _ => rustpython_parser::ast::fold::fold_expr(self, node)? }; Ok(ast::Expr { @@ -55,7 +39,7 @@ impl<'a> ast::fold::Fold> for InferenceContext<'a> { ast::ExprKind::Call {func, args, keywords} => self.infer_call(func, args, keywords), ast::ExprKind::Subscript {value, slice, ctx: _} => self.infer_subscript(value, slice), ast::ExprKind::IfExp {test, body, orelse} => self.infer_if_expr(test, body, orelse), - ast::ExprKind::ListComp {elt: _, generators: _} => Ok(expr.custom), // already folded + ast::ExprKind::ListComp {elt: _, generators: _} => panic!("should not earch here, the list comp should be folded before"), // already folded ast::ExprKind::Slice { .. } => Ok(None), // special handling for slice, which is supported _ => Err("not supported yet".into()) }?, @@ -332,7 +316,7 @@ impl<'a> InferenceContext<'a> { } // some pre-folds need special handling - fn fold_listcomp(&mut self, expr: ast::Expr>) -> Result>, String> { + fn fold_listcomp(&mut self, expr: ast::Expr<()>) -> Result>, String> { if let ast::Expr { location, custom: _, @@ -348,29 +332,25 @@ impl<'a> InferenceContext<'a> { return Err("async is not supported".into()); } - // fold iter first since it does not contain new identifiers - let gen_first_folded = ast::Comprehension { - target: gen.target, - iter: Box::new(self.fold_expr(*gen.iter)?), // fold here - ifs: gen.ifs, - is_async: gen.is_async - }; + let ast::Comprehension {iter, + target, + ifs, + is_async} = gen; + let iter_folded = Box::new(self.fold_expr(*iter)?); if let TypeEnum::ParametricType( primitives::LIST_TYPE, - ls) = gen_first_folded - .iter + ls) = iter_folded .custom .as_ref() .ok_or_else(|| "no value".to_string())? .as_ref() .clone() { self.with_scope(|ctx| -> Result>, String> { - ctx.infer_simple_binding(&gen_first_folded.target, ls[0].clone())?; + ctx.infer_simple_binding(&target, ls[0].clone())?; let elt_folded = Box::new(ctx.fold_expr(*elt)?); - let target_folded = Box::new(ctx.fold_expr(*gen_first_folded.target)?); - let ifs_folded = gen_first_folded - .ifs + let target_folded = Box::new(ctx.fold_expr(*target)?); + let ifs_folded = ifs .into_iter() .map(|x| ctx.fold_expr(x)) .collect::>>, _>>()?; @@ -391,8 +371,8 @@ impl<'a> InferenceContext<'a> { generators: vec![ast::Comprehension { target: target_folded, ifs: ifs_folded, - iter: gen_first_folded.iter, - is_async: gen_first_folded.is_async + iter: iter_folded, + is_async: is_async }] } }) @@ -408,7 +388,7 @@ impl<'a> InferenceContext<'a> { } } - fn infer_simple_binding(&mut self, name: &ast::Expr>, ty: Type) -> Result<(), String> { + fn infer_simple_binding(&mut self, name: &ast::Expr, ty: Type) -> Result<(), String> { match &name.node { ast::ExprKind::Name {id, ctx: _} => { if id == "_" { @@ -445,7 +425,6 @@ pub struct ExpressionInferencer<'a> { } impl<'a> ExpressionInferencer<'a> { pub fn fold_expr(&mut self, expr: ast::Expr) -> Result>, String> { - let expr = Premapper.fold_expr(expr)?; self.ctx.fold_expr(expr) } } @@ -653,7 +632,7 @@ pub mod test { let ast12 = rustpython_parser::parser::parse_expression("(1, True, 3, False)[1]").unwrap(); let folded = inf.fold_expr(ast1).unwrap(); - let folded_2 = Premapper.fold_expr(ast2).unwrap(); + let folded_2 = inf.fold_expr(ast2).unwrap(); let folded_3 = inf.fold_expr(ast3).unwrap(); let folded_4 = inf.fold_expr(ast4).unwrap(); let folded_5 = inf.fold_expr(ast5).unwrap(); @@ -677,6 +656,5 @@ pub mod test { println!("{:?}", folded_10.custom); println!("{:?}", folded_11.custom); println!("{:?}", folded_12.custom); - } } \ No newline at end of file