forked from M-Labs/nac3
nac3core: better field initialization check
This commit is contained in:
parent
e66693282c
commit
20905a9b67
@ -1169,20 +1169,20 @@ impl TopLevelComposer {
|
||||
methods,
|
||||
fields,
|
||||
type_vars,
|
||||
name,
|
||||
name: class_name,
|
||||
object_id,
|
||||
resolver: _,
|
||||
..
|
||||
} = &*class_def
|
||||
{
|
||||
let mut has_init = false;
|
||||
let mut init_id: Option<DefinitionId> = None;
|
||||
// get the class contructor type correct
|
||||
let (contor_args, contor_type_vars) = {
|
||||
let mut constructor_args: Vec<FuncArg> = Vec::new();
|
||||
let mut type_vars: HashMap<u32, Type> = HashMap::new();
|
||||
for (name, func_sig, ..) in methods {
|
||||
for (name, func_sig, id) in methods {
|
||||
if name == "__init__" {
|
||||
has_init = true;
|
||||
init_id = Some(*id);
|
||||
if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() {
|
||||
let FunSignature { args, vars, .. } = &*sig.borrow();
|
||||
constructor_args.extend_from_slice(args);
|
||||
@ -1207,11 +1207,22 @@ impl TopLevelComposer {
|
||||
self.unifier.unify(constructor.unwrap(), contor_type)?;
|
||||
|
||||
// class field instantiation check
|
||||
// TODO: this is a really simple one, more check later
|
||||
if !has_init && !fields.is_empty() {
|
||||
return Err(format!("fields of class {} not fully initialized", name))
|
||||
if let (Some(init_id), false) = (init_id, fields.is_empty()) {
|
||||
let init_ast =
|
||||
self.definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap();
|
||||
if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node {
|
||||
if name != "__init__" {
|
||||
unreachable!("must be init function here")
|
||||
}
|
||||
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
|
||||
if fields.iter().any(|(x, _)| !all_inited.contains(x)) {
|
||||
return Err(format!(
|
||||
"fields of class {} not fully initialized",
|
||||
class_name
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -268,4 +268,73 @@ impl TopLevelComposer {
|
||||
unifier,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn get_all_assigned_field(stmts: &[ast::Stmt<()>]) -> Result<HashSet<String>, String> {
|
||||
let mut result: HashSet<String> = HashSet::new();
|
||||
for s in stmts {
|
||||
match &s.node {
|
||||
ast::StmtKind::AnnAssign { target, .. }
|
||||
if {
|
||||
if let ast::ExprKind::Attribute { value, .. } = &target.node {
|
||||
if let ast::ExprKind::Name { id, .. } = &value.node {
|
||||
id == "self"
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} =>
|
||||
{
|
||||
return Err(format!(
|
||||
"redundant type annotation for class fields at {}",
|
||||
s.location
|
||||
))
|
||||
}
|
||||
ast::StmtKind::Assign { targets, .. } => {
|
||||
for t in targets {
|
||||
if let ast::ExprKind::Attribute { value, attr, .. } = &t.node {
|
||||
if let ast::ExprKind::Name { id, .. } = &value.node {
|
||||
if id == "self" {
|
||||
result.insert(attr.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: do not check for For and While?
|
||||
ast::StmtKind::For { body, orelse, .. }
|
||||
| ast::StmtKind::While { body, orelse, .. } => {
|
||||
result.extend(Self::get_all_assigned_field(body.as_slice())?);
|
||||
result.extend(Self::get_all_assigned_field(orelse.as_slice())?);
|
||||
}
|
||||
ast::StmtKind::If { body, orelse, .. } => {
|
||||
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
|
||||
.cloned()
|
||||
.collect::<HashSet<String>>();
|
||||
result.extend(inited_for_sure);
|
||||
}
|
||||
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
|
||||
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
|
||||
.cloned()
|
||||
.collect::<HashSet<String>>();
|
||||
result.extend(inited_for_sure);
|
||||
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
|
||||
}
|
||||
ast::StmtKind::With { body, .. } => {
|
||||
result.extend(Self::get_all_assigned_field(body.as_slice())?);
|
||||
}
|
||||
ast::StmtKind::Pass => {}
|
||||
ast::StmtKind::Assert { .. } => {}
|
||||
ast::StmtKind::Expr { .. } => {}
|
||||
|
||||
_ => {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
@ -870,7 +870,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||
return SELF
|
||||
def sum(self) -> int32:
|
||||
if self.a == 0:
|
||||
return self.a + self
|
||||
return self.a
|
||||
else:
|
||||
a = self.a
|
||||
self.a = self.a - 1
|
||||
@ -898,7 +898,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||
return ret * ret
|
||||
"},
|
||||
indoc! {"
|
||||
def sum3(l: list[V]) -> V:
|
||||
def sum_three(l: list[V]) -> V:
|
||||
return l[0] + l[1] + l[2]
|
||||
"},
|
||||
indoc! {"
|
||||
@ -921,7 +921,11 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||
b: bool
|
||||
def __init__(self, aa: G):
|
||||
self.a = aa
|
||||
self.b = True
|
||||
if 2 > 1:
|
||||
self.b = True
|
||||
else:
|
||||
# self.b = False
|
||||
pass
|
||||
def fun(self, a: G) -> list[G]:
|
||||
ret = [a, self.a]
|
||||
return ret if self.b else self.fun(self.a)
|
||||
@ -930,6 +934,30 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||
vec![];
|
||||
"type var class"
|
||||
)]
|
||||
#[test_case(
|
||||
vec![
|
||||
indoc! {"
|
||||
class A:
|
||||
def fun(self):
|
||||
1 + 2
|
||||
"},
|
||||
indoc!{"
|
||||
class B:
|
||||
a: int32
|
||||
b: bool
|
||||
def __init__(self):
|
||||
# self.b = False
|
||||
if 3 > 2:
|
||||
self.a = 3
|
||||
self.b = False
|
||||
else:
|
||||
self.a = 4
|
||||
self.b = True
|
||||
"}
|
||||
],
|
||||
vec![];
|
||||
"no_init_inst_check"
|
||||
)]
|
||||
fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||
let print = true;
|
||||
let mut composer: TopLevelComposer = Default::default();
|
||||
|
Loading…
Reference in New Issue
Block a user