forked from M-Labs/nac3
core: allow field initialization in function calls
This commit is contained in:
parent
44487b76ae
commit
4c504abd16
@ -23,7 +23,7 @@ impl Default for ComposerConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>);
|
pub type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>);
|
||||||
pub struct TopLevelComposer {
|
pub struct TopLevelComposer {
|
||||||
// list of top level definitions, same as top level context
|
// list of top level definitions, same as top level context
|
||||||
pub definition_ast_list: Vec<DefAst>,
|
pub definition_ast_list: Vec<DefAst>,
|
||||||
@ -1723,7 +1723,13 @@ impl TopLevelComposer {
|
|||||||
if *name != init_str_id {
|
if *name != init_str_id {
|
||||||
unreachable!("must be init function here")
|
unreachable!("must be init function here")
|
||||||
}
|
}
|
||||||
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
|
// let all_inited = Self::get_all_assigned_field(body.as_slice())?;
|
||||||
|
let all_inited = Self::get_all_assigned_field(
|
||||||
|
definition_ast_list,
|
||||||
|
def,
|
||||||
|
body.as_slice(),
|
||||||
|
)?;
|
||||||
|
|
||||||
for (f, _, _) in fields {
|
for (f, _, _) in fields {
|
||||||
if !all_inited.contains(f) {
|
if !all_inited.contains(f) {
|
||||||
return Err(HashSet::from([
|
return Err(HashSet::from([
|
||||||
|
@ -3,6 +3,7 @@ use std::convert::TryInto;
|
|||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||||
use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap};
|
use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap};
|
||||||
|
use ast::ExprKind;
|
||||||
use nac3parser::ast::{Constant, Location};
|
use nac3parser::ast::{Constant, Location};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use strum_macros::EnumIter;
|
use strum_macros::EnumIter;
|
||||||
@ -677,7 +678,11 @@ impl TopLevelComposer {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result<HashSet<StrRef>, HashSet<String>> {
|
pub fn get_all_assigned_field(
|
||||||
|
definition_ast_list: &Vec<DefAst>,
|
||||||
|
def: &Arc<RwLock<TopLevelDef>>,
|
||||||
|
stmts: &[Stmt<()>],
|
||||||
|
) -> Result<HashSet<StrRef>, HashSet<String>> {
|
||||||
let mut result = HashSet::new();
|
let mut result = HashSet::new();
|
||||||
for s in stmts {
|
for s in stmts {
|
||||||
match &s.node {
|
match &s.node {
|
||||||
@ -713,32 +718,151 @@ impl TopLevelComposer {
|
|||||||
// TODO: do not check for For and While?
|
// TODO: do not check for For and While?
|
||||||
ast::StmtKind::For { body, orelse, .. }
|
ast::StmtKind::For { body, orelse, .. }
|
||||||
| ast::StmtKind::While { body, orelse, .. } => {
|
| ast::StmtKind::While { body, orelse, .. } => {
|
||||||
result.extend(Self::get_all_assigned_field(body.as_slice())?);
|
result.extend(Self::get_all_assigned_field(
|
||||||
result.extend(Self::get_all_assigned_field(orelse.as_slice())?);
|
definition_ast_list,
|
||||||
|
def,
|
||||||
|
body.as_slice(),
|
||||||
|
)?);
|
||||||
|
result.extend(Self::get_all_assigned_field(
|
||||||
|
definition_ast_list,
|
||||||
|
def,
|
||||||
|
orelse.as_slice(),
|
||||||
|
)?);
|
||||||
}
|
}
|
||||||
ast::StmtKind::If { body, orelse, .. } => {
|
ast::StmtKind::If { body, orelse, .. } => {
|
||||||
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
|
let inited_for_sure =
|
||||||
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
|
Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())?
|
||||||
.copied()
|
.intersection(&Self::get_all_assigned_field(
|
||||||
.collect::<HashSet<_>>();
|
definition_ast_list,
|
||||||
|
def,
|
||||||
|
orelse.as_slice(),
|
||||||
|
)?)
|
||||||
|
.copied()
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
result.extend(inited_for_sure);
|
result.extend(inited_for_sure);
|
||||||
}
|
}
|
||||||
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
|
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
|
||||||
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
|
let inited_for_sure =
|
||||||
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
|
Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())?
|
||||||
.copied()
|
.intersection(&Self::get_all_assigned_field(
|
||||||
.collect::<HashSet<_>>();
|
definition_ast_list,
|
||||||
|
def,
|
||||||
|
orelse.as_slice(),
|
||||||
|
)?)
|
||||||
|
.copied()
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
result.extend(inited_for_sure);
|
result.extend(inited_for_sure);
|
||||||
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
|
result.extend(Self::get_all_assigned_field(
|
||||||
|
definition_ast_list,
|
||||||
|
def,
|
||||||
|
finalbody.as_slice(),
|
||||||
|
)?);
|
||||||
}
|
}
|
||||||
ast::StmtKind::With { body, .. } => {
|
ast::StmtKind::With { body, .. } => {
|
||||||
result.extend(Self::get_all_assigned_field(body.as_slice())?);
|
result.extend(Self::get_all_assigned_field(
|
||||||
|
definition_ast_list,
|
||||||
|
def,
|
||||||
|
body.as_slice(),
|
||||||
|
)?);
|
||||||
}
|
}
|
||||||
ast::StmtKind::Pass { .. }
|
// If its a call to __init__function of ancestor extend with ancestor fields
|
||||||
| ast::StmtKind::Assert { .. }
|
ast::StmtKind::Expr { value, .. } => {
|
||||||
| ast::StmtKind::Expr { .. } => {}
|
// Check if Expression is a function call to self
|
||||||
|
if let ExprKind::Call { func, args, .. } = &value.node {
|
||||||
|
if let ExprKind::Attribute { value, attr: fn_name, .. } = &func.node {
|
||||||
|
let class_def = def.read();
|
||||||
|
let (ancestors, methods) = {
|
||||||
|
let mut class_methods: HashMap<StrRef, DefinitionId> =
|
||||||
|
HashMap::new();
|
||||||
|
let mut class_ancestors: HashMap<
|
||||||
|
StrRef,
|
||||||
|
HashMap<StrRef, DefinitionId>,
|
||||||
|
> = HashMap::new();
|
||||||
|
|
||||||
|
if let TopLevelDef::Class { methods, ancestors, .. } = &*class_def {
|
||||||
|
for m in methods {
|
||||||
|
class_methods.insert(m.0, m.2);
|
||||||
|
}
|
||||||
|
ancestors.iter().skip(1).for_each(|a| {
|
||||||
|
if let TypeAnnotation::CustomClass { id, .. } = a {
|
||||||
|
let anc_def =
|
||||||
|
definition_ast_list.get(id.0).unwrap().0.read();
|
||||||
|
if let TopLevelDef::Class { name, methods, .. } =
|
||||||
|
&*anc_def
|
||||||
|
{
|
||||||
|
let mut temp: HashMap<StrRef, DefinitionId> =
|
||||||
|
HashMap::new();
|
||||||
|
for m in methods {
|
||||||
|
temp.insert(m.0, m.2);
|
||||||
|
}
|
||||||
|
// Remove module name suffix from name
|
||||||
|
let mut name_string = name.to_string();
|
||||||
|
let split_loc =
|
||||||
|
name_string.find(|c| c == '.').unwrap() + 1;
|
||||||
|
class_ancestors.insert(
|
||||||
|
name_string.split_off(split_loc).into(),
|
||||||
|
temp,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
(class_ancestors, class_methods)
|
||||||
|
};
|
||||||
|
if let ExprKind::Name { id, .. } = value.node {
|
||||||
|
if id == "self".into() {
|
||||||
|
// Get Class methods and fields
|
||||||
|
let method_id = methods.get(fn_name);
|
||||||
|
if method_id.is_some() {
|
||||||
|
if let Some(fn_ast) = &definition_ast_list
|
||||||
|
.get(method_id.unwrap().0)
|
||||||
|
.unwrap()
|
||||||
|
.1
|
||||||
|
{
|
||||||
|
if let ast::StmtKind::FunctionDef { body, .. } =
|
||||||
|
&fn_ast.node
|
||||||
|
{
|
||||||
|
result.extend(Self::get_all_assigned_field(
|
||||||
|
definition_ast_list,
|
||||||
|
def,
|
||||||
|
body.as_slice(),
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if let Some(ancestor_methods) = ancestors.get(&id) {
|
||||||
|
// First arg must be `self` when calling ancestor function
|
||||||
|
if let ExprKind::Name { id, .. } = args[0].node {
|
||||||
|
if id == "self".into() {
|
||||||
|
if let Some(method_id) = ancestor_methods.get(fn_name) {
|
||||||
|
if let Some(fn_ast) =
|
||||||
|
&definition_ast_list.get(method_id.0).unwrap().1
|
||||||
|
{
|
||||||
|
if let ast::StmtKind::FunctionDef {
|
||||||
|
body, ..
|
||||||
|
} = &fn_ast.node
|
||||||
|
{
|
||||||
|
result.extend(
|
||||||
|
Self::get_all_assigned_field(
|
||||||
|
definition_ast_list,
|
||||||
|
def,
|
||||||
|
body.as_slice(),
|
||||||
|
)?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } => {}
|
||||||
|
|
||||||
_ => {
|
_ => {
|
||||||
|
println!("{:?}", s.node);
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user