diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 62a5447..0fc965d 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1,5 +1,6 @@ use std::rc::Rc; +use indexmap::IndexMap; use nac3parser::ast::{fold::Fold, ExprKind}; use super::*; @@ -439,9 +440,9 @@ impl TopLevelComposer { } } + /// Analyze the AST and modify the corresponding `TopLevelDef` pub fn start_analysis(&mut self, inference: bool) -> Result<(), HashSet> { - self.analyze_top_level_class_type_var()?; - self.analyze_top_level_class_bases()?; + self.analyze_top_level_class_definition()?; self.analyze_top_level_class_fields_methods()?; self.analyze_top_level_function()?; if inference { @@ -451,442 +452,184 @@ impl TopLevelComposer { Ok(()) } - /// step 1, analyze the type vars associated with top level class - fn analyze_top_level_class_type_var(&mut self) -> Result<(), HashSet> { + /// step 1, analyze the top level class definitions + /// + /// Checks for class type variables and ancestors adding them to the `TopLevelDef` list + fn analyze_top_level_class_definition(&mut self) -> Result<(), HashSet> { let def_list = &self.definition_ast_list; - let temp_def_list = self.extract_def_list(); let unifier = self.unifier.borrow_mut(); let primitives_store = &self.primitives_ty; + let mut errors = HashSet::new(); - let mut analyze = |class_def: &Arc>, class_ast: &Option| { - // only deal with class def here - let mut class_def = class_def.write(); - let (class_bases_ast, class_def_type_vars, class_resolver) = { - if let TopLevelDef::Class { type_vars, resolver, .. } = &mut *class_def { - let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) = - class_ast + // Initially only copy the definitions of buitin classes and functions + // class definitions are added in the same order as they appear in the program + let mut temp_def_list: Vec>> = + def_list.iter().take(self.builtin_num).map(|f| f.0.clone()).collect_vec(); + + // Check for class generic variables and ancestors + for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) { + if class_ast.is_some() && matches!(&*class_def.read(), TopLevelDef::Class { .. }) { + // Add class type variables and direct parents to the `TopLevelDef` + if let Err(e) = Self::analyze_class_bases( + class_def, + class_ast, + &temp_def_list, + unifier, + primitives_store, + ) { + errors.extend(e); + } + + // Add class ancestors + Self::analyze_class_ancestors(class_def, &temp_def_list); + + // special case classes that inherit from Exception + let TopLevelDef::Class { ancestors: class_ancestors, .. } = &*class_def.read() + else { + unreachable!() + }; + + if class_ancestors + .iter() + .any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) + { + // if inherited from Exception, the body should be a pass + let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node else { unreachable!() }; - - (bases, type_vars, resolver) - } else { - return Ok(()); - } - }; - let class_resolver = class_resolver.as_ref().unwrap(); - let class_resolver = &**class_resolver; - - let mut is_generic = false; - for b in class_bases_ast { - match &b.node { - // analyze typevars bounded to the class, - // only support things like `class A(Generic[T, V])`, - // things like `class A(Generic[T, V, ImportedModule.T])` is not supported - // i.e. only simple names are allowed in the subscript - // should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params - ast::ExprKind::Subscript { value, slice, .. } - if { - matches!( - &value.node, - ast::ExprKind::Name { id, .. } if id == &"Generic".into() - ) - } => - { - if is_generic { - return Err(HashSet::from([format!( - "only single Generic[...] is allowed (at {})", - b.location - )])); + for stmt in body { + if matches!( + stmt.node, + ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. } + ) { + errors.extend(Err(HashSet::from(["Classes inherited from exception should have no custom fields/methods"]))); } - is_generic = true; - - let type_var_list: Vec<&ast::Expr<()>>; - // if `class A(Generic[T, V, G])` - if let ast::ExprKind::Tuple { elts, .. } = &slice.node { - type_var_list = elts.iter().collect_vec(); - // `class A(Generic[T])` - } else { - type_var_list = vec![&**slice]; - } - - // parse the type vars - let type_vars = type_var_list - .into_iter() - .map(|e| { - class_resolver.parse_type_annotation( - &temp_def_list, - unifier, - primitives_store, - e, - ) - }) - .collect::, _>>()?; - - // check if all are unique type vars - let all_unique_type_var = { - let mut occurred_type_var_id: HashSet = HashSet::new(); - type_vars.iter().all(|x| { - let ty = unifier.get_ty(*x); - if let TypeEnum::TVar { id, .. } = ty.as_ref() { - occurred_type_var_id.insert(*id) - } else { - false - } - }) - }; - if !all_unique_type_var { - return Err(HashSet::from([format!( - "duplicate type variable occurs (at {})", - slice.location - )])); - } - - // add to TopLevelDef - class_def_type_vars.extend(type_vars); } - - // if others, do nothing in this function - _ => continue, } } - Ok(()) - }; - - let mut errors = HashSet::new(); - for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) { - if class_ast.is_none() { - continue; - } - if let Err(e) = analyze(class_def, class_ast) { - errors.extend(e); - } + temp_def_list.push(class_def.clone()); } + + // deal with ancestors of Exception object + let TopLevelDef::Class { name, ancestors, object_id, .. } = &mut *def_list[7].0.write() + else { + unreachable!() + }; + assert_eq!(*name, "Exception".into()); + ancestors.push(make_self_type_annotation(&[], *object_id)); + if !errors.is_empty() { return Err(errors); } Ok(()) } - /// step 2, base classes. - /// now that the type vars of all classes are done, handle base classes and - /// put Self class into the ancestors list. We only allow single inheritance - fn analyze_top_level_class_bases(&mut self) -> Result<(), HashSet> { + /// step 2, class fields and methods + fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), HashSet> { + // Allow resolving definition IDs in error messages if self.unifier.top_level.is_none() { let ctx = Arc::new(self.make_top_level_context()); self.unifier.top_level = Some(ctx); } + let def_list = &self.definition_ast_list; let temp_def_list = self.extract_def_list(); let unifier = self.unifier.borrow_mut(); - let primitive_types = self.primitives_ty; - - let mut get_direct_parents = - |class_def: &Arc>, class_ast: &Option| { - let mut class_def = class_def.write(); - let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = { - if let TopLevelDef::Class { - ancestors, resolver, object_id, type_vars, .. - } = &mut *class_def - { - let Some(ast::Located { - node: ast::StmtKind::ClassDef { bases, .. }, .. - }) = class_ast - else { - unreachable!() - }; - - (object_id, bases, ancestors, resolver, type_vars) - } else { - return Ok(()); - } - }; - let class_resolver = class_resolver.as_ref().unwrap(); - let class_resolver = &**class_resolver; - - let mut has_base = false; - for b in class_bases { - // type vars have already been handled, so skip on `Generic[...]` - if matches!( - &b.node, - ast::ExprKind::Subscript { value, .. } - if matches!( - &value.node, - ast::ExprKind::Name { id, .. } if id == &"Generic".into() - ) - ) { - continue; - } - - if has_base { - return Err(HashSet::from([format!( - "a class definition can only have at most one base class \ - declaration and one generic declaration (at {})", - b.location - )])); - } - has_base = true; - - // the function parse_ast_to make sure that no type var occurred in - // bast_ty if it is a CustomClassKind - let base_ty = parse_ast_to_type_annotation_kinds( - class_resolver, - &temp_def_list, - unifier, - &primitive_types, - b, - vec![(*class_def_id, class_type_vars.clone())] - .into_iter() - .collect::>(), - )?; - - if let TypeAnnotation::CustomClass { .. } = &base_ty { - class_ancestors.push(base_ty); - } else { - return Err(HashSet::from([format!( - "class base declaration can only be custom class (at {})", - b.location, - )])); - } - } - Ok(()) - }; - - // first, only push direct parent into the list - let mut errors = HashSet::new(); - for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) { - if class_ast.is_none() { - continue; - } - if let Err(e) = get_direct_parents(class_def, class_ast) { - errors.extend(e); - } - } - if !errors.is_empty() { - return Err(errors); - } - - // second, get all ancestors - let mut ancestors_store: HashMap> = HashMap::default(); - let mut get_all_ancestors = - |class_def: &Arc>| -> Result<(), HashSet> { - let class_def = class_def.read(); - let (class_ancestors, class_id) = { - if let TopLevelDef::Class { ancestors, object_id, .. } = &*class_def { - (ancestors, *object_id) - } else { - return Ok(()); - } - }; - ancestors_store.insert( - class_id, - // if class has direct parents, get all ancestors of its parents. Else just empty - if class_ancestors.is_empty() { - vec![] - } else { - Self::get_all_ancestors_helper( - &class_ancestors[0], - temp_def_list.as_slice(), - )? - }, - ); - Ok(()) - }; - for (class_def, ast) in self.definition_ast_list.iter().skip(self.builtin_num) { - if ast.is_none() { - continue; - } - if let Err(e) = get_all_ancestors(class_def) { - errors.extend(e); - } - } - if !errors.is_empty() { - return Err(errors); - } - - // insert the ancestors to the def list - for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) { - if class_ast.is_none() { - continue; - } - let mut class_def = class_def.write(); - let (class_ancestors, class_id, class_type_vars) = { - if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = &mut *class_def - { - (ancestors, *object_id, type_vars) - } else { - continue; - } - }; - - let ans = ancestors_store.get_mut(&class_id).unwrap(); - class_ancestors.append(ans); - - // insert self type annotation to the front of the vector to maintain the order - class_ancestors - .insert(0, make_self_type_annotation(class_type_vars.as_slice(), class_id)); - - // special case classes that inherit from Exception - if class_ancestors - .iter() - .any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) - { - // if inherited from Exception, the body should be a pass - let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node else { - unreachable!() - }; - - for stmt in body { - if matches!( - stmt.node, - ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. } - ) { - return Err(HashSet::from([ - "Classes inherited from exception should have no custom fields/methods" - .into(), - ])); - } - } - } - } - - // deal with ancestor of Exception object - let TopLevelDef::Class { name, ancestors, object_id, .. } = - &mut *self.definition_ast_list[7].0.write() - else { - unreachable!() - }; - - assert_eq!(*name, "Exception".into()); - ancestors.push(make_self_type_annotation(&[], *object_id)); - - Ok(()) - } - - /// step 3, class fields and methods - fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), HashSet> { - let temp_def_list = self.extract_def_list(); - let primitives = &self.primitives_ty; - let def_ast_list = &self.definition_ast_list; - let unifier = self.unifier.borrow_mut(); + let primitives_store = &self.primitives_ty; + let mut errors: HashSet = HashSet::new(); let mut type_var_to_concrete_def: HashMap = HashMap::new(); - let mut errors = HashSet::new(); - for (class_def, class_ast) in def_ast_list.iter().skip(self.builtin_num) { - if class_ast.is_none() { - continue; - } - if matches!(&*class_def.read(), TopLevelDef::Class { .. }) { + for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) { + if class_ast.is_some() && matches!(&*class_def.read(), TopLevelDef::Class { .. }) { if let Err(e) = Self::analyze_single_class_methods_fields( class_def, &class_ast.as_ref().unwrap().node, &temp_def_list, unifier, - primitives, + primitives_store, &mut type_var_to_concrete_def, (&self.keyword_list, &self.core_config), ) { errors.extend(e); } - } - } - if !errors.is_empty() { - return Err(errors); - } - // handle the inherited methods and fields - // Note: we cannot defer error handling til the end of the loop, because there is loop - // carried dependency, ignoring the error (temporarily) will cause all assumptions to break - // and produce weird error messages - let mut current_ancestor_depth: usize = 2; - loop { - let mut finished = true; - - for (class_def, class_ast) in def_ast_list.iter().skip(self.builtin_num) { - if class_ast.is_none() { - continue; + // The errors need to be reported before copying methods from parent to child classes + if !errors.is_empty() { + return Err(errors); } - let mut class_def = class_def.write(); - if let TopLevelDef::Class { ancestors, .. } = &*class_def { - // if the length of the ancestor is equal to the current depth - // it means that all the ancestors of the class is handled - if ancestors.len() == current_ancestor_depth { - finished = false; - Self::analyze_single_class_ancestors( + + // The lock on `class_def` must be released once the ancestors are updated + { + let mut class_def = class_def.write(); + let TopLevelDef::Class { ancestors, .. } = &*class_def else { unreachable!() }; + // Methods/fields needs to be processed only if class inherits from another class + if ancestors.len() > 1 { + if let Err(e) = Self::analyze_single_class_ancestors( &mut class_def, &temp_def_list, unifier, - primitives, + primitives_store, &mut type_var_to_concrete_def, - )?; + ) { + errors.extend(e); + }; + } + } + + let mut subst_list = Some(Vec::new()); + // unification of previously assigned typevar + let mut unification_helper = |ty, def| -> Result<(), HashSet> { + let target_ty = get_type_from_type_annotation_kinds( + &temp_def_list, + unifier, + primitives_store, + &def, + &mut subst_list, + )?; + unifier + .unify(ty, target_ty) + .map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?; + Ok(()) + }; + for (ty, def) in &type_var_to_concrete_def { + if let Err(e) = unification_helper(*ty, def.clone()) { + errors.extend(e); + } + } + for ty in subst_list.unwrap() { + let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) else { + unreachable!() + }; + + let mut new_fields = HashMap::new(); + let mut need_subst = false; + for (name, (ty, mutable)) in fields { + let substituted = unifier.subst(*ty, params); + need_subst |= substituted.is_some(); + new_fields.insert(*name, (substituted.unwrap_or(*ty), *mutable)); + } + if need_subst { + let new_ty = unifier.add_ty(TypeEnum::TObj { + obj_id: *obj_id, + params: params.clone(), + fields: new_fields, + }); + if let Err(e) = unifier.unify(ty, new_ty) { + errors.insert(e.to_display(unifier).to_string()); + } } } } - - if finished { - break; - } - - current_ancestor_depth += 1; - if current_ancestor_depth > def_ast_list.len() + 1 { - unreachable!("cannot be longer than the whole top level def list") - } } - let mut subst_list = Some(Vec::new()); - // unification of previously assigned typevar - let mut unification_helper = |ty, def| -> Result<(), HashSet> { - let target_ty = get_type_from_type_annotation_kinds( - &temp_def_list, - unifier, - primitives, - &def, - &mut subst_list, - )?; - unifier - .unify(ty, target_ty) - .map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?; - Ok(()) - }; - for (ty, def) in type_var_to_concrete_def { - if let Err(e) = unification_helper(ty, def) { - errors.extend(e); - } - } - for ty in subst_list.unwrap() { - let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) else { - unreachable!() - }; - - let mut new_fields = HashMap::new(); - let mut need_subst = false; - for (name, (ty, mutable)) in fields { - let substituted = unifier.subst(*ty, params); - need_subst |= substituted.is_some(); - new_fields.insert(*name, (substituted.unwrap_or(*ty), *mutable)); - } - if need_subst { - let new_ty = unifier.add_ty(TypeEnum::TObj { - obj_id: *obj_id, - params: params.clone(), - fields: new_fields, - }); - if let Err(e) = unifier.unify(ty, new_ty) { - errors.insert(e.to_display(unifier).to_string()); - } - } - } - if !errors.is_empty() { - return Err(errors); - } - - for (def, _) in def_ast_list.iter().skip(self.builtin_num) { + for (def, _) in def_list.iter().skip(self.builtin_num) { match &*def.read() { TopLevelDef::Class { resolver: Some(resolver), .. } | TopLevelDef::Function { resolver: Some(resolver), .. } => { if let Err(e) = - resolver.handle_deferred_eval(unifier, &temp_def_list, primitives) + resolver.handle_deferred_eval(unifier, &temp_def_list, primitives_store) { errors.insert(e); } @@ -895,10 +638,13 @@ impl TopLevelComposer { } } + if !errors.is_empty() { + return Err(errors); + } Ok(()) } - /// step 4, after class methods are done, top level functions have nothing unknown + /// step 3, after class methods are done, top level functions have nothing unknown fn analyze_top_level_function(&mut self) -> Result<(), HashSet> { let def_list = &self.definition_ast_list; let keyword_list = &self.keyword_list; @@ -1253,126 +999,83 @@ impl TopLevelComposer { let mut method_var_map = VarMap::new(); let arg_types: Vec = { - // check method parameters cannot have same name + // Function arguments must have: + // 1) `self` as first argument (we currently do not support staticmethods) + // 2) unique names + // 3) names different than keywords + match args.args.first() { + Some(id) if id.node.arg == "self".into() => {}, + _ => return Err(HashSet::from([format!( + "{name} method must have a `self` parameter (at {})", b.location + )])), + } let mut defined_parameter_name: HashSet<_> = HashSet::new(); - let zelf: StrRef = "self".into(); - for x in &args.args { - if !defined_parameter_name.insert(x.node.arg) - || (keyword_list.contains(&x.node.arg) && x.node.arg != zelf) - { - return Err(HashSet::from([ - format!("top level function must have unique parameter names \ - and names should not be the same as the keywords (at {})", - x.location), - ])) + for arg in args.args.iter().skip(1) { + if !defined_parameter_name.insert(arg.node.arg) { + return Err(HashSet::from([format!("class method must have a unique parameter names (at {})", b.location)])); + } + if keyword_list.contains(&arg.node.arg) { + return Err(HashSet::from([format!("parameter names should not be the same as the keywords (at {})", b.location)])); } } - if name == &"__init__".into() && !defined_parameter_name.contains(&zelf) { - return Err(HashSet::from([ - format!("__init__ method must have a `self` parameter (at {})", b.location), - ])) + // `self` must not be provided type annotation or default value + if args.args.len() == args.defaults.len() { + return Err(HashSet::from([format!("`self` cannot have a default value (at {})", b.location)])); } - if !defined_parameter_name.contains(&zelf) { - return Err(HashSet::from([ - format!("class method must have a `self` parameter (at {})", b.location), - ])) + if args.args[0].node.annotation.is_some() { + return Err(HashSet::from([format!("`self` cannot have a type annotation (at {})", b.location)])); } - let mut result = Vec::new(); - - let arg_with_default: Vec<( - &ast::Located>, - Option<&ast::Expr>, - )> = args - .args - .iter() - .rev() - .zip( - args.defaults - .iter() - .rev() - .map(|x| -> Option<&ast::Expr> { Some(x) }) - .chain(std::iter::repeat(None)), - ) - .collect_vec(); - - for (x, default) in arg_with_default.into_iter().rev() { - let name = x.node.arg; - if name != zelf { - let type_ann = { - let annotation_expr = x - .node - .annotation - .as_ref() - .ok_or_else(|| HashSet::from([ - format!( - "type annotation needed for `{}` at {}", - x.node.arg, x.location - ), - ]))? - .as_ref(); - parse_ast_to_type_annotation_kinds( - class_resolver, - temp_def_list, - unifier, - primitives, - annotation_expr, - vec![(class_id, class_type_vars_def.clone())] - .into_iter() - .collect::>(), - )? + let no_defaults = args.args.len() - args.defaults.len() - 1; + for (idx, x) in itertools::enumerate(args.args.iter().skip(1)) { + let type_ann = { + let Some(annotation_expr) = x.node.annotation.as_ref() else {return Err(HashSet::from([format!("type annotation needed for `{}` (at {})", x.node.arg, x.location)]));}; + parse_ast_to_type_annotation_kinds( + class_resolver, + temp_def_list, + unifier, + primitives, + annotation_expr, + vec![(class_id, class_type_vars_def.clone())] + .into_iter() + .collect::>(), + )? + }; + // find type vars within this method parameter type annotation + let type_vars_within = get_type_var_contained_in_type_annotation(&type_ann); + // handle the class type var and the method type var + for type_var_within in type_vars_within { + let TypeAnnotation::TypeVar(ty) = type_var_within else { + unreachable!("must be type var annotation") }; - // find type vars within this method parameter type annotation - let type_vars_within = - get_type_var_contained_in_type_annotation(&type_ann); - // handle the class type var and the method type var - for type_var_within in type_vars_within { - let TypeAnnotation::TypeVar(ty) = type_var_within else { - unreachable!("must be type var annotation") - }; - let id = Self::get_var_id(ty, unifier)?; - if let Some(prev_ty) = method_var_map.insert(id, ty) { - // if already in the list, make sure they are the same? - assert_eq!(prev_ty, ty); - } + let id = Self::get_var_id(ty, unifier)?; + if let Some(prev_ty) = method_var_map.insert(id, ty) { + // if already in the list, make sure they are the same? + assert_eq!(prev_ty, ty); } - // finish handling type vars - let dummy_func_arg = FuncArg { - name, - ty: unifier.get_dummy_var().ty, - default_value: match default { - None => None, - Some(default) => { - if name == "self".into() { - return Err(HashSet::from([ - format!("`self` parameter cannot take default value (at {})", x.location), - ])); - } - Some({ - let v = Self::parse_parameter_default_value( - default, - class_resolver, - )?; - Self::check_default_param_type( - &v, &type_ann, primitives, unifier, - ) - .map_err(|err| HashSet::from([ - format!("{} (at {})", err, x.location), - ]))?; - v - }) - } - }, - is_vararg: false, - }; - // push the dummy type and the type annotation - // into the list for later unification - type_var_to_concrete_def - .insert(dummy_func_arg.ty, type_ann.clone()); - result.push(dummy_func_arg); } + // finish handling type vars + let dummy_func_arg = FuncArg { + name: x.node.arg, + ty: unifier.get_dummy_var().ty, + default_value: if idx < no_defaults { None } else { + let default_idx = idx - no_defaults; + + Some({ + let v = Self::parse_parameter_default_value(&args.defaults[default_idx], class_resolver)?; + Self::check_default_param_type(&v, &type_ann, primitives, unifier).map_err(|err| HashSet::from([format!("{} (at {})", err, x.location)]))?; + v + }) + }, + is_vararg: false, + }; + // push the dummy type and the type annotation + // into the list for later unification + type_var_to_concrete_def + .insert(dummy_func_arg.ty, type_ann.clone()); + result.push(dummy_func_arg); } result }; @@ -1494,12 +1197,12 @@ impl TopLevelComposer { match v { ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {} _ => { - return Err(HashSet::from([ - format!( - "unsupported statement in class definition body (at {})", - b.location - ), - ])) + return Err(HashSet::from([ + format!( + "unsupported statement in class definition body (at {})", + b.location + ), + ])) } } class_attributes_def.push((*attr, dummy_field_type, v.clone())); @@ -1535,7 +1238,7 @@ impl TopLevelComposer { unreachable!("must be type var annotation") }; - if !class_type_vars_def.contains(&t) { + if !class_type_vars_def.contains(&t){ return Err(HashSet::from([ format!( "class fields can only use type \ @@ -1569,7 +1272,7 @@ impl TopLevelComposer { _ => { return Err(HashSet::from([ format!( - "unsupported statement in class definition body (at {})", + "unsupported statement type in class definition body (at {})", b.location ), ])) @@ -1615,7 +1318,6 @@ impl TopLevelComposer { let TypeAnnotation::CustomClass { id, params: _ } = base else { unreachable!("must be class type annotation") }; - let base = temp_def_list.get(id.0).unwrap(); let base = base.read(); let TopLevelDef::Class { methods, fields, attributes, .. } = &*base else { @@ -1624,93 +1326,68 @@ impl TopLevelComposer { // handle methods override // since we need to maintain the order, create a new list - let mut new_child_methods: Vec<(StrRef, Type, DefinitionId)> = Vec::new(); - let mut is_override: HashSet = HashSet::new(); - for (anc_method_name, anc_method_ty, anc_method_def_id) in methods { - // find if there is a method with same name in the child class - let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id); - for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def { - if class_method_name == anc_method_name { - // ignore and handle self - // if is __init__ method, no need to check return type - let ok = class_method_name == &"__init__".into() - || Self::check_overload_function_type( - *class_method_ty, - *anc_method_ty, - unifier, - type_var_to_concrete_def, - ); - if !ok { - return Err(HashSet::from([format!( - "method {class_method_name} has same name as ancestors' method, but incompatible type"), - ])); - } - // mark it as added - is_override.insert(*class_method_name); - to_be_added = (*class_method_name, *class_method_ty, *class_method_defid); - break; + let mut new_child_methods: IndexMap = + methods.iter().map(|m| (m.0, (m.1, m.2))).collect(); + + // let mut new_child_methods: Vec<(StrRef, Type, DefinitionId)> = methods.clone(); + for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def { + if let Some((ty, _)) = new_child_methods + .insert(*class_method_name, (*class_method_ty, *class_method_defid)) + { + let ok = class_method_name == &"__init__".into() + || Self::check_overload_function_type( + *class_method_ty, + ty, + unifier, + type_var_to_concrete_def, + ); + if !ok { + return Err(HashSet::from([format!( + "method {class_method_name} has same name as ancestors' method, but incompatible type"), + ])); } } - new_child_methods.push(to_be_added); } - // add those that are not overriding method to the new_child_methods - for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def { - if !is_override.contains(class_method_name) { - new_child_methods.push((*class_method_name, *class_method_ty, *class_method_defid)); - } - } - // use the new_child_methods to replace all the elements in `class_methods_def` class_methods_def.clear(); - class_methods_def.extend(new_child_methods); + class_methods_def + .extend(new_child_methods.iter().map(|f| (*f.0, f.1 .0, f.1 .1)).collect_vec()); // handle class fields - let mut new_child_fields: Vec<(StrRef, Type, bool)> = Vec::new(); - // let mut is_override: HashSet<_> = HashSet::new(); - for (anc_field_name, anc_field_ty, mutable) in fields { - let to_be_added = (*anc_field_name, *anc_field_ty, *mutable); - // find if there is a fields with the same name in the child class - for (class_field_name, ..) in &*class_fields_def { - if class_field_name == anc_field_name - || attributes.iter().any(|f| f.0 == *class_field_name) - { - return Err(HashSet::from([format!( - "field `{class_field_name}` has already declared in the ancestor classes" - )])); - } + let mut new_child_fields: IndexMap = + fields.iter().map(|f| (f.0, (f.1, f.2))).collect(); + let mut new_child_attributes: IndexMap = + attributes.iter().map(|f| (f.0, (f.1, f.2.clone()))).collect(); + // Overriding class fields and attributes is currently not supported + for (name, ty, mutable) in &*class_fields_def { + if new_child_fields.insert(*name, (*ty, *mutable)).is_some() + || new_child_attributes.contains_key(name) + { + return Err(HashSet::from([format!( + "field `{name}` has already declared in the ancestor classes" + )])); + } + } + for (name, ty, val) in &*class_attribute_def { + if new_child_attributes.insert(*name, (*ty, val.clone())).is_some() + || new_child_fields.contains_key(name) + { + return Err(HashSet::from([format!( + "attribute `{name}` has already declared in the ancestor classes" + )])); } - new_child_fields.push(to_be_added); } - // handle class attributes - let mut new_child_attributes: Vec<(StrRef, Type, ast::Constant)> = Vec::new(); - for (anc_attr_name, anc_attr_ty, attr_value) in attributes { - let to_be_added = (*anc_attr_name, *anc_attr_ty, attr_value.clone()); - // find if there is a attribute with the same name in the child class - for (class_attr_name, ..) in &*class_attribute_def { - if class_attr_name == anc_attr_name - || fields.iter().any(|f| f.0 == *class_attr_name) - { - return Err(HashSet::from([format!( - "attribute `{class_attr_name}` has already declared in the ancestor classes" - )])); - } - } - new_child_attributes.push(to_be_added); - } - - for (class_field_name, class_field_ty, mutable) in &*class_fields_def { - if !is_override.contains(class_field_name) { - new_child_fields.push((*class_field_name, *class_field_ty, *mutable)); - } - } class_fields_def.clear(); - class_fields_def.extend(new_child_fields); + class_fields_def + .extend(new_child_fields.iter().map(|f| (*f.0, f.1 .0, f.1 .1)).collect_vec()); class_attribute_def.clear(); - class_attribute_def.extend(new_child_attributes); + class_attribute_def.extend( + new_child_attributes.iter().map(|f| (*f.0, f.1 .0, f.1 .1.clone())).collect_vec(), + ); Ok(()) } - /// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of + /// step 4, analyze and call type inferencer to fill the `instance_to_stmt` of /// [`TopLevelDef::Function`] fn analyze_function_instance(&mut self) -> Result<(), HashSet> { // first get the class constructor type correct for the following type check in function body @@ -2229,7 +1906,7 @@ impl TopLevelComposer { Ok(()) } - /// Step 6. Analyze and populate the types of global variables. + /// Step 5. Analyze and populate the types of global variables. fn analyze_top_level_variables(&mut self) -> Result<(), HashSet> { let def_list = &self.definition_ast_list; let temp_def_list = self.extract_def_list(); diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index d674c51..b54ad83 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -626,64 +626,6 @@ impl TopLevelComposer { Err(HashSet::from([format!("no method {method_name} in the current class")])) } - /// get all base class def id of a class, excluding itself. \ - /// this function should called only after the direct parent is set - /// and before all the ancestors are set - /// and when we allow single inheritance \ - /// the order of the returned list is from the child to the deepest ancestor - pub fn get_all_ancestors_helper( - child: &TypeAnnotation, - temp_def_list: &[Arc>], - ) -> Result, HashSet> { - let mut result: Vec = Vec::new(); - let mut parent = Self::get_parent(child, temp_def_list); - while let Some(p) = parent { - parent = Self::get_parent(&p, temp_def_list); - let p_id = if let TypeAnnotation::CustomClass { id, .. } = &p { - *id - } else { - unreachable!("must be class kind annotation") - }; - // check cycle - let no_cycle = result.iter().all(|x| { - let TypeAnnotation::CustomClass { id, .. } = x else { - unreachable!("must be class kind annotation") - }; - - id.0 != p_id.0 - }); - if no_cycle { - result.push(p); - } else { - return Err(HashSet::from(["cyclic inheritance detected".into()])); - } - } - Ok(result) - } - - /// should only be called when finding all ancestors, so panic when wrong - fn get_parent( - child: &TypeAnnotation, - temp_def_list: &[Arc>], - ) -> Option { - let child_id = if let TypeAnnotation::CustomClass { id, .. } = child { - *id - } else { - unreachable!("should be class type annotation") - }; - let child_def = temp_def_list.get(child_id.0).unwrap(); - let child_def = child_def.read(); - let TopLevelDef::Class { ancestors, .. } = &*child_def else { - unreachable!("child must be top level class def") - }; - - if ancestors.is_empty() { - None - } else { - Some(ancestors[0].clone()) - } - } - /// get the `var_id` of a given `TVar` type pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result> { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { @@ -993,6 +935,139 @@ impl TopLevelComposer { )) } } + + /// Parses the class type variables and direct parents + /// we only allow single inheritance + pub fn analyze_class_bases( + class_def: &Arc>, + class_ast: &Option, + temp_def_list: &[Arc>], + unifier: &mut Unifier, + primitives_store: &PrimitiveStore, + ) -> Result<(), HashSet> { + let mut class_def = class_def.write(); + let (class_def_id, class_ancestors, class_bases_ast, class_type_vars, class_resolver) = { + let TopLevelDef::Class { object_id, ancestors, type_vars, resolver, .. } = + &mut *class_def + else { + unreachable!() + }; + let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) = class_ast + else { + unreachable!() + }; + (object_id, ancestors, bases, type_vars, resolver.as_ref().unwrap().as_ref()) + }; + + let mut is_generic = false; + let mut has_base = false; + // Check class bases for typevars + for b in class_bases_ast { + match &b.node { + // analyze typevars bounded to the class, + // only support things like `class A(Generic[T, V])`, + // things like `class A(Generic[T, V, ImportedModule.T])` is not supported + // i.e. only simple names are allowed in the subscript + // should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params + ast::ExprKind::Subscript { value, slice, .. } if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Generic".into()) => + { + if is_generic { + return Err(HashSet::from([format!( + "only single Generic[...] is allowed (at {})", + b.location + )])); + } + is_generic = true; + + let type_var_list: Vec<&ast::Expr<()>>; + // if `class A(Generic[T, V, G])` + if let ast::ExprKind::Tuple { elts, .. } = &slice.node { + type_var_list = elts.iter().collect_vec(); + // `class A(Generic[T])` + } else { + type_var_list = vec![&**slice]; + } + + let type_vars = type_var_list + .into_iter() + .map(|e| { + class_resolver.parse_type_annotation( + temp_def_list, + unifier, + primitives_store, + e, + ) + }) + .collect::, _>>()?; + + class_type_vars.extend(type_vars); + } + ast::ExprKind::Name { .. } | ast::ExprKind::Subscript { .. } => { + if has_base { + return Err(HashSet::from([format!("a class definition can only have at most one base class declaration and one generic declaration (at {})", b.location )])); + } + has_base = true; + // the function parse_ast_to make sure that no type var occurred in + // bast_ty if it is a CustomClassKind + let base_ty = parse_ast_to_type_annotation_kinds( + class_resolver, + temp_def_list, + unifier, + primitives_store, + b, + vec![(*class_def_id, class_type_vars.clone())] + .into_iter() + .collect::>(), + )?; + if let TypeAnnotation::CustomClass { .. } = &base_ty { + class_ancestors.push(base_ty); + } else { + return Err(HashSet::from([format!( + "class base declaration can only be custom class (at {})", + b.location + )])); + } + } + _ => { + return Err(HashSet::from([format!( + "unsupported statement in class defintion (at {})", + b.location + )])); + } + } + } + + Ok(()) + } + + /// gets all ancestors of a class + pub fn analyze_class_ancestors( + class_def: &Arc>, + temp_def_list: &[Arc>], + ) { + // Check if class has a direct parent + let mut class_def = class_def.write(); + let TopLevelDef::Class { ancestors, type_vars, object_id, .. } = &mut *class_def else { + unreachable!() + }; + let mut anc_set = HashMap::new(); + + if let Some(ancestor) = ancestors.first() { + let TypeAnnotation::CustomClass { id, .. } = ancestor else { unreachable!() }; + let TopLevelDef::Class { ancestors: parent_ancestors, .. } = + &*temp_def_list[id.0].read() + else { + unreachable!() + }; + for anc in parent_ancestors.iter().skip(1) { + let TypeAnnotation::CustomClass { id, .. } = anc else { unreachable!() }; + anc_set.insert(id, anc.clone()); + } + ancestors.extend(anc_set.into_values()); + } + // push `self` as first ancestor of class + ancestors.insert(0, make_self_type_annotation(type_vars.as_slice(), *object_id)); + } } pub fn parse_parameter_default_value( diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 53ff774..1b0c9b8 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", - "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(241)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", + "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", + "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 502abbd..9a9c4dd 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -7,11 +7,11 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", - "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n", ] diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 077f6ab..5d18aba 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -230,11 +230,6 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { def foo(self, a: T, b: V): pass "}, - indoc! {" - class B(C): - def __init__(self): - pass - "}, indoc! {" class C(A): def __init__(self): @@ -243,6 +238,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { a = 1 pass "}, + indoc! {" + class B(C): + def __init__(self): + pass + "}, indoc! {" def foo(a: A): pass @@ -257,6 +257,14 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { )] #[test_case( &[ + indoc! {" + class B: + aa: bool + def __init__(self): + self.aa = False + def foo(self, b: T): + pass + "}, indoc! {" class Generic_A(Generic[V], B): a: int64 @@ -264,14 +272,6 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { self.a = 123123123123 def fun(self, a: int32) -> V: pass - "}, - indoc! {" - class B: - aa: bool - def __init__(self): - self.aa = False - def foo(self, b: T): - pass "} ], &[]; @@ -391,18 +391,18 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { pass "} ], - &["cyclic inheritance detected"]; + &["NameError: name 'B' is not defined (at unknown:1:9)"]; "cyclic1" )] #[test_case( &[ indoc! {" - class A(B[bool, int64]): - def __init__(self): - pass + class B(Generic[V, T], C[int32]): + def __init__(self): + pass "}, indoc! {" - class B(Generic[V, T], C[int32]): + class A(B[bool, int64]): def __init__(self): pass "}, @@ -412,7 +412,7 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { pass "}, ], - &["cyclic inheritance detected"]; + &["NameError: name 'C' is not defined (at unknown:1:25)"]; "cyclic2" )] #[test_case( @@ -436,11 +436,6 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { )] #[test_case( &[ - indoc! {" - class A(B, Generic[T], C): - def __init__(self): - pass - "}, indoc! {" class B: def __init__(self): @@ -450,6 +445,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { class C: def __init__(self): pass + "}, + indoc! {" + class A(B, Generic[T], C): + def __init__(self): + pass "} ], diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index f8b16f8..e3b2911 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -101,7 +101,13 @@ pub fn parse_ast_to_type_annotation_kinds( Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() }) } else if let Ok(obj_id) = resolver.get_identifier_def(*id) { let type_vars = { - let def_read = top_level_defs[obj_id.0].try_read(); + let Some(top_level_def) = top_level_defs.get(obj_id.0) else { + return Err(HashSet::from([format!( + "NameError: name '{id}' is not defined (at {})", + expr.location + )])); + }; + let def_read = top_level_def.try_read(); if let Some(def_read) = def_read { if let TopLevelDef::Class { type_vars, .. } = &*def_read { type_vars.clone() @@ -156,12 +162,17 @@ pub fn parse_ast_to_type_annotation_kinds( } let obj_id = resolver.get_identifier_def(*id)?; let type_vars = { - let def_read = top_level_defs[obj_id.0].try_read(); + let Some(top_level_def) = top_level_defs.get(obj_id.0) else { + return Err(HashSet::from([format!( + "NameError: name '{id}' is not defined (at {})", + expr.location + )])); + }; + let def_read = top_level_def.try_read(); if let Some(def_read) = def_read { let TopLevelDef::Class { type_vars, .. } = &*def_read else { unreachable!("must be class here") }; - type_vars.clone() } else { locked.get(&obj_id).unwrap().clone()