From 799ed58d2179324e13f609850c8ebcf906f8f1f8 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Wed, 22 Sep 2021 19:25:47 +0800 Subject: [PATCH] nac3core/type_inferencer: avoid type var for assign --- nac3core/src/typecheck/type_inferencer/mod.rs | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 110acf42..af7008ad 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -107,6 +107,50 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { fold::fold_stmt(self, node)? } ast::StmtKind::Assign { ref targets, .. } => { + if targets.iter().all(|t| matches!(t.node, ast::ExprKind::Name { .. })) { + if let ast::StmtKind::Assign { targets, value, .. } = node.node { + let value = self.fold_expr(*value)?; + let value_ty = value.custom.unwrap(); + let targets: Result, _> = targets.into_iter().map(|target| { + if let ast::ExprKind::Name { id, ctx } = target.node { + self.defined_identifiers.insert(id); + let target_ty = if let Some(ty) = self.variable_mapping.get(&id) { + *ty + } else { + let unifier = &mut self.unifier; + self + .function_data + .resolver + .get_symbol_type(unifier, self.primitives, id) + .unwrap_or_else(|| { + self.variable_mapping.insert(id, value_ty); + value_ty + }) + }; + let location = target.location; + self.unifier.unify(value_ty, target_ty).map(|_| Located { + location, + node: ast::ExprKind::Name { id, ctx }, + custom: Some(target_ty) + }) + } else { + unreachable!() + } + }).collect(); + let targets = targets?; + return Ok(Located { + location: node.location, + node: ast::StmtKind::Assign { + targets, + value: Box::new(value), + type_comment: None, + }, + custom: None + }); + } else { + unreachable!() + } + } for target in targets { self.infer_pattern(target)?; }