forked from M-Labs/nac3
1
0
Fork 0

nac3core: top level cleanup, rewrite ancestors handling, __init__ occruence check

This commit is contained in:
ychenfo 2021-08-31 09:57:07 +08:00
parent 4a9593efa3
commit 7bbd608492
3 changed files with 274 additions and 147 deletions

View File

@ -1,5 +1,4 @@
use super::*; use super::*;
use crate::typecheck::typedef::TypeVarMeta;
impl TopLevelComposer { impl TopLevelComposer {
pub fn make_primitives() -> (PrimitiveStore, Unifier) { pub fn make_primitives() -> (PrimitiveStore, Unifier) {
@ -93,14 +92,31 @@ impl TopLevelComposer {
pub fn get_all_ancestors_helper( pub fn get_all_ancestors_helper(
child: &TypeAnnotation, child: &TypeAnnotation,
temp_def_list: &[Arc<RwLock<TopLevelDef>>], temp_def_list: &[Arc<RwLock<TopLevelDef>>],
) -> Vec<TypeAnnotation> { ) -> Result<Vec<TypeAnnotation>, String> {
let mut result: Vec<TypeAnnotation> = Vec::new(); let mut result: Vec<TypeAnnotation> = Vec::new();
let mut parent = Self::get_parent(child, temp_def_list); let mut parent = Self::get_parent(child, temp_def_list);
while let Some(p) = parent { while let Some(p) = parent {
parent = Self::get_parent(&p, temp_def_list); parent = Self::get_parent(&p, temp_def_list);
result.push(p); let p_id = if let TypeAnnotation::CustomClassKind { id, .. } = &p {
*id
} else {
unreachable!("must be class kind annotation")
};
// check cycle
let no_cycle = result.iter().all(|x| {
if let TypeAnnotation::CustomClassKind { id, .. } = x {
id.0 != p_id.0
} else {
unreachable!("must be class kind annotation")
}
});
if no_cycle {
result.push(p);
} else {
return Err("cyclic inheritance detected".into());
}
} }
result Ok(result)
} }
/// should only be called when finding all ancestors, so panic when wrong /// should only be called when finding all ancestors, so panic when wrong
@ -126,51 +142,6 @@ impl TopLevelComposer {
} }
} }
pub fn check_overload_type_compatible(unifier: &mut Unifier, ty: Type, other: Type) -> bool {
let ty = unifier.get_ty(ty);
let ty = ty.as_ref();
let other = unifier.get_ty(other);
let other = other.as_ref();
match (ty, other) {
(TypeEnum::TList { ty }, TypeEnum::TList { ty: other })
| (TypeEnum::TVirtual { ty }, TypeEnum::TVirtual { ty: other }) => {
Self::check_overload_type_compatible(unifier, *ty, *other)
}
(TypeEnum::TTuple { ty }, TypeEnum::TTuple { ty: other }) => ty
.iter()
.zip(other)
.all(|(ty, other)| Self::check_overload_type_compatible(unifier, *ty, *other)),
(
TypeEnum::TObj { obj_id, params, .. },
TypeEnum::TObj { obj_id: other_obj_id, params: other_params, .. },
) => {
let params = &*params.borrow();
let other_params = &*other_params.borrow();
obj_id.0 == other_obj_id.0
&& (params.iter().all(|(var_id, ty)| {
if let Some(other_ty) = other_params.get(var_id) {
Self::check_overload_type_compatible(unifier, *ty, *other_ty)
} else {
false
}
}))
}
(
TypeEnum::TVar { id, meta: TypeVarMeta::Generic, .. },
TypeEnum::TVar { id: other_id, meta: TypeVarMeta::Generic, .. },
) => {
// NOTE: directly compare var_id?
*id == *other_id
}
_ => false,
}
}
/// get the var_id of a given TVar type /// get the var_id of a given TVar type
pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<u32, String> { pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<u32, String> {
if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() {
@ -179,4 +150,63 @@ impl TopLevelComposer {
Err("not type var".to_string()) Err("not type var".to_string())
} }
} }
pub fn check_overload_function_type(
this: Type,
other: Type,
unifier: &mut Unifier,
type_var_to_concrete_def: &HashMap<Type, TypeAnnotation>,
) -> bool {
let this = unifier.get_ty(this);
let this = this.as_ref();
let other = unifier.get_ty(other);
let other = other.as_ref();
if let (TypeEnum::TFunc(this_sig), TypeEnum::TFunc(other_sig)) = (this, other) {
let (this_sig, other_sig) = (&*this_sig.borrow(), &*other_sig.borrow());
let (
FunSignature { args: this_args, ret: this_ret, vars: _this_vars },
FunSignature { args: other_args, ret: other_ret, vars: _other_vars },
) = (this_sig, other_sig);
// check args
let args_ok = this_args
.iter()
.map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap()))
.zip(other_args.iter().map(|FuncArg { name, ty, .. }| {
(name, type_var_to_concrete_def.get(ty).unwrap())
}))
.all(|(this, other)| {
if this.0 == "self" && this.0 == other.0 {
true
} else {
this.0 == other.0
&& check_overload_type_annotation_compatible(this.1, other.1, unifier)
}
});
// check rets
let ret_ok = check_overload_type_annotation_compatible(
type_var_to_concrete_def.get(this_ret).unwrap(),
type_var_to_concrete_def.get(other_ret).unwrap(),
unifier,
);
// return
args_ok && ret_ok
} else {
unreachable!("this function must be called with function type")
}
}
pub fn check_overload_field_type(
this: Type,
other: Type,
unifier: &mut Unifier,
type_var_to_concrete_def: &HashMap<Type, TypeAnnotation>,
) -> bool {
check_overload_type_annotation_compatible(
type_var_to_concrete_def.get(&this).unwrap(),
type_var_to_concrete_def.get(&other).unwrap(),
unifier,
)
}
} }

