nac3core: better field initialization check

This commit is contained in:
ychenfo 2021-09-21 02:48:42 +08:00
parent e66693282c
commit 20905a9b67
3 changed files with 119 additions and 11 deletions

View File

@ -1169,20 +1169,20 @@ impl TopLevelComposer {
methods,
fields,
type_vars,
name,
name: class_name,
object_id,
resolver: _,
..
} = &*class_def
{
let mut has_init = false;
let mut init_id: Option<DefinitionId> = None;
// get the class contructor type correct
let (contor_args, contor_type_vars) = {
let mut constructor_args: Vec<FuncArg> = Vec::new();
let mut type_vars: HashMap<u32, Type> = HashMap::new();
for (name, func_sig, ..) in methods {
for (name, func_sig, id) in methods {
if name == "__init__" {
has_init = true;
init_id = Some(*id);
if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() {
let FunSignature { args, vars, .. } = &*sig.borrow();
constructor_args.extend_from_slice(args);
@ -1207,11 +1207,22 @@ impl TopLevelComposer {
self.unifier.unify(constructor.unwrap(), contor_type)?;
// class field instantiation check
// TODO: this is a really simple one, more check later
if !has_init && !fields.is_empty() {
return Err(format!("fields of class {} not fully initialized", name))
if let (Some(init_id), false) = (init_id, fields.is_empty()) {
let init_ast =
self.definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap();
if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node {
if name != "__init__" {
unreachable!("must be init function here")
}
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
if fields.iter().any(|(x, _)| !all_inited.contains(x)) {
return Err(format!(
"fields of class {} not fully initialized",
class_name
));
}
}
}
}
}

View File

@ -268,4 +268,73 @@ impl TopLevelComposer {
unifier,
)
}
pub fn get_all_assigned_field(stmts: &[ast::Stmt<()>]) -> Result<HashSet<String>, String> {
let mut result: HashSet<String> = HashSet::new();
for s in stmts {
match &s.node {
ast::StmtKind::AnnAssign { target, .. }
if {
if let ast::ExprKind::Attribute { value, .. } = &target.node {
if let ast::ExprKind::Name { id, .. } = &value.node {
id == "self"
} else {
false
}
} else {
false
}
} =>
{
return Err(format!(
"redundant type annotation for class fields at {}",
s.location
))
}
ast::StmtKind::Assign { targets, .. } => {
for t in targets {
if let ast::ExprKind::Attribute { value, attr, .. } = &t.node {
if let ast::ExprKind::Name { id, .. } = &value.node {
if id == "self" {
result.insert(attr.clone());
}
}
}
}
}
// TODO: do not check for For and While?
ast::StmtKind::For { body, orelse, .. }
| ast::StmtKind::While { body, orelse, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?);
result.extend(Self::get_all_assigned_field(orelse.as_slice())?);
}
ast::StmtKind::If { body, orelse, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.cloned()
.collect::<HashSet<String>>();
result.extend(inited_for_sure);
}
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.cloned()
.collect::<HashSet<String>>();
result.extend(inited_for_sure);
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
}
ast::StmtKind::With { body, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?);
}
ast::StmtKind::Pass => {}
ast::StmtKind::Assert { .. } => {}
ast::StmtKind::Expr { .. } => {}
_ => {
unimplemented!()
}
}
}
Ok(result)
}
}

View File

@ -870,7 +870,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
return SELF
def sum(self) -> int32:
if self.a == 0:
return self.a + self
return self.a
else:
a = self.a
self.a = self.a - 1
@ -898,7 +898,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
return ret * ret
"},
indoc! {"
def sum3(l: list[V]) -> V:
def sum_three(l: list[V]) -> V:
return l[0] + l[1] + l[2]
"},
indoc! {"
@ -921,7 +921,11 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
b: bool
def __init__(self, aa: G):
self.a = aa
self.b = True
if 2 > 1:
self.b = True
else:
# self.b = False
pass
def fun(self, a: G) -> list[G]:
ret = [a, self.a]
return ret if self.b else self.fun(self.a)
@ -930,6 +934,30 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
vec![];
"type var class"
)]
#[test_case(
vec![
indoc! {"
class A:
def fun(self):
1 + 2
"},
indoc!{"
class B:
a: int32
b: bool
def __init__(self):
# self.b = False
if 3 > 2:
self.a = 3
self.b = False
else:
self.a = 4
self.b = True
"}
],
vec![];
"no_init_inst_check"
)]
fn test_inference(source: Vec<&str>, res: Vec<&str>) {
let print = true;
let mut composer: TopLevelComposer = Default::default();