diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index ede82f85a..cc2c3277c 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,5 +1,4 @@ use super::*; -use crate::typecheck::typedef::TypeVarMeta; impl TopLevelComposer { pub fn make_primitives() -> (PrimitiveStore, Unifier) { @@ -93,14 +92,31 @@ impl TopLevelComposer { pub fn get_all_ancestors_helper( child: &TypeAnnotation, temp_def_list: &[Arc>], - ) -> Vec { + ) -> Result, String> { let mut result: Vec = Vec::new(); let mut parent = Self::get_parent(child, temp_def_list); while let Some(p) = parent { parent = Self::get_parent(&p, temp_def_list); - result.push(p); + let p_id = if let TypeAnnotation::CustomClassKind { id, .. } = &p { + *id + } else { + unreachable!("must be class kind annotation") + }; + // check cycle + let no_cycle = result.iter().all(|x| { + if let TypeAnnotation::CustomClassKind { id, .. } = x { + id.0 != p_id.0 + } else { + unreachable!("must be class kind annotation") + } + }); + if no_cycle { + result.push(p); + } else { + return Err("cyclic inheritance detected".into()); + } } - result + Ok(result) } /// should only be called when finding all ancestors, so panic when wrong @@ -126,51 +142,6 @@ impl TopLevelComposer { } } - pub fn check_overload_type_compatible(unifier: &mut Unifier, ty: Type, other: Type) -> bool { - let ty = unifier.get_ty(ty); - let ty = ty.as_ref(); - let other = unifier.get_ty(other); - let other = other.as_ref(); - - match (ty, other) { - (TypeEnum::TList { ty }, TypeEnum::TList { ty: other }) - | (TypeEnum::TVirtual { ty }, TypeEnum::TVirtual { ty: other }) => { - Self::check_overload_type_compatible(unifier, *ty, *other) - } - - (TypeEnum::TTuple { ty }, TypeEnum::TTuple { ty: other }) => ty - .iter() - .zip(other) - .all(|(ty, other)| Self::check_overload_type_compatible(unifier, *ty, *other)), - - ( - TypeEnum::TObj { obj_id, params, .. }, - TypeEnum::TObj { obj_id: other_obj_id, params: other_params, .. }, - ) => { - let params = &*params.borrow(); - let other_params = &*other_params.borrow(); - obj_id.0 == other_obj_id.0 - && (params.iter().all(|(var_id, ty)| { - if let Some(other_ty) = other_params.get(var_id) { - Self::check_overload_type_compatible(unifier, *ty, *other_ty) - } else { - false - } - })) - } - - ( - TypeEnum::TVar { id, meta: TypeVarMeta::Generic, .. }, - TypeEnum::TVar { id: other_id, meta: TypeVarMeta::Generic, .. }, - ) => { - // NOTE: directly compare var_id? - *id == *other_id - } - - _ => false, - } - } - /// get the var_id of a given TVar type pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { @@ -179,4 +150,63 @@ impl TopLevelComposer { Err("not type var".to_string()) } } + + pub fn check_overload_function_type( + this: Type, + other: Type, + unifier: &mut Unifier, + type_var_to_concrete_def: &HashMap, + ) -> bool { + let this = unifier.get_ty(this); + let this = this.as_ref(); + let other = unifier.get_ty(other); + let other = other.as_ref(); + if let (TypeEnum::TFunc(this_sig), TypeEnum::TFunc(other_sig)) = (this, other) { + let (this_sig, other_sig) = (&*this_sig.borrow(), &*other_sig.borrow()); + let ( + FunSignature { args: this_args, ret: this_ret, vars: _this_vars }, + FunSignature { args: other_args, ret: other_ret, vars: _other_vars }, + ) = (this_sig, other_sig); + // check args + let args_ok = this_args + .iter() + .map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap())) + .zip(other_args.iter().map(|FuncArg { name, ty, .. }| { + (name, type_var_to_concrete_def.get(ty).unwrap()) + })) + .all(|(this, other)| { + if this.0 == "self" && this.0 == other.0 { + true + } else { + this.0 == other.0 + && check_overload_type_annotation_compatible(this.1, other.1, unifier) + } + }); + + // check rets + let ret_ok = check_overload_type_annotation_compatible( + type_var_to_concrete_def.get(this_ret).unwrap(), + type_var_to_concrete_def.get(other_ret).unwrap(), + unifier, + ); + + // return + args_ok && ret_ok + } else { + unreachable!("this function must be called with function type") + } + } + + pub fn check_overload_field_type( + this: Type, + other: Type, + unifier: &mut Unifier, + type_var_to_concrete_def: &HashMap, + ) -> bool { + check_overload_type_annotation_compatible( + type_var_to_concrete_def.get(&this).unwrap(), + type_var_to_concrete_def.get(&other).unwrap(), + unifier, + ) + } } diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index c4017da4c..d31378d50 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -195,6 +195,7 @@ impl TopLevelComposer { // we do not push anything to the def list, so we keep track of the index // and then push in the correct order after the for loop let mut class_method_index_offset = 0; + let mut has_init = false; for b in body { if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node { if self.keyword_list.contains(name) { @@ -205,6 +206,9 @@ impl TopLevelComposer { if !defined_class_method_name.insert(global_class_method_name.clone()) { return Err("duplicate class method definition".into()); } + if method_name == "__init__" { + has_init = true; + } let method_def_id = self.definition_ast_list.len() + { // plus 1 here since we already have the class def class_method_index_offset += 1; @@ -230,6 +234,9 @@ impl TopLevelComposer { continue; } } + if !has_init { + return Err("class def must have __init__ method defined".into()); + } // move the ast to the entry of the class in the ast_list class_def_ast.1 = Some(ast); @@ -469,7 +476,7 @@ impl TopLevelComposer { if class_ancestors.is_empty() { vec![] } else { - Self::get_all_ancestors_helper(&class_ancestors[0], temp_def_list.as_slice()) + Self::get_all_ancestors_helper(&class_ancestors[0], temp_def_list.as_slice())? }, ); } @@ -499,9 +506,9 @@ impl TopLevelComposer { /// step 3, class fields and methods fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> { let temp_def_list = self.extract_def_list(); - let unifier = self.unifier.borrow_mut(); let primitives = &self.primitives_ty; let def_ast_list = &self.definition_ast_list; + let unifier = self.unifier.borrow_mut(); let mut type_var_to_concrete_def: HashMap = HashMap::new(); @@ -517,6 +524,40 @@ impl TopLevelComposer { )? } + // handle the inheritanced methods and fields + let mut current_ancestor_depth: usize = 2; + loop { + let mut finished = true; + + for (class_def, _) in def_ast_list { + let mut class_def = class_def.write(); + if let TopLevelDef::Class { ancestors, .. } = class_def.deref() { + // if the length of the ancestor is equal to the current depth + // it means that all the ancestors of the class is handled + if ancestors.len() == current_ancestor_depth { + finished = false; + Self::analyze_single_class_ancestors( + class_def.deref_mut(), + &temp_def_list, + unifier, + primitives, + &mut type_var_to_concrete_def, + )?; + } + } + } + + if finished { + break; + } else { + current_ancestor_depth += 1; + } + + if current_ancestor_depth > def_ast_list.len() + 1 { + unreachable!("cannot be longer than the whole top level def list") + } + } + // unification of previously assigned typevar for (ty, def) in type_var_to_concrete_def { let target_ty = @@ -524,16 +565,6 @@ impl TopLevelComposer { unifier.unify(ty, target_ty)?; } - // handle the inheritanced methods and fields - for (class_def, _) in def_ast_list { - Self::analyze_single_class_ancestors( - class_def.clone(), - &temp_def_list, - unifier, - primitives, - )?; - } - Ok(()) } @@ -596,7 +627,6 @@ impl TopLevelComposer { annotation, )?; - // if there are same type variables appears, we only need to copy them once let type_vars_within = get_type_var_contained_in_type_annotation(&type_annotation) .into_iter() @@ -679,6 +709,7 @@ impl TopLevelComposer { unreachable!("must be both function"); } } else { + // not top level function def, skip continue; } } @@ -942,16 +973,16 @@ impl TopLevelComposer { } fn analyze_single_class_ancestors( - class_def: Arc>, + class_def: &mut TopLevelDef, temp_def_list: &[Arc>], unifier: &mut Unifier, - primitives: &PrimitiveStore, + _primitives: &PrimitiveStore, + type_var_to_concrete_def: &mut HashMap, ) -> Result<(), String> { - let mut class_def = class_def.write(); let ( _class_id, class_ancestor_def, - _class_fields_def, + class_fields_def, class_methods_def, _class_type_vars_def, _class_resolver, @@ -963,99 +994,110 @@ impl TopLevelComposer { resolver, type_vars, .. - } = class_def.deref_mut() + } = class_def { (*object_id, ancestors, fields, methods, type_vars, resolver) } else { unreachable!("here must be class def ast"); }; - for (method_name, method_ty, ..) in class_methods_def { - if method_name == "__init__" { - continue; - } - // search the ancestors from the nearest to the deepest to find overload and check - 'search_for_overload: for anc in class_ancestor_def.iter().skip(1) { - if let TypeAnnotation::CustomClassKind { id, params } = anc { - let anc_class_def = temp_def_list.get(id.0).unwrap().read(); - let anc_class_def = anc_class_def.deref(); - - if let TopLevelDef::Class { methods, type_vars, .. } = anc_class_def { - for (anc_method_name, anc_method_ty, ..) in methods { - // if same name, then is overload, needs check - if anc_method_name == method_name { - let param_ty = params - .iter() - .map(|x| { - get_type_from_type_annotation_kinds( - temp_def_list, - unifier, - primitives, - x, - ) - }) - .collect::, _>>()?; - - let subst = type_vars - .iter() - .map(|x| { - if let TypeEnum::TVar { id, .. } = - unifier.get_ty(*x).as_ref() - { - *id - } else { - unreachable!() - } - }) - .zip(param_ty.into_iter()) - .collect::>(); - - let anc_method_ty = unifier.subst(*anc_method_ty, &subst).unwrap(); - - if let ( - TypeEnum::TFunc(child_method_sig), - TypeEnum::TFunc(parent_method_sig), - ) = ( - unifier.get_ty(*method_ty).as_ref(), - unifier.get_ty(anc_method_ty).as_ref(), - ) { - let ( - FunSignature { args: c_as, ret: c_r, .. }, - FunSignature { args: p_as, ret: p_r, .. }, - ) = (&*child_method_sig.borrow(), &*parent_method_sig.borrow()); - - // arguments - for ( - FuncArg { name: c_name, ty: c_ty, .. }, - FuncArg { name: p_name, ty: p_ty, .. }, - ) in c_as.iter().zip(p_as) - { - if c_name == "self" { - continue; - } - if c_name != p_name - || !Self::check_overload_type_compatible( - unifier, *c_ty, *p_ty, - ) - { - return Err("incompatible parameter".into()); - } - } - - // check the compatibility of c_r and p_r - if !Self::check_overload_type_compatible(unifier, *c_r, *p_r) { - return Err("incompatible parameter".into()); - } - } else { - unreachable!("must be function type") - } - break 'search_for_overload; + // since when this function is called, the ancestors of the direct parent + // are supposed to be already handled, so we only need to deal with the direct parent + let base = class_ancestor_def.get(1).unwrap(); + if let TypeAnnotation::CustomClassKind { id, params: _ } = base { + let base = temp_def_list.get(id.0).unwrap(); + let base = base.read(); + if let TopLevelDef::Class { methods, fields, .. } = &*base { + // handle methods override + // since we need to maintain the order, create a new list + let mut new_child_methods: Vec<(String, Type, DefinitionId)> = Vec::new(); + let mut is_override: HashSet = HashSet::new(); + for (anc_method_name, anc_method_ty, anc_method_def_id) in methods { + // find if there is a method with same name in the child class + let mut to_be_added = + (anc_method_name.to_string(), *anc_method_ty, *anc_method_def_id); + for (class_method_name, class_method_ty, class_method_defid) in + class_methods_def.iter() + { + if class_method_name == anc_method_name { + // ignore and handle self + let ok = class_method_name == "__init__" + && Self::check_overload_function_type( + *class_method_ty, + *anc_method_ty, + unifier, + type_var_to_concrete_def, + ); + if !ok { + return Err("method has same name as ancestors' method, but incompatible type".into()); } + // mark it as added + is_override.insert(class_method_name.to_string()); + to_be_added = ( + class_method_name.to_string(), + *class_method_ty, + *class_method_defid, + ); + break; } } + new_child_methods.push(to_be_added); } + // add those that are not overriding method to the new_child_methods + for (class_method_name, class_method_ty, class_method_defid) in + class_methods_def.iter() + { + if !is_override.contains(class_method_name) { + new_child_methods.push(( + class_method_name.to_string(), + *class_method_ty, + *class_method_defid, + )); + } + } + // use the new_child_methods to replace all the elements in `class_methods_def` + class_methods_def.drain(..); + class_methods_def.extend(new_child_methods); + + // handle class fields + let mut new_child_fields: Vec<(String, Type)> = Vec::new(); + let mut is_override: HashSet = HashSet::new(); + for (anc_field_name, anc_field_ty) in fields { + let mut to_be_added = (anc_field_name.to_string(), *anc_field_ty); + // find if there is a fields with the same name in the child class + for (class_field_name, class_field_ty) in class_fields_def.iter() { + if class_field_name == anc_field_name { + let ok = Self::check_overload_field_type( + *class_field_ty, + *anc_field_ty, + unifier, + type_var_to_concrete_def, + ); + if !ok { + return Err("fields has same name as ancestors' field, but incompatible type".into()); + } + // mark it as added + is_override.insert(class_field_name.to_string()); + to_be_added = (class_field_name.to_string(), *class_field_ty); + break; + } + } + new_child_fields.push(to_be_added); + } + for (class_field_name, class_field_ty) in class_fields_def.iter() { + if !is_override.contains(class_field_name) { + new_child_fields.push((class_field_name.to_string(), *class_field_ty)); + } + } + class_fields_def.drain(..); + class_fields_def.extend(new_child_fields); + } else { + unreachable!("must be top level class def") } + } else { + unreachable!("must be class type annotation") } + Ok(()) } } diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index ac8009043..960e982dd 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,3 +1,5 @@ +use crate::typecheck::typedef::TypeVarMeta; + use super::*; #[derive(Clone)] @@ -323,3 +325,56 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec bool { + match (this, other) { + (TypeAnnotation::PrimitiveKind(a), TypeAnnotation::PrimitiveKind(b)) => a == b, + (TypeAnnotation::TypeVarKind(a), TypeAnnotation::TypeVarKind(b)) => { + let a = unifier.get_ty(*a); + let a = a.deref(); + let b = unifier.get_ty(*b); + let b = b.deref(); + if let ( + TypeEnum::TVar { id: a, meta: TypeVarMeta::Generic, .. }, + TypeEnum::TVar { id: b, meta: TypeVarMeta::Generic, .. }, + ) = (a, b) + { + a == b + } else { + unreachable!("must be type var") + } + } + (TypeAnnotation::VirtualKind(a), TypeAnnotation::VirtualKind(b)) + | (TypeAnnotation::ListKind(a), TypeAnnotation::ListKind(b)) => { + check_overload_type_annotation_compatible(a.as_ref(), b.as_ref(), unifier) + } + + (TypeAnnotation::TupleKind(a), TypeAnnotation::TupleKind(b)) => { + a.len() == b.len() && { + a.iter() + .zip(b) + .all(|(a, b)| check_overload_type_annotation_compatible(a, b, unifier)) + } + } + + ( + TypeAnnotation::CustomClassKind { id: a, params: a_p }, + TypeAnnotation::CustomClassKind { id: b, params: b_p }, + ) => { + a.0 == b.0 && { + a_p.len() == b_p.len() && { + a_p.iter() + .zip(b_p) + .all(|(a, b)| check_overload_type_annotation_compatible(a, b, unifier)) + } + } + } + + _ => false, + } +}