diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 50eaccc4..0feaf674 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,4 +1,5 @@ use super::*; +use crate::typecheck::typedef::TypeVarMeta; impl TopLevelComposer { pub fn make_primitives() -> (PrimitiveStore, Unifier) { @@ -116,4 +117,49 @@ impl TopLevelComposer { None } } + + pub fn check_overload_type_compatible(unifier: &mut Unifier, ty: Type, other: Type) -> bool { + let ty = unifier.get_ty(ty); + let ty = ty.as_ref(); + let other = unifier.get_ty(other); + let other = other.as_ref(); + + match (ty, other) { + (TypeEnum::TList { ty }, TypeEnum::TList { ty: other }) + | (TypeEnum::TVirtual { ty }, TypeEnum::TVirtual { ty: other }) => { + Self::check_overload_type_compatible(unifier, *ty, *other) + } + + (TypeEnum::TTuple { ty }, TypeEnum::TTuple { ty: other }) => ty + .iter() + .zip(other) + .all(|(ty, other)| Self::check_overload_type_compatible(unifier, *ty, *other)), + + ( + TypeEnum::TObj { obj_id, params, .. }, + TypeEnum::TObj { obj_id: other_obj_id, params: other_params, .. }, + ) => { + let params = &*params.borrow(); + let other_params = &*other_params.borrow(); + obj_id.0 == other_obj_id.0 + && (params.iter().all(|(var_id, ty)| { + if let Some(other_ty) = other_params.get(var_id) { + Self::check_overload_type_compatible(unifier, *ty, *other_ty) + } else { + false + } + })) + } + + ( + TypeEnum::TVar { id, meta: TypeVarMeta::Generic, .. }, + TypeEnum::TVar { id: other_id, meta: TypeVarMeta::Generic, .. }, + ) => { + // NOTE: directly compare var_id? + *id == *other_id + } + + _ => false, + } + } } diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 84b5cb99..7f375f9f 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -452,7 +452,7 @@ impl TopLevelComposer { } } - // second get all ancestors + // second, get all ancestors let mut ancestors_store: HashMap> = Default::default(); for (class_def, class_ast) in self.definition_ast_list.iter_mut() { let mut class_def = class_def.write(); @@ -513,7 +513,7 @@ impl TopLevelComposer { let mut type_var_to_concrete_def: HashMap = HashMap::new(); for (class_def, class_ast) in def_ast_list { - Self::analyze_single_class( + Self::analyze_single_class_methods_fields( class_def.clone(), &class_ast.as_ref().unwrap().node, &temp_def_list, @@ -531,10 +531,20 @@ impl TopLevelComposer { unifier.unify(ty, target_ty)?; } + // handle the inheritanced methods and fields + for (class_def, _) in def_ast_list { + Self::analyze_single_class_ancestors( + class_def.clone(), + &temp_def_list, + unifier, + primitives, + )?; + } + Ok(()) } - /// step 4, after class methods are done + /// step 4, after class methods are done, top level functions have nothing unknown fn analyze_top_level_function(&mut self) -> Result<(), String> { let def_list = &self.definition_ast_list; let keyword_list = &self.keyword_list; @@ -682,13 +692,7 @@ impl TopLevelComposer { Ok(()) } - /// step 5, field instantiation? - fn analyze_top_level_field_instantiation(&mut self) -> Result<(), String> { - // TODO: - unimplemented!() - } - - fn analyze_single_class( + fn analyze_single_class_methods_fields( class_def: Arc>, class_ast: &ast::StmtKind<()>, temp_def_list: &[Arc>], @@ -720,7 +724,7 @@ impl TopLevelComposer { { if let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast { ( - object_id, + *object_id, name.clone(), bases, body, @@ -828,7 +832,7 @@ impl TopLevelComposer { }; type_var_to_concrete_def.insert( dummy_func_arg.ty, - make_self_type_annotation(temp_def_list, *class_id)?, + make_self_type_annotation(temp_def_list, class_id)?, ); result.push(dummy_func_arg); } @@ -874,7 +878,7 @@ impl TopLevelComposer { let dummy_return_type = unifier.get_fresh_var().0; type_var_to_concrete_def.insert( dummy_return_type, - make_self_type_annotation(temp_def_list, *class_id)?, + make_self_type_annotation(temp_def_list, class_id)?, ); dummy_return_type } @@ -943,4 +947,122 @@ impl TopLevelComposer { } Ok(()) } + + fn analyze_single_class_ancestors( + class_def: Arc>, + temp_def_list: &[Arc>], + unifier: &mut Unifier, + primitives: &PrimitiveStore, + ) -> Result<(), String> { + let mut class_def = class_def.write(); + let ( + _class_id, + class_ancestor_def, + _class_fields_def, + class_methods_def, + _class_type_vars_def, + _class_resolver, + ) = if let TopLevelDef::Class { + object_id, + ancestors, + fields, + methods, + resolver, + type_vars, + .. + } = class_def.deref_mut() + { + (*object_id, ancestors, fields, methods, type_vars, resolver) + } else { + unreachable!("here must be class def ast"); + }; + + for (method_name, method_ty, ..) in class_methods_def { + if method_name == "__init__" { + continue; + } + // search the ancestors from the nearest to the deepest to find overload and check + 'search_for_overload: for anc in class_ancestor_def.iter().skip(1) { + if let TypeAnnotation::CustomClassKind { id, params } = anc { + let anc_class_def = temp_def_list.get(id.0).unwrap().read(); + let anc_class_def = anc_class_def.deref(); + + if let TopLevelDef::Class { methods, type_vars, .. } = anc_class_def { + for (anc_method_name, anc_method_ty, ..) in methods { + // if same name, then is overload, needs check + if anc_method_name == method_name { + let param_ty = params + .iter() + .map(|x| { + get_type_from_type_annotation_kinds( + temp_def_list, + unifier, + primitives, + x, + ) + }) + .collect::, _>>()?; + + let subst = type_vars + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = + unifier.get_ty(*x).as_ref() + { + *id + } else { + unreachable!() + } + }) + .zip(param_ty.into_iter()) + .collect::>(); + + let anc_method_ty = unifier.subst(*anc_method_ty, &subst).unwrap(); + + if let ( + TypeEnum::TFunc(child_method_sig), + TypeEnum::TFunc(parent_method_sig), + ) = ( + unifier.get_ty(*method_ty).as_ref(), + unifier.get_ty(anc_method_ty).as_ref(), + ) { + let ( + FunSignature { args: c_as, ret: c_r, .. }, + FunSignature { args: p_as, ret: p_r, .. }, + ) = (&*child_method_sig.borrow(), &*parent_method_sig.borrow()); + + // arguments + for ( + FuncArg { name: c_name, ty: c_ty, .. }, + FuncArg { name: p_name, ty: p_ty, .. }, + ) in c_as.iter().zip(p_as) + { + if c_name == "self" { + continue; + } + if c_name != p_name + || !Self::check_overload_type_compatible( + unifier, *c_ty, *p_ty, + ) + { + return Err("incompatible parameter".into()); + } + } + + // check the compatibility of c_r and p_r + if !Self::check_overload_type_compatible(unifier, *c_r, *p_r) { + return Err("incompatible parameter".into()); + } + } else { + unreachable!("must be function type") + } + break 'search_for_overload; + } + } + } + } + } + } + Ok(()) + } } diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 4efd03b2..52a073be 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -156,7 +156,7 @@ pub fn parse_ast_to_type_annotation_kinds( result } else { return Err("application of type vars to generic class \ - not currently supported" + is not currently supported" .into()); } };