forked from M-Labs/nac3
core: allow field initialization in function calls
This commit is contained in:
parent
44487b76ae
commit
4c504abd16
@ -23,7 +23,7 @@ impl Default for ComposerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>);
|
||||
pub type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>);
|
||||
pub struct TopLevelComposer {
|
||||
// list of top level definitions, same as top level context
|
||||
pub definition_ast_list: Vec<DefAst>,
|
||||
@ -1723,7 +1723,13 @@ impl TopLevelComposer {
|
||||
if *name != init_str_id {
|
||||
unreachable!("must be init function here")
|
||||
}
|
||||
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
|
||||
// let all_inited = Self::get_all_assigned_field(body.as_slice())?;
|
||||
let all_inited = Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
body.as_slice(),
|
||||
)?;
|
||||
|
||||
for (f, _, _) in fields {
|
||||
if !all_inited.contains(f) {
|
||||
return Err(HashSet::from([
|
||||
|
@ -3,6 +3,7 @@ use std::convert::TryInto;
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap};
|
||||
use ast::ExprKind;
|
||||
use nac3parser::ast::{Constant, Location};
|
||||
use strum::IntoEnumIterator;
|
||||
use strum_macros::EnumIter;
|
||||
@ -677,7 +678,11 @@ impl TopLevelComposer {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result<HashSet<StrRef>, HashSet<String>> {
|
||||
pub fn get_all_assigned_field(
|
||||
definition_ast_list: &Vec<DefAst>,
|
||||
def: &Arc<RwLock<TopLevelDef>>,
|
||||
stmts: &[Stmt<()>],
|
||||
) -> Result<HashSet<StrRef>, HashSet<String>> {
|
||||
let mut result = HashSet::new();
|
||||
for s in stmts {
|
||||
match &s.node {
|
||||
@ -713,32 +718,151 @@ impl TopLevelComposer {
|
||||
// TODO: do not check for For and While?
|
||||
ast::StmtKind::For { body, orelse, .. }
|
||||
| ast::StmtKind::While { body, orelse, .. } => {
|
||||
result.extend(Self::get_all_assigned_field(body.as_slice())?);
|
||||
result.extend(Self::get_all_assigned_field(orelse.as_slice())?);
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
body.as_slice(),
|
||||
)?);
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
orelse.as_slice(),
|
||||
)?);
|
||||
}
|
||||
ast::StmtKind::If { body, orelse, .. } => {
|
||||
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
|
||||
let inited_for_sure =
|
||||
Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
orelse.as_slice(),
|
||||
)?)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
result.extend(inited_for_sure);
|
||||
}
|
||||
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
|
||||
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
|
||||
let inited_for_sure =
|
||||
Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
orelse.as_slice(),
|
||||
)?)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
result.extend(inited_for_sure);
|
||||
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
finalbody.as_slice(),
|
||||
)?);
|
||||
}
|
||||
ast::StmtKind::With { body, .. } => {
|
||||
result.extend(Self::get_all_assigned_field(body.as_slice())?);
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
body.as_slice(),
|
||||
)?);
|
||||
}
|
||||
ast::StmtKind::Pass { .. }
|
||||
| ast::StmtKind::Assert { .. }
|
||||
| ast::StmtKind::Expr { .. } => {}
|
||||
// If its a call to __init__function of ancestor extend with ancestor fields
|
||||
ast::StmtKind::Expr { value, .. } => {
|
||||
// Check if Expression is a function call to self
|
||||
if let ExprKind::Call { func, args, .. } = &value.node {
|
||||
if let ExprKind::Attribute { value, attr: fn_name, .. } = &func.node {
|
||||
let class_def = def.read();
|
||||
let (ancestors, methods) = {
|
||||
let mut class_methods: HashMap<StrRef, DefinitionId> =
|
||||
HashMap::new();
|
||||
let mut class_ancestors: HashMap<
|
||||
StrRef,
|
||||
HashMap<StrRef, DefinitionId>,
|
||||
> = HashMap::new();
|
||||
|
||||
if let TopLevelDef::Class { methods, ancestors, .. } = &*class_def {
|
||||
for m in methods {
|
||||
class_methods.insert(m.0, m.2);
|
||||
}
|
||||
ancestors.iter().skip(1).for_each(|a| {
|
||||
if let TypeAnnotation::CustomClass { id, .. } = a {
|
||||
let anc_def =
|
||||
definition_ast_list.get(id.0).unwrap().0.read();
|
||||
if let TopLevelDef::Class { name, methods, .. } =
|
||||
&*anc_def
|
||||
{
|
||||
let mut temp: HashMap<StrRef, DefinitionId> =
|
||||
HashMap::new();
|
||||
for m in methods {
|
||||
temp.insert(m.0, m.2);
|
||||
}
|
||||
// Remove module name suffix from name
|
||||
let mut name_string = name.to_string();
|
||||
let split_loc =
|
||||
name_string.find(|c| c == '.').unwrap() + 1;
|
||||
class_ancestors.insert(
|
||||
name_string.split_off(split_loc).into(),
|
||||
temp,
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
(class_ancestors, class_methods)
|
||||
};
|
||||
if let ExprKind::Name { id, .. } = value.node {
|
||||
if id == "self".into() {
|
||||
// Get Class methods and fields
|
||||
let method_id = methods.get(fn_name);
|
||||
if method_id.is_some() {
|
||||
if let Some(fn_ast) = &definition_ast_list
|
||||
.get(method_id.unwrap().0)
|
||||
.unwrap()
|
||||
.1
|
||||
{
|
||||
if let ast::StmtKind::FunctionDef { body, .. } =
|
||||
&fn_ast.node
|
||||
{
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
body.as_slice(),
|
||||
)?);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if let Some(ancestor_methods) = ancestors.get(&id) {
|
||||
// First arg must be `self` when calling ancestor function
|
||||
if let ExprKind::Name { id, .. } = args[0].node {
|
||||
if id == "self".into() {
|
||||
if let Some(method_id) = ancestor_methods.get(fn_name) {
|
||||
if let Some(fn_ast) =
|
||||
&definition_ast_list.get(method_id.0).unwrap().1
|
||||
{
|
||||
if let ast::StmtKind::FunctionDef {
|
||||
body, ..
|
||||
} = &fn_ast.node
|
||||
{
|
||||
result.extend(
|
||||
Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
body.as_slice(),
|
||||
)?,
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } => {}
|
||||
|
||||
_ => {
|
||||
println!("{:?}", s.node);
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user