diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index fbed5d3..86f9f42 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -678,45 +678,47 @@ impl TopLevelComposer { .collect::, _>>()? }; - let return_ty_annotation = { - let return_annotation = returns - .as_ref() - .ok_or_else(|| "function return type needed".to_string())? - .as_ref(); - parse_ast_to_type_annotation_kinds( - resolver, - &temp_def_list, - unifier, - primitives_store, - return_annotation, - )? - }; + let return_ty = { + if let Some(returns) = returns { + let return_ty_annotation = { + let return_annotation = returns.as_ref(); + parse_ast_to_type_annotation_kinds( + resolver, + &temp_def_list, + unifier, + primitives_store, + return_annotation, + )? + }; - let type_vars_within = - get_type_var_contained_in_type_annotation(&return_ty_annotation) - .into_iter() - .map(|x| -> Result<(u32, Type), String> { - if let TypeAnnotation::TypeVarKind(ty) = x { - Ok((Self::get_var_id(ty, unifier)?, ty)) - } else { - unreachable!("must be type var here") + let type_vars_within = + get_type_var_contained_in_type_annotation(&return_ty_annotation) + .into_iter() + .map(|x| -> Result<(u32, Type), String> { + if let TypeAnnotation::TypeVarKind(ty) = x { + Ok((Self::get_var_id(ty, unifier)?, ty)) + } else { + unreachable!("must be type var here") + } + }) + .collect::, _>>()?; + for (id, ty) in type_vars_within { + if let Some(prev_ty) = function_var_map.insert(id, ty) { + // if already have the type inserted, make sure they are the same thing + assert_eq!(prev_ty, ty); } - }) - .collect::, _>>()?; - for (id, ty) in type_vars_within { - if let Some(prev_ty) = function_var_map.insert(id, ty) { - // if already have the type inserted, make sure they are the same thing - assert_eq!(prev_ty, ty); + } + + get_type_from_type_annotation_kinds( + &temp_def_list, + unifier, + primitives_store, + &return_ty_annotation, + )? + } else { + primitives_store.none } - } - - let return_ty = get_type_from_type_annotation_kinds( - &temp_def_list, - unifier, - primitives_store, - &return_ty_annotation, - )?; - + }; let function_ty = unifier.add_ty(TypeEnum::TFunc( FunSignature { args: arg_types, ret: return_ty, vars: function_var_map } .into(), @@ -883,37 +885,45 @@ impl TopLevelComposer { let ret_type = { if name != "__init__" { - let result = returns - .as_ref() - .ok_or_else(|| "method return type annotation needed".to_string())? - .as_ref(); - let annotation = parse_ast_to_type_annotation_kinds( - class_resolver, - temp_def_list, - unifier, - primitives, - result, - )?; + if let Some(result) = returns { + let result = result.as_ref(); + let annotation = parse_ast_to_type_annotation_kinds( + class_resolver, + temp_def_list, + unifier, + primitives, + result, + )?; - // find type vars within this return type annotation - let type_vars_within = - get_type_var_contained_in_type_annotation(&annotation); - // handle the class type var and the method type var - for type_var_within in type_vars_within { - if let TypeAnnotation::TypeVarKind(ty) = type_var_within { - let id = Self::get_var_id(ty, unifier)?; - if let Some(prev_ty) = method_var_map.insert(id, ty) { - // if already in the list, make sure they are the same? - assert_eq!(prev_ty, ty); + // find type vars within this return type annotation + let type_vars_within = + get_type_var_contained_in_type_annotation(&annotation); + // handle the class type var and the method type var + for type_var_within in type_vars_within { + if let TypeAnnotation::TypeVarKind(ty) = type_var_within { + let id = Self::get_var_id(ty, unifier)?; + if let Some(prev_ty) = method_var_map.insert(id, ty) { + // if already in the list, make sure they are the same? + assert_eq!(prev_ty, ty); + } + } else { + unreachable!("must be type var annotation"); } - } else { - unreachable!("must be type var annotation"); } - } - let dummy_return_type = unifier.get_fresh_var().0; - type_var_to_concrete_def.insert(dummy_return_type, annotation.clone()); - dummy_return_type + let dummy_return_type = unifier.get_fresh_var().0; + type_var_to_concrete_def.insert(dummy_return_type, annotation.clone()); + dummy_return_type + } else { + // if do not have return annotation, return none + // for uniform handling, still use type annoatation + let dummy_return_type = unifier.get_fresh_var().0; + type_var_to_concrete_def.insert( + dummy_return_type, + TypeAnnotation::PrimitiveKind(primitives.none), + ); + dummy_return_type + } } else { // if is the "__init__" function, the return type is self let dummy_return_type = unifier.get_fresh_var().0; diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index df36ca9..a270eb0 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -56,6 +56,10 @@ impl SymbolResolver for Resolver { def fun(self): self.b = self.b + 3.0 + "}, + indoc! {" + def foo(a: float): + a + 1.0 "} ] )]