forked from M-Labs/nac3
nac3core: better field initialization check
This commit is contained in:
parent
e66693282c
commit
20905a9b67
@ -1169,20 +1169,20 @@ impl TopLevelComposer {
|
|||||||
methods,
|
methods,
|
||||||
fields,
|
fields,
|
||||||
type_vars,
|
type_vars,
|
||||||
name,
|
name: class_name,
|
||||||
object_id,
|
object_id,
|
||||||
resolver: _,
|
resolver: _,
|
||||||
..
|
..
|
||||||
} = &*class_def
|
} = &*class_def
|
||||||
{
|
{
|
||||||
let mut has_init = false;
|
let mut init_id: Option<DefinitionId> = None;
|
||||||
// get the class contructor type correct
|
// get the class contructor type correct
|
||||||
let (contor_args, contor_type_vars) = {
|
let (contor_args, contor_type_vars) = {
|
||||||
let mut constructor_args: Vec<FuncArg> = Vec::new();
|
let mut constructor_args: Vec<FuncArg> = Vec::new();
|
||||||
let mut type_vars: HashMap<u32, Type> = HashMap::new();
|
let mut type_vars: HashMap<u32, Type> = HashMap::new();
|
||||||
for (name, func_sig, ..) in methods {
|
for (name, func_sig, id) in methods {
|
||||||
if name == "__init__" {
|
if name == "__init__" {
|
||||||
has_init = true;
|
init_id = Some(*id);
|
||||||
if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() {
|
if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() {
|
||||||
let FunSignature { args, vars, .. } = &*sig.borrow();
|
let FunSignature { args, vars, .. } = &*sig.borrow();
|
||||||
constructor_args.extend_from_slice(args);
|
constructor_args.extend_from_slice(args);
|
||||||
@ -1207,11 +1207,22 @@ impl TopLevelComposer {
|
|||||||
self.unifier.unify(constructor.unwrap(), contor_type)?;
|
self.unifier.unify(constructor.unwrap(), contor_type)?;
|
||||||
|
|
||||||
// class field instantiation check
|
// class field instantiation check
|
||||||
// TODO: this is a really simple one, more check later
|
if let (Some(init_id), false) = (init_id, fields.is_empty()) {
|
||||||
if !has_init && !fields.is_empty() {
|
let init_ast =
|
||||||
return Err(format!("fields of class {} not fully initialized", name))
|
self.definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap();
|
||||||
|
if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node {
|
||||||
|
if name != "__init__" {
|
||||||
|
unreachable!("must be init function here")
|
||||||
|
}
|
||||||
|
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
|
||||||
|
if fields.iter().any(|(x, _)| !all_inited.contains(x)) {
|
||||||
|
return Err(format!(
|
||||||
|
"fields of class {} not fully initialized",
|
||||||
|
class_name
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -268,4 +268,73 @@ impl TopLevelComposer {
|
|||||||
unifier,
|
unifier,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_all_assigned_field(stmts: &[ast::Stmt<()>]) -> Result<HashSet<String>, String> {
|
||||||
|
let mut result: HashSet<String> = HashSet::new();
|
||||||
|
for s in stmts {
|
||||||
|
match &s.node {
|
||||||
|
ast::StmtKind::AnnAssign { target, .. }
|
||||||
|
if {
|
||||||
|
if let ast::ExprKind::Attribute { value, .. } = &target.node {
|
||||||
|
if let ast::ExprKind::Name { id, .. } = &value.node {
|
||||||
|
id == "self"
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} =>
|
||||||
|
{
|
||||||
|
return Err(format!(
|
||||||
|
"redundant type annotation for class fields at {}",
|
||||||
|
s.location
|
||||||
|
))
|
||||||
|
}
|
||||||
|
ast::StmtKind::Assign { targets, .. } => {
|
||||||
|
for t in targets {
|
||||||
|
if let ast::ExprKind::Attribute { value, attr, .. } = &t.node {
|
||||||
|
if let ast::ExprKind::Name { id, .. } = &value.node {
|
||||||
|
if id == "self" {
|
||||||
|
result.insert(attr.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: do not check for For and While?
|
||||||
|
ast::StmtKind::For { body, orelse, .. }
|
||||||
|
| ast::StmtKind::While { body, orelse, .. } => {
|
||||||
|
result.extend(Self::get_all_assigned_field(body.as_slice())?);
|
||||||
|
result.extend(Self::get_all_assigned_field(orelse.as_slice())?);
|
||||||
|
}
|
||||||
|
ast::StmtKind::If { body, orelse, .. } => {
|
||||||
|
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
|
||||||
|
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
|
||||||
|
.cloned()
|
||||||
|
.collect::<HashSet<String>>();
|
||||||
|
result.extend(inited_for_sure);
|
||||||
|
}
|
||||||
|
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
|
||||||
|
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
|
||||||
|
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
|
||||||
|
.cloned()
|
||||||
|
.collect::<HashSet<String>>();
|
||||||
|
result.extend(inited_for_sure);
|
||||||
|
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
|
||||||
|
}
|
||||||
|
ast::StmtKind::With { body, .. } => {
|
||||||
|
result.extend(Self::get_all_assigned_field(body.as_slice())?);
|
||||||
|
}
|
||||||
|
ast::StmtKind::Pass => {}
|
||||||
|
ast::StmtKind::Assert { .. } => {}
|
||||||
|
ast::StmtKind::Expr { .. } => {}
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -870,7 +870,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
|||||||
return SELF
|
return SELF
|
||||||
def sum(self) -> int32:
|
def sum(self) -> int32:
|
||||||
if self.a == 0:
|
if self.a == 0:
|
||||||
return self.a + self
|
return self.a
|
||||||
else:
|
else:
|
||||||
a = self.a
|
a = self.a
|
||||||
self.a = self.a - 1
|
self.a = self.a - 1
|
||||||
@ -898,7 +898,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
|||||||
return ret * ret
|
return ret * ret
|
||||||
"},
|
"},
|
||||||
indoc! {"
|
indoc! {"
|
||||||
def sum3(l: list[V]) -> V:
|
def sum_three(l: list[V]) -> V:
|
||||||
return l[0] + l[1] + l[2]
|
return l[0] + l[1] + l[2]
|
||||||
"},
|
"},
|
||||||
indoc! {"
|
indoc! {"
|
||||||
@ -921,7 +921,11 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
|||||||
b: bool
|
b: bool
|
||||||
def __init__(self, aa: G):
|
def __init__(self, aa: G):
|
||||||
self.a = aa
|
self.a = aa
|
||||||
|
if 2 > 1:
|
||||||
self.b = True
|
self.b = True
|
||||||
|
else:
|
||||||
|
# self.b = False
|
||||||
|
pass
|
||||||
def fun(self, a: G) -> list[G]:
|
def fun(self, a: G) -> list[G]:
|
||||||
ret = [a, self.a]
|
ret = [a, self.a]
|
||||||
return ret if self.b else self.fun(self.a)
|
return ret if self.b else self.fun(self.a)
|
||||||
@ -930,6 +934,30 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
|||||||
vec![];
|
vec![];
|
||||||
"type var class"
|
"type var class"
|
||||||
)]
|
)]
|
||||||
|
#[test_case(
|
||||||
|
vec![
|
||||||
|
indoc! {"
|
||||||
|
class A:
|
||||||
|
def fun(self):
|
||||||
|
1 + 2
|
||||||
|
"},
|
||||||
|
indoc!{"
|
||||||
|
class B:
|
||||||
|
a: int32
|
||||||
|
b: bool
|
||||||
|
def __init__(self):
|
||||||
|
# self.b = False
|
||||||
|
if 3 > 2:
|
||||||
|
self.a = 3
|
||||||
|
self.b = False
|
||||||
|
else:
|
||||||
|
self.a = 4
|
||||||
|
self.b = True
|
||||||
|
"}
|
||||||
|
],
|
||||||
|
vec![];
|
||||||
|
"no_init_inst_check"
|
||||||
|
)]
|
||||||
fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||||
let print = true;
|
let print = true;
|
||||||
let mut composer: TopLevelComposer = Default::default();
|
let mut composer: TopLevelComposer = Default::default();
|
||||||
|
Loading…
Reference in New Issue
Block a user