View File

@ -195,6 +195,7 @@ impl TopLevelComposer {
// we do not push anything to the def list, so we keep track of the index // we do not push anything to the def list, so we keep track of the index
// and then push in the correct order after the for loop // and then push in the correct order after the for loop
let mut class_method_index_offset = 0; let mut class_method_index_offset = 0;
let mut has_init = false;
for b in body { for b in body {
if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node { if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node {
if self.keyword_list.contains(name) { if self.keyword_list.contains(name) {
@ -205,6 +206,9 @@ impl TopLevelComposer {
if !defined_class_method_name.insert(global_class_method_name.clone()) { if !defined_class_method_name.insert(global_class_method_name.clone()) {
return Err("duplicate class method definition".into()); return Err("duplicate class method definition".into());
} }
if method_name == "__init__" {
has_init = true;
}
let method_def_id = self.definition_ast_list.len() + { let method_def_id = self.definition_ast_list.len() + {
// plus 1 here since we already have the class def // plus 1 here since we already have the class def
class_method_index_offset += 1; class_method_index_offset += 1;
@ -230,6 +234,9 @@ impl TopLevelComposer {
continue; continue;
} }
} }
if !has_init {
return Err("class def must have __init__ method defined".into());
}
// move the ast to the entry of the class in the ast_list // move the ast to the entry of the class in the ast_list
class_def_ast.1 = Some(ast); class_def_ast.1 = Some(ast);
@ -469,7 +476,7 @@ impl TopLevelComposer {
if class_ancestors.is_empty() { if class_ancestors.is_empty() {
vec![] vec![]
} else { } else {
Self::get_all_ancestors_helper(&class_ancestors[0], temp_def_list.as_slice()) Self::get_all_ancestors_helper(&class_ancestors[0], temp_def_list.as_slice())?
}, },
); );
} }
@ -499,9 +506,9 @@ impl TopLevelComposer {
/// step 3, class fields and methods /// step 3, class fields and methods
fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> { fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> {
let temp_def_list = self.extract_def_list(); let temp_def_list = self.extract_def_list();
let unifier = self.unifier.borrow_mut();
let primitives = &self.primitives_ty; let primitives = &self.primitives_ty;
let def_ast_list = &self.definition_ast_list; let def_ast_list = &self.definition_ast_list;
let unifier = self.unifier.borrow_mut();
let mut type_var_to_concrete_def: HashMap<Type, TypeAnnotation> = HashMap::new(); let mut type_var_to_concrete_def: HashMap<Type, TypeAnnotation> = HashMap::new();
@ -517,6 +524,40 @@ impl TopLevelComposer {
)? )?
} }
// handle the inheritanced methods and fields
let mut current_ancestor_depth: usize = 2;
loop {
let mut finished = true;
for (class_def, _) in def_ast_list {
let mut class_def = class_def.write();
if let TopLevelDef::Class { ancestors, .. } = class_def.deref() {
// if the length of the ancestor is equal to the current depth
// it means that all the ancestors of the class is handled
if ancestors.len() == current_ancestor_depth {
finished = false;
Self::analyze_single_class_ancestors(
class_def.deref_mut(),
&temp_def_list,
unifier,
primitives,
&mut type_var_to_concrete_def,
)?;
}
}
}
if finished {
break;
} else {
current_ancestor_depth += 1;
}
if current_ancestor_depth > def_ast_list.len() + 1 {
unreachable!("cannot be longer than the whole top level def list")
}
}
// unification of previously assigned typevar // unification of previously assigned typevar
for (ty, def) in type_var_to_concrete_def { for (ty, def) in type_var_to_concrete_def {
let target_ty = let target_ty =
@ -524,16 +565,6 @@ impl TopLevelComposer {
unifier.unify(ty, target_ty)?; unifier.unify(ty, target_ty)?;
} }
// handle the inheritanced methods and fields
for (class_def, _) in def_ast_list {
Self::analyze_single_class_ancestors(
class_def.clone(),
&temp_def_list,
unifier,
primitives,
)?;
}
Ok(()) Ok(())
} }
@ -596,7 +627,6 @@ impl TopLevelComposer {
annotation, annotation,
)?; )?;
// if there are same type variables appears, we only need to copy them once
let type_vars_within = let type_vars_within =
get_type_var_contained_in_type_annotation(&type_annotation) get_type_var_contained_in_type_annotation(&type_annotation)
.into_iter() .into_iter()
@ -679,6 +709,7 @@ impl TopLevelComposer {
unreachable!("must be both function"); unreachable!("must be both function");
} }
} else { } else {
// not top level function def, skip
continue; continue;
} }
} }
@ -942,16 +973,16 @@ impl TopLevelComposer {
} }
fn analyze_single_class_ancestors( fn analyze_single_class_ancestors(
class_def: Arc<RwLock<TopLevelDef>>, class_def: &mut TopLevelDef,
temp_def_list: &[Arc<RwLock<TopLevelDef>>], temp_def_list: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, _primitives: &PrimitiveStore,
type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>,
) -> Result<(), String> { ) -> Result<(), String> {
let mut class_def = class_def.write();
let ( let (
_class_id, _class_id,
class_ancestor_def, class_ancestor_def,
_class_fields_def, class_fields_def,
class_methods_def, class_methods_def,
_class_type_vars_def, _class_type_vars_def,
_class_resolver, _class_resolver,
@ -963,99 +994,110 @@ impl TopLevelComposer {
resolver, resolver,
type_vars, type_vars,
.. ..
} = class_def.deref_mut() } = class_def
{ {
(*object_id, ancestors, fields, methods, type_vars, resolver) (*object_id, ancestors, fields, methods, type_vars, resolver)
} else { } else {
unreachable!("here must be class def ast"); unreachable!("here must be class def ast");
}; };
for (method_name, method_ty, ..) in class_methods_def { // since when this function is called, the ancestors of the direct parent
if method_name == "__init__" { // are supposed to be already handled, so we only need to deal with the direct parent
continue; let base = class_ancestor_def.get(1).unwrap();
} if let TypeAnnotation::CustomClassKind { id, params: _ } = base {
// search the ancestors from the nearest to the deepest to find overload and check let base = temp_def_list.get(id.0).unwrap();
'search_for_overload: for anc in class_ancestor_def.iter().skip(1) { let base = base.read();
if let TypeAnnotation::CustomClassKind { id, params } = anc { if let TopLevelDef::Class { methods, fields, .. } = &*base {
let anc_class_def = temp_def_list.get(id.0).unwrap().read(); // handle methods override
let anc_class_def = anc_class_def.deref(); // since we need to maintain the order, create a new list
let mut new_child_methods: Vec<(String, Type, DefinitionId)> = Vec::new();
if let TopLevelDef::Class { methods, type_vars, .. } = anc_class_def { let mut is_override: HashSet<String> = HashSet::new();
for (anc_method_name, anc_method_ty, ..) in methods { for (anc_method_name, anc_method_ty, anc_method_def_id) in methods {
// if same name, then is overload, needs check // find if there is a method with same name in the child class
if anc_method_name == method_name { let mut to_be_added =
let param_ty = params (anc_method_name.to_string(), *anc_method_ty, *anc_method_def_id);
.iter() for (class_method_name, class_method_ty, class_method_defid) in
.map(|x| { class_methods_def.iter()
get_type_from_type_annotation_kinds( {
temp_def_list, if class_method_name == anc_method_name {
unifier, // ignore and handle self
primitives, let ok = class_method_name == "__init__"
x, && Self::check_overload_function_type(
) *class_method_ty,
}) *anc_method_ty,
.collect::<Result<Vec<_>, _>>()?; unifier,
type_var_to_concrete_def,
let subst = type_vars );
.iter() if !ok {
.map(|x| { return Err("method has same name as ancestors' method, but incompatible type".into());
if let TypeEnum::TVar { id, .. } =
unifier.get_ty(*x).as_ref()
{
*id
} else {
unreachable!()
}
})
.zip(param_ty.into_iter())
.collect::<HashMap<u32, Type>>();
let anc_method_ty = unifier.subst(*anc_method_ty, &subst).unwrap();
if let (
TypeEnum::TFunc(child_method_sig),
TypeEnum::TFunc(parent_method_sig),
) = (
unifier.get_ty(*method_ty).as_ref(),
unifier.get_ty(anc_method_ty).as_ref(),
) {
let (
FunSignature { args: c_as, ret: c_r, .. },
FunSignature { args: p_as, ret: p_r, .. },
) = (&*child_method_sig.borrow(), &*parent_method_sig.borrow());
// arguments
for (
FuncArg { name: c_name, ty: c_ty, .. },
FuncArg { name: p_name, ty: p_ty, .. },
) in c_as.iter().zip(p_as)
{
if c_name == "self" {
continue;
}
if c_name != p_name
|| !Self::check_overload_type_compatible(
unifier, *c_ty, *p_ty,
)
{
return Err("incompatible parameter".into());
}
}
// check the compatibility of c_r and p_r
if !Self::check_overload_type_compatible(unifier, *c_r, *p_r) {
return Err("incompatible parameter".into());
}
} else {
unreachable!("must be function type")
}
break 'search_for_overload;
} }
// mark it as added
is_override.insert(class_method_name.to_string());
to_be_added = (
class_method_name.to_string(),
*class_method_ty,
*class_method_defid,
);
break;
} }
} }
new_child_methods.push(to_be_added);
} }
// add those that are not overriding method to the new_child_methods
for (class_method_name, class_method_ty, class_method_defid) in
class_methods_def.iter()
{
if !is_override.contains(class_method_name) {
new_child_methods.push((
class_method_name.to_string(),
*class_method_ty,
*class_method_defid,
));
}
}
// use the new_child_methods to replace all the elements in `class_methods_def`
class_methods_def.drain(..);
class_methods_def.extend(new_child_methods);
// handle class fields
let mut new_child_fields: Vec<(String, Type)> = Vec::new();
let mut is_override: HashSet<String> = HashSet::new();
for (anc_field_name, anc_field_ty) in fields {
let mut to_be_added = (anc_field_name.to_string(), *anc_field_ty);
// find if there is a fields with the same name in the child class
for (class_field_name, class_field_ty) in class_fields_def.iter() {
if class_field_name == anc_field_name {
let ok = Self::check_overload_field_type(
*class_field_ty,
*anc_field_ty,
unifier,
type_var_to_concrete_def,
);
if !ok {
return Err("fields has same name as ancestors' field, but incompatible type".into());
}
// mark it as added
is_override.insert(class_field_name.to_string());
to_be_added = (class_field_name.to_string(), *class_field_ty);
break;
}
}
new_child_fields.push(to_be_added);
}
for (class_field_name, class_field_ty) in class_fields_def.iter() {
if !is_override.contains(class_field_name) {
new_child_fields.push((class_field_name.to_string(), *class_field_ty));
}
}
class_fields_def.drain(..);
class_fields_def.extend(new_child_fields);
} else {
unreachable!("must be top level class def")
} }
} else {
unreachable!("must be class type annotation")
} }
Ok(()) Ok(())
} }
} }

