diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index ede7f77d3..286e0eed4 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -136,6 +136,8 @@ impl TopLevelComposer { "none".into(), "None".into(), "self".into(), + "Kernel".into(), + "KernelImmutable".into() ]), defined_class_method_name: Default::default(), defined_class_name: Default::default(), @@ -825,9 +827,10 @@ impl TopLevelComposer { }; let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.as_ref(); - + + let mut defined_fields: HashSet = HashSet::new(); for b in class_body_ast { - if let ast::StmtKind::FunctionDef { args, returns, name, body, .. } = &b.node { + if let ast::StmtKind::FunctionDef { args, returns, name, .. } = &b.node { let (method_dummy_ty, method_id) = Self::get_class_method_def_info(class_methods_def, name)?; @@ -968,61 +971,54 @@ impl TopLevelComposer { // NOTE: unify now since function type is not in type annotation define // which is fine since type within method_type will be subst later unifier.unify(method_dummy_ty, method_type)?; + } else if let ast::StmtKind::AnnAssign { target, annotation, value: None, .. } = &b.node { + if let ast::ExprKind::Name { id: attr, .. } = &target.node { + if defined_fields.insert(attr.to_string()) { + let dummy_field_type = unifier.get_fresh_var().0; + class_fields_def.push((attr.to_string(), dummy_field_type)); - // class fields - if name == "__init__" { - for b in body { - let mut defined_fields: HashSet = HashSet::new(); - // TODO: check the type of value, field instantiation check? - if let ast::StmtKind::AnnAssign { annotation, target, value: _, .. } = - &b.node - { - if let ast::ExprKind::Attribute { value, attr, .. } = &target.node { - if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "self") - { - if defined_fields.insert(attr.to_string()) { - let dummy_field_type = unifier.get_fresh_var().0; - class_fields_def.push((attr.to_string(), dummy_field_type)); + // handle Kernel[T], KernelImmutable[T] + let annotation = { + match &annotation.as_ref().node { + ast::ExprKind::Subscript { value, slice, .. } if { + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == "Kernel" || id == "KernelImmutable") + } => slice, + _ => annotation + } + }; - let annotation = parse_ast_to_type_annotation_kinds( - class_resolver.as_ref(), - &temp_def_list, - unifier, - primitives, - annotation.as_ref(), - vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), - )?; - - // find type vars within this return type annotation - let type_vars_within = - get_type_var_contained_in_type_annotation(&annotation); - // handle the class type var and the method type var - for type_var_within in type_vars_within { - if let TypeAnnotation::TypeVarKind(t) = type_var_within - { - if !class_type_vars_def.contains(&t) { - return Err("class fields can only use type \ - vars declared as class generic type vars" - .into()); - } - } else { - unreachable!("must be type var annotation"); - } - } - - // TODO: allow class have field which type refers to Self type? - type_var_to_concrete_def - .insert(dummy_field_type, annotation); - } else { - return Err("same class fields defined twice".into()); - } + let annotation = parse_ast_to_type_annotation_kinds( + class_resolver.as_ref(), + &temp_def_list, + unifier, + primitives, + annotation.as_ref(), + vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), + )?; + // find type vars within this return type annotation + let type_vars_within = + get_type_var_contained_in_type_annotation(&annotation); + // handle the class type var and the method type var + for type_var_within in type_vars_within { + if let TypeAnnotation::TypeVarKind(t) = type_var_within + { + if !class_type_vars_def.contains(&t) { + return Err("class fields can only use type \ + vars declared as class generic type vars" + .into()); } + } else { + unreachable!("must be type var annotation"); } } + type_var_to_concrete_def + .insert(dummy_field_type, annotation); + } else { + return Err("same class fields defined twice".into()); } } } else { - continue; + return Err("unsupported statement type in class definition body".into()) } } Ok(()) diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 22bdb313b..973888855 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -156,8 +156,9 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s vec![ indoc! {" class A(): + a: int32 def __init__(self): - self.a: int32 = 3 + self.a = 3 def fun(self, b: B): pass def foo(self, a: T, b: V): @@ -273,15 +274,17 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s vec![ indoc! {" class Generic_A(Generic[V], B): + a: int64 def __init__(self): - self.a: int64 = 123123123123 + self.a = 123123123123 def fun(self, a: int32) -> V: pass "}, indoc! {" class B: + aa: bool def __init__(self): - self.aa: bool = False + self.aa = False def foo(self, b: T): pass "} @@ -343,9 +346,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s "}, indoc! {" class A(Generic[T, V]): + a: T + b: V def __init__(self, v: V): - self.a: T = 1 - self.b: V = v + self.a = 1 + self.b = v def fun(self, a: T) -> V: pass "}, @@ -418,9 +423,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s vec![ indoc! {" class A(Generic[T, V]): + a: A[float, bool] + b: B def __init__(self, a: A[float, bool], b: B): - self.a: A[float, bool] = a - self.b: B = b + self.a = a + self.b = b def fun(self, a: A[float, bool]) -> A[bool, int32]: pass "}, @@ -578,7 +585,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s vec!["a class def can only have at most one base class declaration and one generic declaration"]; "err multiple inheritance" )] -fn test_simple_class_analyze(source: Vec<&str>, res: Vec<&str>) { +fn test_analyze(source: Vec<&str>, res: Vec<&str>) { let print = false; let mut composer = TopLevelComposer::new();