From ff27a1697e26ec84f0d657b81ca09c3e081a8090 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Sun, 12 Dec 2021 05:39:48 +0800 Subject: [PATCH] nac3core: fix for loop type inference --- nac3core/src/typecheck/type_inferencer/mod.rs | 42 ++++++++++++++----- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 4ab40e6..c18bed8 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -122,9 +122,36 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { }, } } - ast::StmtKind::For { ref target, .. } => { - self.infer_pattern(target)?; - fold::fold_stmt(self, node)? + ast::StmtKind::For { target, iter, body, orelse, config_comment, type_comment } => { + self.infer_pattern(&target)?; + let target = self.fold_expr(*target)?; + let iter = self.fold_expr(*iter)?; + if self.unifier.unioned(iter.custom.unwrap(), self.primitives.range) { + self.unify(self.primitives.int32, target.custom.unwrap(), &target.location)?; + } else { + let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); + self.unify(list, iter.custom.unwrap(), &iter.location)?; + } + let body = body + .into_iter() + .map(|b| self.fold_stmt(b)) + .collect::, _>>()?; + let orelse = orelse + .into_iter() + .map(|o| self.fold_stmt(o)) + .collect::, _>>()?; + Located { + location: node.location, + node: ast::StmtKind::For { + target: Box::new(target), + iter: Box::new(iter), + body, + orelse, + config_comment, + type_comment, + }, + custom: None + } } ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => { for target in targets.iter_mut() { @@ -201,14 +228,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { _ => fold::fold_stmt(self, node)?, }; match &stmt.node { - ast::StmtKind::For { target, iter, .. } => { - if self.unifier.unioned(iter.custom.unwrap(), self.primitives.range) { - self.unify(self.primitives.int32, target.custom.unwrap(), &target.location)?; - } else { - let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); - self.unify(list, iter.custom.unwrap(), &iter.location)?; - } - } + ast::StmtKind::For { .. } => {} ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => { self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?; }