nac3core: better field initialization check

This commit is contained in:
ychenfo 2021-09-21 02:48:42 +08:00
parent e66693282c
commit 20905a9b67
3 changed files with 119 additions and 11 deletions

View File

@ -1169,20 +1169,20 @@ impl TopLevelComposer {
methods, methods,
fields, fields,
type_vars, type_vars,
name, name: class_name,
object_id, object_id,
resolver: _, resolver: _,
.. ..
} = &*class_def } = &*class_def
{ {
let mut has_init = false; let mut init_id: Option<DefinitionId> = None;
// get the class contructor type correct // get the class contructor type correct
let (contor_args, contor_type_vars) = { let (contor_args, contor_type_vars) = {
let mut constructor_args: Vec<FuncArg> = Vec::new(); let mut constructor_args: Vec<FuncArg> = Vec::new();
let mut type_vars: HashMap<u32, Type> = HashMap::new(); let mut type_vars: HashMap<u32, Type> = HashMap::new();
for (name, func_sig, ..) in methods { for (name, func_sig, id) in methods {
if name == "__init__" { if name == "__init__" {
has_init = true; init_id = Some(*id);
if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() { if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() {
let FunSignature { args, vars, .. } = &*sig.borrow(); let FunSignature { args, vars, .. } = &*sig.borrow();
constructor_args.extend_from_slice(args); constructor_args.extend_from_slice(args);
@ -1207,11 +1207,22 @@ impl TopLevelComposer {
self.unifier.unify(constructor.unwrap(), contor_type)?; self.unifier.unify(constructor.unwrap(), contor_type)?;
// class field instantiation check // class field instantiation check
// TODO: this is a really simple one, more check later if let (Some(init_id), false) = (init_id, fields.is_empty()) {
if !has_init && !fields.is_empty() { let init_ast =
return Err(format!("fields of class {} not fully initialized", name)) self.definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap();
if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node {
if name != "__init__" {
unreachable!("must be init function here")
}
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
if fields.iter().any(|(x, _)| !all_inited.contains(x)) {
return Err(format!(
"fields of class {} not fully initialized",
class_name
));
}
}
} }
} }
} }

View File

@ -268,4 +268,73 @@ impl TopLevelComposer {
unifier, unifier,
) )
} }
pub fn get_all_assigned_field(stmts: &[ast::Stmt<()>]) -> Result<HashSet<String>, String> {
let mut result: HashSet<String> = HashSet::new();
for s in stmts {
match &s.node {
ast::StmtKind::AnnAssign { target, .. }
if {
if let ast::ExprKind::Attribute { value, .. } = &target.node {
if let ast::ExprKind::Name { id, .. } = &value.node {
id == "self"
} else {
false
}
} else {
false
}
} =>
{
return Err(format!(
"redundant type annotation for class fields at {}",
s.location
))
}
ast::StmtKind::Assign { targets, .. } => {
for t in targets {
if let ast::ExprKind::Attribute { value, attr, .. } = &t.node {
if let ast::ExprKind::Name { id, .. } = &value.node {
if id == "self" {
result.insert(attr.clone());
}
}
}
}
}
// TODO: do not check for For and While?
ast::StmtKind::For { body, orelse, .. }
| ast::StmtKind::While { body, orelse, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?);
result.extend(Self::get_all_assigned_field(orelse.as_slice())?);
}
ast::StmtKind::If { body, orelse, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.cloned()
.collect::<HashSet<String>>();
result.extend(inited_for_sure);
}
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.cloned()
.collect::<HashSet<String>>();
result.extend(inited_for_sure);
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
}
ast::StmtKind::With { body, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?);
}
ast::StmtKind::Pass => {}
ast::StmtKind::Assert { .. } => {}
ast::StmtKind::Expr { .. } => {}
_ => {
unimplemented!()
}
}
}
Ok(result)
}
} }

View File

@ -870,7 +870,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
return SELF return SELF
def sum(self) -> int32: def sum(self) -> int32:
if self.a == 0: if self.a == 0:
return self.a + self return self.a
else: else:
a = self.a a = self.a
self.a = self.a - 1 self.a = self.a - 1
@ -898,7 +898,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
return ret * ret return ret * ret
"}, "},
indoc! {" indoc! {"
def sum3(l: list[V]) -> V: def sum_three(l: list[V]) -> V:
return l[0] + l[1] + l[2] return l[0] + l[1] + l[2]
"}, "},
indoc! {" indoc! {"
@ -921,7 +921,11 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
b: bool b: bool
def __init__(self, aa: G): def __init__(self, aa: G):
self.a = aa self.a = aa
if 2 > 1:
self.b = True self.b = True
else:
# self.b = False
pass
def fun(self, a: G) -> list[G]: def fun(self, a: G) -> list[G]:
ret = [a, self.a] ret = [a, self.a]
return ret if self.b else self.fun(self.a) return ret if self.b else self.fun(self.a)
@ -930,6 +934,30 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
vec![]; vec![];
"type var class" "type var class"
)] )]
#[test_case(
vec![
indoc! {"
class A:
def fun(self):
1 + 2
"},
indoc!{"
class B:
a: int32
b: bool
def __init__(self):
# self.b = False
if 3 > 2:
self.a = 3
self.b = False
else:
self.a = 4
self.b = True
"}
],
vec![];
"no_init_inst_check"
)]
fn test_inference(source: Vec<&str>, res: Vec<&str>) { fn test_inference(source: Vec<&str>, res: Vec<&str>) {
let print = true; let print = true;
let mut composer: TopLevelComposer = Default::default(); let mut composer: TopLevelComposer = Default::default();