View File

@ -1,3 +1,5 @@
use crate::typecheck::typedef::TypeVarMeta;
use super::*; use super::*;
#[derive(Clone)] #[derive(Clone)]
@ -323,3 +325,56 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<Ty
} }
result result
} }
/// check the type compatibility for overload
pub fn check_overload_type_annotation_compatible(
this: &TypeAnnotation,
other: &TypeAnnotation,
unifier: &mut Unifier,
) -> bool {
match (this, other) {
(TypeAnnotation::PrimitiveKind(a), TypeAnnotation::PrimitiveKind(b)) => a == b,
(TypeAnnotation::TypeVarKind(a), TypeAnnotation::TypeVarKind(b)) => {
let a = unifier.get_ty(*a);
let a = a.deref();
let b = unifier.get_ty(*b);
let b = b.deref();
if let (
TypeEnum::TVar { id: a, meta: TypeVarMeta::Generic, .. },
TypeEnum::TVar { id: b, meta: TypeVarMeta::Generic, .. },
) = (a, b)
{
a == b
} else {
unreachable!("must be type var")
}
}
(TypeAnnotation::VirtualKind(a), TypeAnnotation::VirtualKind(b))
| (TypeAnnotation::ListKind(a), TypeAnnotation::ListKind(b)) => {
check_overload_type_annotation_compatible(a.as_ref(), b.as_ref(), unifier)
}
(TypeAnnotation::TupleKind(a), TypeAnnotation::TupleKind(b)) => {
a.len() == b.len() && {
a.iter()
.zip(b)
.all(|(a, b)| check_overload_type_annotation_compatible(a, b, unifier))
}
}
(
TypeAnnotation::CustomClassKind { id: a, params: a_p },
TypeAnnotation::CustomClassKind { id: b, params: b_p },
) => {
a.0 == b.0 && {
a_p.len() == b_p.len() && {
a_p.iter()
.zip(b_p)
.all(|(a, b)| check_overload_type_annotation_compatible(a, b, unifier))
}
}
}
_ => false,
}
}