hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
1 changed files with 201 additions and 166 deletions
Showing only changes of commit 529442590f - Show all commits

View File

@ -4,9 +4,9 @@ use std::{collections::HashMap, collections::HashSet, sync::Arc};
use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier}; use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier};
use crate::symbol_resolver::SymbolResolver; use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Mapping};
use crate::typecheck::typedef::{FunSignature, FuncArg}; use crate::typecheck::typedef::{FunSignature, FuncArg};
use itertools::{Itertools, chain}; use itertools::Itertools;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use rustpython_parser::ast::{self, Stmt}; use rustpython_parser::ast::{self, Stmt};
@ -154,7 +154,7 @@ impl TopLevelComposer {
top_level_def_list.into_iter().zip(ast_list).collect_vec() top_level_def_list.into_iter().zip(ast_list).collect_vec()
).into(), ).into(),
primitives: primitives.0, primitives: primitives.0,
unifier: primitives.1.into(), unifier: primitives.1,
class_method_to_def_id: Default::default(), class_method_to_def_id: Default::default(),
to_be_analyzed_class: Default::default(), to_be_analyzed_class: Default::default(),
}; };
@ -252,22 +252,11 @@ impl TopLevelComposer {
// move the ast to the entry of the class in the ast_list // move the ast to the entry of the class in the ast_list
class_def_ast.1 = Some(ast); class_def_ast.1 = Some(ast);
// put methods into the class def
{
let mut class_def = class_def_ast.0.write();
let class_def_methods =
if let TopLevelDef::Class { methods, .. } = class_def.deref_mut() {
methods
} else { unimplemented!() };
for (name, _, id) in &class_method_name_def_ids {
class_def_methods.push((name.into(), self.primitives.none, *id));
}
}
// now class_def_ast and class_method_def_ast_ids are ok, put them into actual def list in correct order // now class_def_ast and class_method_def_ast_ids are ok, put them into actual def list in correct order
def_list.push(class_def_ast); def_list.push(class_def_ast);
for (_, def, _) in class_method_name_def_ids { for (name, def, id) in class_method_name_def_ids {
def_list.push((def, None)); def_list.push((def, None));
self.class_method_to_def_id.insert(name, id);
} }
// put the constructor into the def_list // put the constructor into the def_list
@ -280,8 +269,7 @@ impl TopLevelComposer {
)); ));
// class, put its def_id into the to be analyzed set // class, put its def_id into the to be analyzed set
let to_be_analyzed = &mut self.to_be_analyzed_class; self.to_be_analyzed_class.push(DefinitionId(class_def_id));
to_be_analyzed.push(DefinitionId(class_def_id));
Ok((class_name, DefinitionId(class_def_id))) Ok((class_name, DefinitionId(class_def_id)))
} }
@ -461,38 +449,50 @@ impl TopLevelComposer {
return Err("expect concrete class/type to be base class".into()); return Err("expect concrete class/type to be base class".into());
}; };
// write to the class ancestors // write to the class ancestors, make sure the uniqueness
class_ancestors.push(base_id); if !class_ancestors.contains(&base_id) {
class_ancestors.push(base_id);
} else {
return Err("cannot specify the same base class twice".into())
}
} }
} }
Ok(()) Ok(())
} }
/// step 3, class fields and methods /// step 3, class fields and methods
// FIXME: need analyze base classes here
// FIXME: how to deal with self type
// FIXME: how to prevent cycles
fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> { fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> {
let mut def_list = self.definition_ast_list.write(); let mut def_ast_list = self.definition_ast_list.write();
let converted_top_level = &self.to_top_level_context(); let converted_top_level = &self.to_top_level_context();
let primitives = &self.primitives; let primitives = &self.primitives;
let to_be_analyzed_class = &mut self.to_be_analyzed_class; let to_be_analyzed_class = &mut self.to_be_analyzed_class;
let unifier = &mut self.unifier; let unifier = &mut self.unifier;
while !to_be_analyzed_class.is_empty() { 'class: loop{
if to_be_analyzed_class.is_empty() { break; }
let class_ind = to_be_analyzed_class.remove(0).0; let class_ind = to_be_analyzed_class.remove(0).0;
let (class_name, class_body, classs_def) = { let (class_name, class_body, class_resolver) = {
let class_ast = def_list[class_ind].1.as_ref(); let (class_def, class_ast) = &mut def_ast_list[class_ind];
if let Some(ast::Located { if let Some(ast::Located {
node: ast::StmtKind::ClassDef { name, body, .. }, .. node: ast::StmtKind::ClassDef { name, body, .. }, ..
}) = class_ast }) = class_ast.as_ref()
{ {
let class_def = def_list[class_ind].0; if let TopLevelDef::Class { resolver, .. } = class_def.write().deref() {
(name, body, class_def) (name, body, resolver.as_ref().unwrap().clone())
} else { unreachable!() }
} else { } else {
unreachable!("should be class def ast") unreachable!("should be class def ast")
} }
}; };
let class_methods_parsing_result: Vec<(String, Type, DefinitionId)> = vec![]; // need these vectors to check re-defining methods, class fields
let class_fields_parsing_result: Vec<(String, Type)> = vec![]; // and store the parsed result in case some method cannot be typed for now
let mut class_methods_parsing_result: Vec<(String, Type, DefinitionId)> = vec![];
let mut class_fields_parsing_result: Vec<(String, Type)> = vec![];
for b in class_body { for b in class_body {
if let ast::StmtKind::FunctionDef { if let ast::StmtKind::FunctionDef {
args: method_args_ast, args: method_args_ast,
@ -502,179 +502,191 @@ impl TopLevelComposer {
.. ..
} = &b.node } = &b.node
{ {
let (class_def, method_def) = { let arg_name_tys: Vec<(String, Type)> = {
// unwrap should not fail let mut result = vec![];
let method_ind = class_method_to_def_id for a in &method_args_ast.args {
.get(&Self::name_mangling(class_name.into(), method_name)) if a.node.arg != "self" {
.unwrap() let annotation = a
.0;
// split the def_list to two parts to get the
// mutable reference to both the method and the class
assert_ne!(method_ind, class_ind);
let min_ind =
(if method_ind > class_ind { class_ind } else { method_ind }) + 1;
let (head_slice, tail_slice) = def_list.split_at_mut(min_ind);
let (new_method_ind, new_class_ind) = (
if method_ind >= min_ind { method_ind - min_ind } else { method_ind },
if class_ind >= min_ind { class_ind - min_ind } else { class_ind },
);
if new_class_ind == class_ind {
(&mut head_slice[new_class_ind], &mut tail_slice[new_method_ind])
} else {
(&mut tail_slice[new_class_ind], &mut head_slice[new_method_ind])
}
};
let (class_fields, class_methods, class_resolver) = {
if let TopLevelDef::Class { resolver, fields, methods, .. } =
class_def.0.get_mut()
{
(fields, methods, resolver)
} else {
unreachable!("must be class def here")
}
};
let arg_tys = method_args_ast
.args
.iter()
.map(|x| -> Result<Type, String> {
if x.node.arg != "self" {
let annotation = x
.node .node
.annotation .annotation
.as_ref() .as_ref()
.ok_or_else(|| { .ok_or_else(|| {
"type annotation for function parameter is needed".to_string() "type annotation for function parameter is needed".to_string()
})? })?.as_ref();
.as_ref();
let ty = let ty =
class_resolver.as_ref().unwrap().lock().parse_type_annotation( class_resolver.as_ref().lock().parse_type_annotation(
converted_top_level, converted_top_level,
unifier.borrow_mut(), unifier.borrow_mut(),
primitives, primitives,
annotation, annotation,
)?; )?;
Ok(ty) if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) {
to_be_analyzed_class.push(DefinitionId(class_ind));
continue 'class;
}
result.push((a.node.arg.to_string(), ty));
} else { } else {
// TODO: handle self, how // TODO: handle self, how
unimplemented!() unimplemented!()
} }
}) }
.collect::<Result<Vec<_>, _>>()?; result
};
let ret_ty = if method_name != "__init__" { let method_type_var =
method_returns_ast arg_name_tys
.as_ref() .iter()
.map(|x| .filter_map(|(_, ty)| {
class_resolver.as_ref().unwrap().lock().parse_type_annotation( let ty_enum = unifier.get_ty(*ty);
converted_top_level, if let TypeEnum::TVar { id, .. } = ty_enum.as_ref() {
unifier.borrow_mut(), Some((*id, *ty))
primitives, } else { None }
x.as_ref(), })
.collect::<Mapping<u32>>();
let ret_ty = {
if method_name != "__init__" {
let ty = method_returns_ast
.as_ref()
.map(|x|
class_resolver.as_ref().lock().parse_type_annotation(
converted_top_level,
unifier.borrow_mut(),
primitives,
x.as_ref(),
)
) )
) .ok_or_else(|| "return type annotation error".to_string())??;
.ok_or_else(|| "return type annotation needed".to_string())?? if !Self::check_ty_analyzed(ty, unifier, to_be_analyzed_class) {
} else { to_be_analyzed_class.push(DefinitionId(class_ind));
// TODO: self type, how continue 'class;
unimplemented!() } else { ty }
} else {
// TODO: __init__ function, self type, how
unimplemented!()
}
}; };
// handle fields // handle fields
if method_name == "__init__" { let class_field_name_tys: Option<Vec<(String, Type)>> =
for body in method_body_ast { if method_name == "__init__" {
match &body.node { let mut result: Vec<(String, Type)> = vec![];
ast::StmtKind::AnnAssign { for body in method_body_ast {
target, match &body.node {
annotation, ast::StmtKind::AnnAssign {
.. target,
} if { annotation,
if let ast::ExprKind::Attribute {
value,
attr,
.. ..
} = &target.node { } if {
if let ast::ExprKind::Name {id, ..} = &value.node { if let ast::ExprKind::Attribute {
id == "self" value, ..
} = &target.node {
matches!(
&value.node,
ast::ExprKind::Name { id, .. } if id == "self")
} else { false } } else { false }
} else { false } } => {
} => { let field_ty = class_resolver.as_ref().lock().parse_type_annotation(
// TODO: record this field with its type converted_top_level,
}, unifier.borrow_mut(),
primitives,
annotation.as_ref())?;
if !Self::check_ty_analyzed(field_ty, unifier, to_be_analyzed_class) {
to_be_analyzed_class.push(DefinitionId(class_ind));
continue 'class;
} else {
result.push((
if let ast::ExprKind::Attribute {
attr, ..
} = &target.node {
attr.to_string()
} else { unreachable!() },
field_ty
)) }
},
// TODO: exclude those without type annotation // exclude those without type annotation
ast::StmtKind::Assign { ast::StmtKind::Assign {
targets, targets, ..
.. } if {
} if { if let ast::ExprKind::Attribute {
if let ast::ExprKind::Attribute { value, ..
value, } = &targets[0].node {
attr, matches!(
.. &value.node,
} = &targets[0].node { ast::ExprKind::Name {id, ..} if id == "self")
if let ast::ExprKind::Name {id, ..} = &value.node {
id == "self"
} else { false } } else { false }
} else { false } } => {
} => { return Err("class fields type annotation needed".into())
unimplemented!() },
},
// do nothing // do nothing
_ => { } _ => { }
} }
} };
Some(result)
} else { None };
// current method all type ok, put the current method into the list
if class_methods_parsing_result
.iter()
.any(|(name, _, _)| name == method_name) {
return Err("duplicate method definition".into())
} else {
class_methods_parsing_result.push((
method_name.clone(),
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
args: arg_name_tys.into_iter().map(|(name, ty)| {
FuncArg {
name,
ty,
default_value: None
}
}).collect_vec(),
vars: method_type_var
}.into())),
*self.class_method_to_def_id.get(&Self::name_mangling(class_name.clone(), method_name)).unwrap()
))
} }
let all_tys_ok = { // put the fiedlds inside
let ret_ty_iter = vec![ret_ty]; if let Some(class_field_name_tys) = class_field_name_tys {
let ret_ty_iter = ret_ty_iter.iter(); assert!(class_fields_parsing_result.is_empty());
let mut all_tys = chain!(arg_tys.iter(), ret_ty_iter); class_fields_parsing_result.extend(class_field_name_tys);
all_tys.all(|x| {
let type_enum = unifier.get_ty(*x);
match type_enum.as_ref() {
TypeEnum::TObj { obj_id, .. } => {
!to_be_analyzed_class.contains(obj_id)
}
TypeEnum::TVirtual { ty } => {
if let TypeEnum::TObj { obj_id, .. } =
unifier.get_ty(*ty).as_ref()
{
!to_be_analyzed_class.contains(obj_id)
} else {
unreachable!()
}
}
TypeEnum::TVar { .. } => true,
_ => unreachable!(),
}
})
};
if all_tys_ok {
// TODO: put related value to the `class_methods_parsing_result`
unimplemented!()
} else {
to_be_analyzed_class.push(DefinitionId(class_ind));
// TODO: go to the next WHILE loop
unimplemented!()
} }
} else { } else {
// what should we do with `class A: a = 3`? // what should we do with `class A: a = 3`?
// do nothing, continue the for loop to iterate class ast
continue; continue;
} }
} };
// TODO: now it should be confirmed that every // now it should be confirmed that every
// methods and fields of the class can be correctly typed, put the results // methods and fields of the class can be correctly typed, put the results
// into the actual def_list and the unifier // into the actual class def method and fields field
} let (class_def, _) = &def_ast_list[class_ind];
Ok(()) let mut class_def = class_def.write();
} if let TopLevelDef::Class { fields, methods, .. } = class_def.deref_mut() {
for (ref n, ref t) in class_fields_parsing_result {
fields.push((n.clone(), *t));
}
for (n, t, id) in &class_methods_parsing_result {
methods.push((n.clone(), *t, *id));
}
} else { unreachable!() }
fn analyze_top_level_inheritance(&mut self) -> Result<(), String> { // change the signature field of the class methods
unimplemented!() for (_, ty, id) in &class_methods_parsing_result {
let (method_def, _) = &def_ast_list[id.0];
let mut method_def = method_def.write();
if let TopLevelDef::Function { signature, .. } = method_def.deref_mut() {
*signature = *ty;
}
}
};
Ok(())
} }
fn analyze_top_level_function(&mut self) -> Result<(), String> { fn analyze_top_level_function(&mut self) -> Result<(), String> {
@ -684,4 +696,27 @@ impl TopLevelComposer {
fn analyze_top_level_field_instantiation(&mut self) -> Result<(), String> { fn analyze_top_level_field_instantiation(&mut self) -> Result<(), String> {
unimplemented!() unimplemented!()
} }
fn check_ty_analyzed(ty: Type,
unifier: &mut Unifier,
to_be_analyzed: &[DefinitionId]) -> bool
{
let type_enum = unifier.get_ty(ty);
match type_enum.as_ref() {
TypeEnum::TObj { obj_id, .. } => {
!to_be_analyzed.contains(obj_id)
}
TypeEnum::TVirtual { ty } => {
if let TypeEnum::TObj { obj_id, .. } =
unifier.get_ty(*ty).as_ref()
{
!to_be_analyzed.contains(obj_id)
} else {
unreachable!()
}
}
TypeEnum::TVar { .. } => true,
_ => unreachable!(),
}
}
} }