diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 13f9b5e0..37316f76 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -642,7 +642,6 @@ impl TopLevelComposer { let have_unique_fuction_parameter_name = args.args.iter().all(|x| { defined_paramter_name.insert(x.node.arg.clone()) && !keyword_list.contains(&x.node.arg) - && "self" != x.node.arg }); if !have_unique_fuction_parameter_name { return Err("top level function must have unique parameter names \ @@ -837,57 +836,63 @@ impl TopLevelComposer { let mut defined_paramter_name: HashSet = HashSet::new(); let have_unique_fuction_parameter_name = args.args.iter().all(|x| { defined_paramter_name.insert(x.node.arg.clone()) - && !keyword_list.contains(&x.node.arg) + && (!keyword_list.contains(&x.node.arg) || x.node.arg == "self") }); if !have_unique_fuction_parameter_name { return Err("class method must have unique parameter names \ and names thould not be the same as the keywords" .into()); } + if name == "__init__" && !defined_paramter_name.contains("self") { + return Err("__init__ function must have a `self` parameter".into()); + } let mut result = Vec::new(); for x in &args.args { let name = x.node.arg.clone(); - let type_ann = { - let annotation_expr = x - .node - .annotation - .as_ref() - .ok_or_else(|| "type annotation needed".to_string())? - .as_ref(); - parse_ast_to_type_annotation_kinds( - class_resolver.as_ref(), - temp_def_list, - unifier, - primitives, - annotation_expr, - )? - }; - // find type vars within this method parameter type annotation - let type_vars_within = get_type_var_contained_in_type_annotation(&type_ann); - // handle the class type var and the method type var - for type_var_within in type_vars_within { - if let TypeAnnotation::TypeVarKind(ty) = type_var_within { - let id = Self::get_var_id(ty, unifier)?; - if let Some(prev_ty) = method_var_map.insert(id, ty) { - // if already in the list, make sure they are the same? - assert_eq!(prev_ty, ty); + if name != "self" { + let type_ann = { + let annotation_expr = x + .node + .annotation + .as_ref() + .ok_or_else(|| "type annotation needed".to_string())? + .as_ref(); + parse_ast_to_type_annotation_kinds( + class_resolver.as_ref(), + temp_def_list, + unifier, + primitives, + annotation_expr, + )? + }; + // find type vars within this method parameter type annotation + let type_vars_within = + get_type_var_contained_in_type_annotation(&type_ann); + // handle the class type var and the method type var + for type_var_within in type_vars_within { + if let TypeAnnotation::TypeVarKind(ty) = type_var_within { + let id = Self::get_var_id(ty, unifier)?; + if let Some(prev_ty) = method_var_map.insert(id, ty) { + // if already in the list, make sure they are the same? + assert_eq!(prev_ty, ty); + } + } else { + unreachable!("must be type var annotation"); } - } else { - unreachable!("must be type var annotation"); } + // finish handling type vars + let dummy_func_arg = FuncArg { + name, + ty: unifier.get_fresh_var().0, + // TODO: symbol default value? + default_value: None, + }; + // push the dummy type and the type annotation + // into the list for later unification + type_var_to_concrete_def.insert(dummy_func_arg.ty, type_ann.clone()); + result.push(dummy_func_arg) } - // finish handling type vars - let dummy_func_arg = FuncArg { - name, - ty: unifier.get_fresh_var().0, - // TODO: symbol default value? - default_value: None, - }; - // push the dummy type and the type annotation - // into the list for later unification - type_var_to_concrete_def.insert(dummy_func_arg.ty, type_ann.clone()); - result.push(dummy_func_arg) } result }; diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index e210d002..f8c6cbde 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -156,7 +156,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s vec![ indoc! {" class A(): - def __init__(): + def __init__(self): self.a: int32 = 3 def fun(b: B): pass @@ -165,12 +165,12 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s "}, indoc! {" class B(C): - def __init__(): + def __init__(self): pass "}, indoc! {" class C(A): - def __init__(): + def __init__(self): pass def fun(b: B): a = 1 @@ -273,16 +273,16 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s vec![ indoc! {" class Generic_A(Generic[V], B): - def __init__(): + def __init__(self): self.a: int64 = 123123123123 - def fun(a: int32) -> V: + def fun(self, a: int32) -> V: pass "}, indoc! {" class B: - def __init__(): + def __init__(self): self.aa: bool = False - def foo(b: T): + def foo(self, b: T): pass "} ], @@ -343,7 +343,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s "}, indoc! {" class A(Generic[T, V]): - def __init__(v: V): + def __init__(self, v: V): self.a: T = 1 self.b: V = v def fun(a: T) -> V: @@ -355,7 +355,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s "}, indoc! {" class B: - def __init__(): + def __init__(self): pass "} ],