diff --git a/nac3core/src/typecheck/type_inferencer.rs b/nac3core/src/typecheck/type_inferencer.rs index 18cb87e7..728bb7cc 100644 --- a/nac3core/src/typecheck/type_inferencer.rs +++ b/nac3core/src/typecheck/type_inferencer.rs @@ -30,6 +30,15 @@ pub struct Inferencer<'a> { primitives: &'a PrimitiveStore, } +struct NaiveFolder(); +impl fold::Fold<()> for NaiveFolder { + type TargetU = Option; + type Error = String; + fn map_user(&mut self, _: ()) -> Result { + Ok(None) + } +} + impl<'a> fold::Fold<()> for Inferencer<'a> { type TargetU = Option; type Error = String; @@ -38,6 +47,66 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { Ok(None) } + fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result, Self::Error> { + let stmt = match node.node { + // we don't want fold over type annotation + ast::StmtKind::AnnAssign { + target, + annotation, + value, + simple, + } => { + let target = Box::new(fold::fold_expr(self, *target)?); + let value = if let Some(v) = value { + let ty = Box::new(fold::fold_expr(self, *v)?); + self.unifier + .unify(target.custom.unwrap(), ty.custom.unwrap())?; + Some(ty) + } else { + None + }; + let annotation_type = self + .resolver + .parse_type_name(annotation.as_ref()) + .ok_or_else(|| "cannot parse type name".to_string())?; + self.unifier.unify(annotation_type, target.custom.unwrap())?; + let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?); + Located { + location: node.location, + custom: None, + node: ast::StmtKind::AnnAssign { + target, + annotation, + value, + simple, + }, + } + } + _ => fold::fold_stmt(self, node)?, + }; + match &stmt.node { + ast::StmtKind::For { target, iter, .. } => { + let list = self.unifier.add_ty(TypeEnum::TList { + ty: target.custom.unwrap(), + }); + self.unifier.unify(list, iter.custom.unwrap())?; + } + ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => { + self.unifier + .unify(test.custom.unwrap(), self.primitives.bool)?; + } + ast::StmtKind::Assign { targets, value, .. } => { + for target in targets.iter() { + self.unifier + .unify(target.custom.unwrap(), value.custom.unwrap())?; + } + } + ast::StmtKind::AnnAssign { .. } => {} + _ => return Err("Unsupported statement type".to_string()) + }; + Ok(stmt) + } + fn fold_expr(&mut self, node: ast::Expr<()>) -> Result, Self::Error> { let expr = match node.node { ast::ExprKind::Call {