For loop type inference fix #128

Merged
sb10q merged 2 commits from for_loop_type_fix into master 2021-12-19 18:01:49 +08:00
2 changed files with 33 additions and 11 deletions

View File

@ -24,11 +24,13 @@ impl<'a> Inferencer<'a> {
if !defined_identifiers.contains(id) { if !defined_identifiers.contains(id) {
defined_identifiers.insert(*id); defined_identifiers.insert(*id);
} }
self.check_expr(pattern, defined_identifiers)?;
self.should_have_value(pattern)?; self.should_have_value(pattern)?;
Ok(()) Ok(())
} }
ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
for elt in elts.iter() { for elt in elts.iter() {
self.check_expr(pattern, defined_identifiers)?;
self.check_pattern(elt, defined_identifiers)?; self.check_pattern(elt, defined_identifiers)?;
self.should_have_value(elt)?; self.should_have_value(elt)?;
} }

View File

@ -122,9 +122,36 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
}, },
} }
} }
ast::StmtKind::For { ref target, .. } => { ast::StmtKind::For { target, iter, body, orelse, config_comment, type_comment } => {
self.infer_pattern(target)?; self.infer_pattern(&target)?;
fold::fold_stmt(self, node)? let target = self.fold_expr(*target)?;
let iter = self.fold_expr(*iter)?;
if self.unifier.unioned(iter.custom.unwrap(), self.primitives.range) {
self.unify(self.primitives.int32, target.custom.unwrap(), &target.location)?;
} else {
let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
self.unify(list, iter.custom.unwrap(), &iter.location)?;
}
let body = body
.into_iter()
.map(|b| self.fold_stmt(b))
.collect::<Result<Vec<_>, _>>()?;
let orelse = orelse
.into_iter()
.map(|o| self.fold_stmt(o))
.collect::<Result<Vec<_>, _>>()?;
Located {
location: node.location,
node: ast::StmtKind::For {
target: Box::new(target),
iter: Box::new(iter),
body,
orelse,
config_comment,
type_comment,
},
custom: None
}
} }
ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => { ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => {
for target in targets.iter_mut() { for target in targets.iter_mut() {
@ -201,14 +228,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
_ => fold::fold_stmt(self, node)?, _ => fold::fold_stmt(self, node)?,
}; };
match &stmt.node { match &stmt.node {
ast::StmtKind::For { target, iter, .. } => { ast::StmtKind::For { .. } => {}
if self.unifier.unioned(iter.custom.unwrap(), self.primitives.range) {
self.unify(self.primitives.int32, target.custom.unwrap(), &target.location)?;
} else {
let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
self.unify(list, iter.custom.unwrap(), &iter.location)?;
}
}
ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => { ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => {
self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?; self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?;
} }