forked from M-Labs/nac3
nac3core: top level cleanup, rewrite ancestors handling, __init__ occruence check
This commit is contained in:
parent
4a9593efa3
commit
7bbd608492
@ -1,5 +1,4 @@
|
||||
use super::*;
|
||||
use crate::typecheck::typedef::TypeVarMeta;
|
||||
|
||||
impl TopLevelComposer {
|
||||
pub fn make_primitives() -> (PrimitiveStore, Unifier) {
|
||||
@ -93,14 +92,31 @@ impl TopLevelComposer {
|
||||
pub fn get_all_ancestors_helper(
|
||||
child: &TypeAnnotation,
|
||||
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
|
||||
) -> Vec<TypeAnnotation> {
|
||||
) -> Result<Vec<TypeAnnotation>, String> {
|
||||
let mut result: Vec<TypeAnnotation> = Vec::new();
|
||||
let mut parent = Self::get_parent(child, temp_def_list);
|
||||
while let Some(p) = parent {
|
||||
parent = Self::get_parent(&p, temp_def_list);
|
||||
result.push(p);
|
||||
let p_id = if let TypeAnnotation::CustomClassKind { id, .. } = &p {
|
||||
*id
|
||||
} else {
|
||||
unreachable!("must be class kind annotation")
|
||||
};
|
||||
// check cycle
|
||||
let no_cycle = result.iter().all(|x| {
|
||||
if let TypeAnnotation::CustomClassKind { id, .. } = x {
|
||||
id.0 != p_id.0
|
||||
} else {
|
||||
unreachable!("must be class kind annotation")
|
||||
}
|
||||
});
|
||||
if no_cycle {
|
||||
result.push(p);
|
||||
} else {
|
||||
return Err("cyclic inheritance detected".into());
|
||||
}
|
||||
}
|
||||
result
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// should only be called when finding all ancestors, so panic when wrong
|
||||
@ -126,51 +142,6 @@ impl TopLevelComposer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_overload_type_compatible(unifier: &mut Unifier, ty: Type, other: Type) -> bool {
|
||||
let ty = unifier.get_ty(ty);
|
||||
let ty = ty.as_ref();
|
||||
let other = unifier.get_ty(other);
|
||||
let other = other.as_ref();
|
||||
|
||||
match (ty, other) {
|
||||
(TypeEnum::TList { ty }, TypeEnum::TList { ty: other })
|
||||
| (TypeEnum::TVirtual { ty }, TypeEnum::TVirtual { ty: other }) => {
|
||||
Self::check_overload_type_compatible(unifier, *ty, *other)
|
||||
}
|
||||
|
||||
(TypeEnum::TTuple { ty }, TypeEnum::TTuple { ty: other }) => ty
|
||||
.iter()
|
||||
.zip(other)
|
||||
.all(|(ty, other)| Self::check_overload_type_compatible(unifier, *ty, *other)),
|
||||
|
||||
(
|
||||
TypeEnum::TObj { obj_id, params, .. },
|
||||
TypeEnum::TObj { obj_id: other_obj_id, params: other_params, .. },
|
||||
) => {
|
||||
let params = &*params.borrow();
|
||||
let other_params = &*other_params.borrow();
|
||||
obj_id.0 == other_obj_id.0
|
||||
&& (params.iter().all(|(var_id, ty)| {
|
||||
if let Some(other_ty) = other_params.get(var_id) {
|
||||
Self::check_overload_type_compatible(unifier, *ty, *other_ty)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
(
|
||||
TypeEnum::TVar { id, meta: TypeVarMeta::Generic, .. },
|
||||
TypeEnum::TVar { id: other_id, meta: TypeVarMeta::Generic, .. },
|
||||
) => {
|
||||
// NOTE: directly compare var_id?
|
||||
*id == *other_id
|
||||
}
|
||||
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// get the var_id of a given TVar type
|
||||
pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<u32, String> {
|
||||
if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() {
|
||||
@ -179,4 +150,63 @@ impl TopLevelComposer {
|
||||
Err("not type var".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_overload_function_type(
|
||||
this: Type,
|
||||
other: Type,
|
||||
unifier: &mut Unifier,
|
||||
type_var_to_concrete_def: &HashMap<Type, TypeAnnotation>,
|
||||
) -> bool {
|
||||
let this = unifier.get_ty(this);
|
||||
let this = this.as_ref();
|
||||
let other = unifier.get_ty(other);
|
||||
let other = other.as_ref();
|
||||
if let (TypeEnum::TFunc(this_sig), TypeEnum::TFunc(other_sig)) = (this, other) {
|
||||
let (this_sig, other_sig) = (&*this_sig.borrow(), &*other_sig.borrow());
|
||||
let (
|
||||
FunSignature { args: this_args, ret: this_ret, vars: _this_vars },
|
||||
FunSignature { args: other_args, ret: other_ret, vars: _other_vars },
|
||||
) = (this_sig, other_sig);
|
||||
// check args
|
||||
let args_ok = this_args
|
||||
.iter()
|
||||
.map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap()))
|
||||
.zip(other_args.iter().map(|FuncArg { name, ty, .. }| {
|
||||
(name, type_var_to_concrete_def.get(ty).unwrap())
|
||||
}))
|
||||
.all(|(this, other)| {
|
||||
if this.0 == "self" && this.0 == other.0 {
|
||||
true
|
||||
} else {
|
||||
this.0 == other.0
|
||||
&& check_overload_type_annotation_compatible(this.1, other.1, unifier)
|
||||
}
|
||||
});
|
||||
|
||||
// check rets
|
||||
let ret_ok = check_overload_type_annotation_compatible(
|
||||
type_var_to_concrete_def.get(this_ret).unwrap(),
|
||||
type_var_to_concrete_def.get(other_ret).unwrap(),
|
||||
unifier,
|
||||
);
|
||||
|
||||
// return
|
||||
args_ok && ret_ok
|
||||
} else {
|
||||
unreachable!("this function must be called with function type")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_overload_field_type(
|
||||
this: Type,
|
||||
other: Type,
|
||||
unifier: &mut Unifier,
|
||||
type_var_to_concrete_def: &HashMap<Type, TypeAnnotation>,
|
||||
) -> bool {
|
||||
check_overload_type_annotation_compatible(
|
||||
type_var_to_concrete_def.get(&this).unwrap(),
|
||||
type_var_to_concrete_def.get(&other).unwrap(),
|
||||
unifier,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -195,6 +195,7 @@ impl TopLevelComposer {
|
||||
// we do not push anything to the def list, so we keep track of the index
|
||||
// and then push in the correct order after the for loop
|
||||
let mut class_method_index_offset = 0;
|
||||
let mut has_init = false;
|
||||
for b in body {
|
||||
if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node {
|
||||
if self.keyword_list.contains(name) {
|
||||
@ -205,6 +206,9 @@ impl TopLevelComposer {
|
||||
if !defined_class_method_name.insert(global_class_method_name.clone()) {
|
||||
return Err("duplicate class method definition".into());
|
||||
}
|
||||
if method_name == "__init__" {
|
||||
has_init = true;
|
||||
}
|
||||
let method_def_id = self.definition_ast_list.len() + {
|
||||
// plus 1 here since we already have the class def
|
||||
class_method_index_offset += 1;
|
||||
@ -230,6 +234,9 @@ impl TopLevelComposer {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if !has_init {
|
||||
return Err("class def must have __init__ method defined".into());
|
||||
}
|
||||
|
||||
// move the ast to the entry of the class in the ast_list
|
||||
class_def_ast.1 = Some(ast);
|
||||
@ -469,7 +476,7 @@ impl TopLevelComposer {
|
||||
if class_ancestors.is_empty() {
|
||||
vec![]
|
||||
} else {
|
||||
Self::get_all_ancestors_helper(&class_ancestors[0], temp_def_list.as_slice())
|
||||
Self::get_all_ancestors_helper(&class_ancestors[0], temp_def_list.as_slice())?
|
||||
},
|
||||
);
|
||||
}
|
||||
@ -499,9 +506,9 @@ impl TopLevelComposer {
|
||||
/// step 3, class fields and methods
|
||||
fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> {
|
||||
let temp_def_list = self.extract_def_list();
|
||||
let unifier = self.unifier.borrow_mut();
|
||||
let primitives = &self.primitives_ty;
|
||||
let def_ast_list = &self.definition_ast_list;
|
||||
let unifier = self.unifier.borrow_mut();
|
||||
|
||||
let mut type_var_to_concrete_def: HashMap<Type, TypeAnnotation> = HashMap::new();
|
||||
|
||||
@ -517,6 +524,40 @@ impl TopLevelComposer {
|
||||
)?
|
||||
}
|
||||
|
||||
// handle the inheritanced methods and fields
|
||||
let mut current_ancestor_depth: usize = 2;
|
||||
loop {
|
||||
let mut finished = true;
|
||||
|
||||
for (class_def, _) in def_ast_list {
|
||||
let mut class_def = class_def.write();
|
||||
if let TopLevelDef::Class { ancestors, .. } = class_def.deref() {
|
||||
// if the length of the ancestor is equal to the current depth
|
||||
// it means that all the ancestors of the class is handled
|
||||
if ancestors.len() == current_ancestor_depth {
|
||||
finished = false;
|
||||
Self::analyze_single_class_ancestors(
|
||||
class_def.deref_mut(),
|
||||
&temp_def_list,
|
||||
unifier,
|
||||
primitives,
|
||||
&mut type_var_to_concrete_def,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if finished {
|
||||
break;
|
||||
} else {
|
||||
current_ancestor_depth += 1;
|
||||
}
|
||||
|
||||
if current_ancestor_depth > def_ast_list.len() + 1 {
|
||||
unreachable!("cannot be longer than the whole top level def list")
|
||||
}
|
||||
}
|
||||
|
||||
// unification of previously assigned typevar
|
||||
for (ty, def) in type_var_to_concrete_def {
|
||||
let target_ty =
|
||||
@ -524,16 +565,6 @@ impl TopLevelComposer {
|
||||
unifier.unify(ty, target_ty)?;
|
||||
}
|
||||
|
||||
// handle the inheritanced methods and fields
|
||||
for (class_def, _) in def_ast_list {
|
||||
Self::analyze_single_class_ancestors(
|
||||
class_def.clone(),
|
||||
&temp_def_list,
|
||||
unifier,
|
||||
primitives,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -596,7 +627,6 @@ impl TopLevelComposer {
|
||||
annotation,
|
||||
)?;
|
||||
|
||||
// if there are same type variables appears, we only need to copy them once
|
||||
let type_vars_within =
|
||||
get_type_var_contained_in_type_annotation(&type_annotation)
|
||||
.into_iter()
|
||||
@ -679,6 +709,7 @@ impl TopLevelComposer {
|
||||
unreachable!("must be both function");
|
||||
}
|
||||
} else {
|
||||
// not top level function def, skip
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -942,16 +973,16 @@ impl TopLevelComposer {
|
||||
}
|
||||
|
||||
fn analyze_single_class_ancestors(
|
||||
class_def: Arc<RwLock<TopLevelDef>>,
|
||||
class_def: &mut TopLevelDef,
|
||||
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
_primitives: &PrimitiveStore,
|
||||
type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>,
|
||||
) -> Result<(), String> {
|
||||
let mut class_def = class_def.write();
|
||||
let (
|
||||
_class_id,
|
||||
class_ancestor_def,
|
||||
_class_fields_def,
|
||||
class_fields_def,
|
||||
class_methods_def,
|
||||
_class_type_vars_def,
|
||||
_class_resolver,
|
||||
@ -963,99 +994,110 @@ impl TopLevelComposer {
|
||||
resolver,
|
||||
type_vars,
|
||||
..
|
||||
} = class_def.deref_mut()
|
||||
} = class_def
|
||||
{
|
||||
(*object_id, ancestors, fields, methods, type_vars, resolver)
|
||||
} else {
|
||||
unreachable!("here must be class def ast");
|
||||
};
|
||||
|
||||
for (method_name, method_ty, ..) in class_methods_def {
|
||||
if method_name == "__init__" {
|
||||
continue;
|
||||
}
|
||||
// search the ancestors from the nearest to the deepest to find overload and check
|
||||
'search_for_overload: for anc in class_ancestor_def.iter().skip(1) {
|
||||
if let TypeAnnotation::CustomClassKind { id, params } = anc {
|
||||
let anc_class_def = temp_def_list.get(id.0).unwrap().read();
|
||||
let anc_class_def = anc_class_def.deref();
|
||||
|
||||
if let TopLevelDef::Class { methods, type_vars, .. } = anc_class_def {
|
||||
for (anc_method_name, anc_method_ty, ..) in methods {
|
||||
// if same name, then is overload, needs check
|
||||
if anc_method_name == method_name {
|
||||
let param_ty = params
|
||||
.iter()
|
||||
.map(|x| {
|
||||
get_type_from_type_annotation_kinds(
|
||||
temp_def_list,
|
||||
unifier,
|
||||
primitives,
|
||||
x,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let subst = type_vars
|
||||
.iter()
|
||||
.map(|x| {
|
||||
if let TypeEnum::TVar { id, .. } =
|
||||
unifier.get_ty(*x).as_ref()
|
||||
{
|
||||
*id
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})
|
||||
.zip(param_ty.into_iter())
|
||||
.collect::<HashMap<u32, Type>>();
|
||||
|
||||
let anc_method_ty = unifier.subst(*anc_method_ty, &subst).unwrap();
|
||||
|
||||
if let (
|
||||
TypeEnum::TFunc(child_method_sig),
|
||||
TypeEnum::TFunc(parent_method_sig),
|
||||
) = (
|
||||
unifier.get_ty(*method_ty).as_ref(),
|
||||
unifier.get_ty(anc_method_ty).as_ref(),
|
||||
) {
|
||||
let (
|
||||
FunSignature { args: c_as, ret: c_r, .. },
|
||||
FunSignature { args: p_as, ret: p_r, .. },
|
||||
) = (&*child_method_sig.borrow(), &*parent_method_sig.borrow());
|
||||
|
||||
// arguments
|
||||
for (
|
||||
FuncArg { name: c_name, ty: c_ty, .. },
|
||||
FuncArg { name: p_name, ty: p_ty, .. },
|
||||
) in c_as.iter().zip(p_as)
|
||||
{
|
||||
if c_name == "self" {
|
||||
continue;
|
||||
}
|
||||
if c_name != p_name
|
||||
|| !Self::check_overload_type_compatible(
|
||||
unifier, *c_ty, *p_ty,
|
||||
)
|
||||
{
|
||||
return Err("incompatible parameter".into());
|
||||
}
|
||||
}
|
||||
|
||||
// check the compatibility of c_r and p_r
|
||||
if !Self::check_overload_type_compatible(unifier, *c_r, *p_r) {
|
||||
return Err("incompatible parameter".into());
|
||||
}
|
||||
} else {
|
||||
unreachable!("must be function type")
|
||||
}
|
||||
break 'search_for_overload;
|
||||
// since when this function is called, the ancestors of the direct parent
|
||||
// are supposed to be already handled, so we only need to deal with the direct parent
|
||||
let base = class_ancestor_def.get(1).unwrap();
|
||||
if let TypeAnnotation::CustomClassKind { id, params: _ } = base {
|
||||
let base = temp_def_list.get(id.0).unwrap();
|
||||
let base = base.read();
|
||||
if let TopLevelDef::Class { methods, fields, .. } = &*base {
|
||||
// handle methods override
|
||||
// since we need to maintain the order, create a new list
|
||||
let mut new_child_methods: Vec<(String, Type, DefinitionId)> = Vec::new();
|
||||
let mut is_override: HashSet<String> = HashSet::new();
|
||||
for (anc_method_name, anc_method_ty, anc_method_def_id) in methods {
|
||||
// find if there is a method with same name in the child class
|
||||
let mut to_be_added =
|
||||
(anc_method_name.to_string(), *anc_method_ty, *anc_method_def_id);
|
||||
for (class_method_name, class_method_ty, class_method_defid) in
|
||||
class_methods_def.iter()
|
||||
{
|
||||
if class_method_name == anc_method_name {
|
||||
// ignore and handle self
|
||||
let ok = class_method_name == "__init__"
|
||||
&& Self::check_overload_function_type(
|
||||
*class_method_ty,
|
||||
*anc_method_ty,
|
||||
unifier,
|
||||
type_var_to_concrete_def,
|
||||
);
|
||||
if !ok {
|
||||
return Err("method has same name as ancestors' method, but incompatible type".into());
|
||||
}
|
||||
// mark it as added
|
||||
is_override.insert(class_method_name.to_string());
|
||||
to_be_added = (
|
||||
class_method_name.to_string(),
|
||||
*class_method_ty,
|
||||
*class_method_defid,
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
new_child_methods.push(to_be_added);
|
||||
}
|
||||
// add those that are not overriding method to the new_child_methods
|
||||
for (class_method_name, class_method_ty, class_method_defid) in
|
||||
class_methods_def.iter()
|
||||
{
|
||||
if !is_override.contains(class_method_name) {
|
||||
new_child_methods.push((
|
||||
class_method_name.to_string(),
|
||||
*class_method_ty,
|
||||
*class_method_defid,
|
||||
));
|
||||
}
|
||||
}
|
||||
// use the new_child_methods to replace all the elements in `class_methods_def`
|
||||
class_methods_def.drain(..);
|
||||
class_methods_def.extend(new_child_methods);
|
||||
|
||||
// handle class fields
|
||||
let mut new_child_fields: Vec<(String, Type)> = Vec::new();
|
||||
let mut is_override: HashSet<String> = HashSet::new();
|
||||
for (anc_field_name, anc_field_ty) in fields {
|
||||
let mut to_be_added = (anc_field_name.to_string(), *anc_field_ty);
|
||||
// find if there is a fields with the same name in the child class
|
||||
for (class_field_name, class_field_ty) in class_fields_def.iter() {
|
||||
if class_field_name == anc_field_name {
|
||||
let ok = Self::check_overload_field_type(
|
||||
*class_field_ty,
|
||||
*anc_field_ty,
|
||||
unifier,
|
||||
type_var_to_concrete_def,
|
||||
);
|
||||
if !ok {
|
||||
return Err("fields has same name as ancestors' field, but incompatible type".into());
|
||||
}
|
||||
// mark it as added
|
||||
is_override.insert(class_field_name.to_string());
|
||||
to_be_added = (class_field_name.to_string(), *class_field_ty);
|
||||
break;
|
||||
}
|
||||
}
|
||||
new_child_fields.push(to_be_added);
|
||||
}
|
||||
for (class_field_name, class_field_ty) in class_fields_def.iter() {
|
||||
if !is_override.contains(class_field_name) {
|
||||
new_child_fields.push((class_field_name.to_string(), *class_field_ty));
|
||||
}
|
||||
}
|
||||
class_fields_def.drain(..);
|
||||
class_fields_def.extend(new_child_fields);
|
||||
} else {
|
||||
unreachable!("must be top level class def")
|
||||
}
|
||||
} else {
|
||||
unreachable!("must be class type annotation")
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
use crate::typecheck::typedef::TypeVarMeta;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -323,3 +325,56 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<Ty
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// check the type compatibility for overload
|
||||
pub fn check_overload_type_annotation_compatible(
|
||||
this: &TypeAnnotation,
|
||||
other: &TypeAnnotation,
|
||||
unifier: &mut Unifier,
|
||||
) -> bool {
|
||||
match (this, other) {
|
||||
(TypeAnnotation::PrimitiveKind(a), TypeAnnotation::PrimitiveKind(b)) => a == b,
|
||||
(TypeAnnotation::TypeVarKind(a), TypeAnnotation::TypeVarKind(b)) => {
|
||||
let a = unifier.get_ty(*a);
|
||||
let a = a.deref();
|
||||
let b = unifier.get_ty(*b);
|
||||
let b = b.deref();
|
||||
if let (
|
||||
TypeEnum::TVar { id: a, meta: TypeVarMeta::Generic, .. },
|
||||
TypeEnum::TVar { id: b, meta: TypeVarMeta::Generic, .. },
|
||||
) = (a, b)
|
||||
{
|
||||
a == b
|
||||
} else {
|
||||
unreachable!("must be type var")
|
||||
}
|
||||
}
|
||||
(TypeAnnotation::VirtualKind(a), TypeAnnotation::VirtualKind(b))
|
||||
| (TypeAnnotation::ListKind(a), TypeAnnotation::ListKind(b)) => {
|
||||
check_overload_type_annotation_compatible(a.as_ref(), b.as_ref(), unifier)
|
||||
}
|
||||
|
||||
(TypeAnnotation::TupleKind(a), TypeAnnotation::TupleKind(b)) => {
|
||||
a.len() == b.len() && {
|
||||
a.iter()
|
||||
.zip(b)
|
||||
.all(|(a, b)| check_overload_type_annotation_compatible(a, b, unifier))
|
||||
}
|
||||
}
|
||||
|
||||
(
|
||||
TypeAnnotation::CustomClassKind { id: a, params: a_p },
|
||||
TypeAnnotation::CustomClassKind { id: b, params: b_p },
|
||||
) => {
|
||||
a.0 == b.0 && {
|
||||
a_p.len() == b_p.len() && {
|
||||
a_p.iter()
|
||||
.zip(b_p)
|
||||
.all(|(a, b)| check_overload_type_annotation_compatible(a, b, unifier))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user