diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 624182b6..45ef892d 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -31,6 +31,7 @@ pub struct Inferencer<'a> { pub variable_mapping: HashMap, pub calls: &'a mut Vec>, pub primitives: &'a PrimitiveStore, + pub return_type: Option } struct NaiveFolder(); @@ -106,6 +107,20 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } } ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {} + ast::StmtKind::Return { value } => { + match (value, self.return_type) { + (Some(v), Some(v1)) => { + self.unifier.unify(v.custom.unwrap(), v1)?; + } + (Some(_), None) => { + return Err("Unexpected return value".to_string()); + } + (None, Some(_)) => { + return Err("Expected return value".to_string()); + } + (None, None) => {} + } + } _ => return Err("Unsupported statement type".to_string()), }; Ok(stmt) @@ -227,6 +242,7 @@ impl<'a> Inferencer<'a> { variable_mapping, calls: self.calls, primitives: self.primitives, + return_type: self.return_type }; let fun = FunSignature { args: fn_args @@ -275,6 +291,7 @@ impl<'a> Inferencer<'a> { variable_mapping, calls: self.calls, primitives: self.primitives, + return_type: self.return_type }; let elt = new_context.fold_expr(elt)?; let generator = generators.pop().unwrap(); diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 3e85c245..c85408f7 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -123,6 +123,7 @@ impl TestEnvironment { variable_mapping: Default::default(), calls: &mut self.calls, primitives: &mut self.primitives, + return_type: None } } } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 9a355c33..421f58ce 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -177,7 +177,7 @@ impl Unifier { } TypeEnum::TObj { obj_id, params, .. } => { let name = obj_to_name(*obj_id); - if params.len() > 0 { + if !params.is_empty() { let mut params = params .values() .map(|v| self.stringify(*v, obj_to_name, var_to_name));