diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 37316f76..ede7f77d 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -165,6 +165,7 @@ impl TopLevelComposer { ast: ast::Stmt<()>, resolver: Option>>, ) -> Result<(String, DefinitionId), String> { + // FIXME: different module same name? let defined_class_name = &mut self.defined_class_name; let defined_class_method_name = &mut self.defined_class_method_name; let defined_function_name = &mut self.defined_function_name; @@ -423,13 +424,19 @@ impl TopLevelComposer { // skip 5 to skip analyzing the primitives for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(5) { let mut class_def = class_def.write(); - let (class_bases, class_ancestors, class_resolver) = { - if let TopLevelDef::Class { ancestors, resolver, .. } = class_def.deref_mut() { + let ( + class_def_id, + class_bases, + class_ancestors, + class_resolver, + class_type_vars + ) = { + if let TopLevelDef::Class { ancestors, resolver, object_id, type_vars, .. } = class_def.deref_mut() { if let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) = class_ast { - (bases, ancestors, resolver) + (object_id, bases, ancestors, resolver, type_vars) } else { unreachable!("must be both class") } @@ -469,6 +476,7 @@ impl TopLevelComposer { unifier, &self.primitives_ty, b, + vec![(*class_def_id, class_type_vars.clone())].into_iter().collect() )?; if let TypeAnnotation::CustomClassKind { .. } = &base_ty { @@ -667,6 +675,8 @@ impl TopLevelComposer { unifier, primitives_store, annotation, + // NOTE: since only class need this, for function, it should be fine + HashMap::new(), )?; let type_vars_within = @@ -713,6 +723,8 @@ impl TopLevelComposer { unifier, primitives_store, return_annotation, + // NOTE: since only class need this, for function, it should be fine + HashMap::new(), )? }; @@ -774,7 +786,7 @@ impl TopLevelComposer { ) -> Result<(), String> { let mut class_def = class_def.write(); let ( - _class_id, + class_id, _class_name, _class_bases_ast, class_body_ast, @@ -809,7 +821,7 @@ impl TopLevelComposer { unreachable!("here must be class def ast"); } } else { - unreachable!("here must be class def ast"); + unreachable!("here must be toplevel class def"); }; let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.as_ref(); @@ -846,6 +858,9 @@ impl TopLevelComposer { if name == "__init__" && !defined_paramter_name.contains("self") { return Err("__init__ function must have a `self` parameter".into()); } + if !defined_paramter_name.contains("self") { + return Err("currently does not support static method".into()) + } let mut result = Vec::new(); for x in &args.args { @@ -864,6 +879,7 @@ impl TopLevelComposer { unifier, primitives, annotation_expr, + vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), )? }; // find type vars within this method parameter type annotation @@ -885,7 +901,7 @@ impl TopLevelComposer { let dummy_func_arg = FuncArg { name, ty: unifier.get_fresh_var().0, - // TODO: symbol default value? + // TODO: default value? default_value: None, }; // push the dummy type and the type annotation @@ -906,6 +922,7 @@ impl TopLevelComposer { unifier, primitives, result, + vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), )?; // find type vars within this return type annotation let type_vars_within = @@ -973,6 +990,7 @@ impl TopLevelComposer { unifier, primitives, annotation.as_ref(), + vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), )?; // find type vars within this return type annotation diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index ed9b61ed..22bdb313 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -158,9 +158,9 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s class A(): def __init__(self): self.a: int32 = 3 - def fun(b: B): + def fun(self, b: B): pass - def foo(a: T, b: V): + def foo(self, a: T, b: V): pass "}, indoc! {" @@ -172,7 +172,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s class C(A): def __init__(self): pass - def fun(b: B): + def fun(self, b: B): a = 1 pass "}, @@ -346,7 +346,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s def __init__(self, v: V): self.a: T = 1 self.b: V = v - def fun(a: T) -> V: + def fun(self, a: T) -> V: pass "}, indoc! {" @@ -414,6 +414,131 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s ]; "list tuple generic" )] +#[test_case( + vec![ + indoc! {" + class A(Generic[T, V]): + def __init__(self, a: A[float, bool], b: B): + self.a: A[float, bool] = a + self.b: B = b + def fun(self, a: A[float, bool]) -> A[bool, int32]: + pass + "}, + indoc! {" + class B(A[int64, bool]): + def __init__(self): + pass + def foo(self, b: B) -> B: + pass + def bar(self, a: A[list[B], int32]) -> tuple[A[virtual[A[B, int32]], bool], B]: + pass + "} + ], + vec![ + indoc! {"5: Class { + name: \"A\", + def_id: DefinitionId(5), + ancestors: [CustomClassKind { id: DefinitionId(5), params: [TypeVarKind(UnificationKey(100)), TypeVarKind(UnificationKey(101))] }], + fields: [(\"a\", \"class5[2->class2, 3->class3]\"), (\"b\", \"class9\")], + methods: [(\"__init__\", \"fn[[a=class5[2->class2, 3->class3], b=class9], class4]\", DefinitionId(6)), (\"fun\", \"fn[[a=class5[2->class2, 3->class3]], class5[2->class3, 3->class0]]\", DefinitionId(7))], + type_vars: [UnificationKey(100), UnificationKey(101)] + }"}, + + indoc! {"6: Function { + name: \"A__init__\", + sig: \"fn[[a=class5[2->class2, 3->class3], b=class9], class4]\", + var_id: [2, 3] + }"}, + + indoc! {"7: Function { + name: \"Afun\", + sig: \"fn[[a=class5[2->class2, 3->class3]], class5[2->class3, 3->class0]]\", + var_id: [2, 3] + }"}, + + indoc! {"8: Initializer { DefinitionId(5) }"}, + + indoc! {"9: Class { + name: \"B\", + def_id: DefinitionId(9), + ancestors: [CustomClassKind { id: DefinitionId(9), params: [] }, CustomClassKind { id: DefinitionId(5), params: [PrimitiveKind(UnificationKey(1)), PrimitiveKind(UnificationKey(3))] }], + fields: [(\"a\", \"class5[2->class2, 3->class3]\"), (\"b\", \"class9\")], + methods: [(\"__init__\", \"fn[[], class4]\", DefinitionId(10)), (\"fun\", \"fn[[a=class5[2->class2, 3->class3]], class5[2->class3, 3->class0]]\", DefinitionId(7)), (\"foo\", \"fn[[b=class9], class9]\", DefinitionId(11)), (\"bar\", \"fn[[a=class5[2->list[class9], 3->class0]], tuple[class5[2->virtual[class5[2->class9, 3->class0]], 3->class3], class9]]\", DefinitionId(12))], + type_vars: [] + }"}, + + indoc! {"10: Function { + name: \"B__init__\", + sig: \"fn[[], class4]\", + var_id: [] + }"}, + + indoc! {"11: Function { + name: \"Bfoo\", + sig: \"fn[[b=class9], class9]\", + var_id: [] + }"}, + + indoc! {"12: Function { + name: \"Bbar\", + sig: \"fn[[a=class5[2->list[class9], 3->class0]], tuple[class5[2->virtual[class5[2->class9, 3->class0]], 3->class3], class9]]\", + var_id: [] + }"}, + + indoc! {"13: Initializer { DefinitionId(9) }"}, + ]; + "self1" +)] +#[test_case( + vec![ + indoc! {" + class A(Generic[T]): + def __init__(self): + pass + def fun(self, a: A[T]) -> A[T]: + pass + "} + ], + vec!["application of type vars to generic class is not currently supported"]; + "err no type var in generic app" +)] +#[test_case( + vec![ + indoc! {" + class A(B): + def __init__(self): + pass + "}, + indoc! {" + class B(A): + def __init__(self): + pass + "} + ], + vec!["cyclic inheritance detected"]; + "cyclic1" +)] +#[test_case( + vec![ + indoc! {" + class A(B[bool, int64]): + def __init__(self): + pass + "}, + indoc! {" + class B(Generic[V, T], C[int32]): + def __init__(self): + pass + "}, + indoc! {" + class C(Generic[T], A): + def __init__(self): + pass + "}, + ], + vec!["cyclic inheritance detected"]; + "cyclic2" +)] #[test_case( vec![indoc! {" class A: @@ -535,5 +660,4 @@ fn test_simple_class_analyze(source: Vec<&str>, res: Vec<&str>) { } } } - } diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 69abf29f..54859ca6 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -26,6 +26,8 @@ pub fn parse_ast_to_type_annotation_kinds( unifier: &mut Unifier, primitives: &PrimitiveStore, expr: &ast::Expr, + // the key stores the type_var of this topleveldef::class, we only need this field here + mut locked: HashMap> ) -> Result { match &expr.node { ast::ExprKind::Name { id, .. } => match id.as_str() { @@ -36,19 +38,26 @@ pub fn parse_ast_to_type_annotation_kinds( "None" => Ok(TypeAnnotation::PrimitiveKind(primitives.none)), x => { if let Some(obj_id) = resolver.get_identifier_def(x) { - let def = top_level_defs[obj_id.0].read(); - if let TopLevelDef::Class { type_vars, .. } = &*def { - // also check param number here - if !type_vars.is_empty() { - return Err(format!( - "expect {} type variable parameter but got 0", - type_vars.len() - )); + let type_vars = { + let def_read = top_level_defs[obj_id.0].try_read(); + if let Some(def_read) = def_read { + if let TopLevelDef::Class { type_vars, .. } = &*def_read { + type_vars.clone() + } else { + return Err("function cannot be used as a type".into()); + } + } else { + locked.get(&obj_id).unwrap().clone() } - Ok(TypeAnnotation::CustomClassKind { id: obj_id, params: vec![] }) - } else { - Err("function cannot be used as a type".into()) + }; + // check param number here + if !type_vars.is_empty() { + return Err(format!( + "expect {} type variable parameter but got 0", + type_vars.len() + )); } + Ok(TypeAnnotation::CustomClassKind { id: obj_id, params: vec![] }) } else if let Some(ty) = resolver.get_symbol_type(unifier, primitives, id) { if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { Ok(TypeAnnotation::TypeVarKind(ty)) @@ -71,6 +80,7 @@ pub fn parse_ast_to_type_annotation_kinds( unifier, primitives, slice.as_ref(), + locked )?; if !matches!(def, TypeAnnotation::CustomClassKind { .. }) { unreachable!("must be concretized custom class kind in the virtual") @@ -88,6 +98,7 @@ pub fn parse_ast_to_type_annotation_kinds( unifier, primitives, slice.as_ref(), + locked )?; Ok(TypeAnnotation::ListKind(def_ann.into())) } @@ -106,6 +117,7 @@ pub fn parse_ast_to_type_annotation_kinds( unifier, primitives, e, + locked.clone() ) }) .collect::, _>>()?; @@ -124,53 +136,63 @@ pub fn parse_ast_to_type_annotation_kinds( let obj_id = resolver .get_identifier_def(id) .ok_or_else(|| "unknown class name".to_string())?; - let def = top_level_defs[obj_id.0].read(); - if let TopLevelDef::Class { type_vars, .. } = &*def { - // we do not check whether the application of type variables are compatible here - let param_type_infos = { - let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node { - elts.iter().collect_vec() + let type_vars = { + let def_read = top_level_defs[obj_id.0].try_read(); + if let Some(def_read) = def_read { + if let TopLevelDef::Class { type_vars, .. } = &*def_read { + type_vars.clone() } else { - vec![slice.as_ref()] - }; - if type_vars.len() != params_ast.len() { - return Err(format!( - "expect {} type parameters but got {}", - type_vars.len(), - params_ast.len() - )); - } - let result = params_ast - .into_iter() - .map(|x| { - parse_ast_to_type_annotation_kinds( - resolver, - top_level_defs, - unifier, - primitives, - x, - ) - }) - .collect::, _>>()?; - - // make sure the result do not contain any type vars - let no_type_var = result - .iter() - .all(|x| get_type_var_contained_in_type_annotation(x).is_empty()); - if no_type_var { - result - } else { - return Err("application of type vars to generic class \ - is not currently supported" - .into()); + unreachable!("must be class here") } + } else { + locked.get(&obj_id).unwrap().clone() + } + }; + // we do not check whether the application of type variables are compatible here + let param_type_infos = { + let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node { + elts.iter().collect_vec() + } else { + vec![slice.as_ref()] }; + if type_vars.len() != params_ast.len() { + return Err(format!( + "expect {} type parameters but got {}", + type_vars.len(), + params_ast.len() + )); + } + let result = params_ast + .into_iter() + .map(|x| { + parse_ast_to_type_annotation_kinds( + resolver, + top_level_defs, + unifier, + primitives, + x, + { + locked.insert(obj_id, type_vars.clone()); + locked.clone() + } + ) + }) + .collect::, _>>()?; - // allow type var in class generic application list - Ok(TypeAnnotation::CustomClassKind { id: obj_id, params: param_type_infos }) - } else { - Err("function cannot be used as a type".into()) - } + // make sure the result do not contain any type vars + let no_type_var = result + .iter() + .all(|x| get_type_var_contained_in_type_annotation(x).is_empty()); + if no_type_var { + result + } else { + return Err("application of type vars to generic class \ + is not currently supported" + .into()); + } + }; + + Ok(TypeAnnotation::CustomClassKind { id: obj_id, params: param_type_infos }) } else { Err("unsupported expression type for class name".into()) } @@ -180,6 +202,9 @@ pub fn parse_ast_to_type_annotation_kinds( } } +// no need to have the `locked` parameter, unlike the `parse_ast_to_type_annotation_kinds`, since +// when calling this function, there should be no topleveldef::class being write, and this function +// also only read the toplevedefs pub fn get_type_from_type_annotation_kinds( top_level_defs: &[Arc>], unifier: &mut Unifier, @@ -187,9 +212,10 @@ pub fn get_type_from_type_annotation_kinds( ann: &TypeAnnotation, ) -> Result { match ann { - TypeAnnotation::CustomClassKind { id, params } => { - let class_def = top_level_defs[id.0].read(); - if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*class_def { + TypeAnnotation::CustomClassKind { id: obj_id, params } => { + let def_read = top_level_defs[obj_id.0].read(); + let class_def: &TopLevelDef = def_read.deref(); + if let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def { if type_vars.len() != params.len() { Err(format!( "unexpected number of type parameters: expected {} but got {}", @@ -210,7 +236,8 @@ pub fn get_type_from_type_annotation_kinds( .collect::, _>>()?; let subst = { - // NOTE: check for compatible range here + // check for compatible range + // TODO: if allow type var to be applied, need more check let mut result: HashMap = HashMap::new(); for (tvar, p) in type_vars.iter().zip(param_ty) { if let TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic } = @@ -265,7 +292,7 @@ pub fn get_type_from_type_annotation_kinds( // ); Ok(unifier.add_ty(TypeEnum::TObj { - obj_id: *id, + obj_id: *obj_id, fields: RefCell::new(tobj_fields), params: subst.into(), }))