Compare commits

...

2 Commits

Author SHA1 Message Date
9421f166de WIP 2024-07-25 17:42:37 +08:00
4c504abd16 core: allow field initialization in function calls 2024-07-25 17:37:02 +08:00
5 changed files with 177 additions and 22 deletions

View File

@ -161,7 +161,9 @@
clippy
pre-commit
rustfmt
rust-analyzer
];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
};
devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2";

View File

@ -23,7 +23,7 @@ impl Default for ComposerConfig {
}
}
type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>);
pub type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>);
pub struct TopLevelComposer {
// list of top level definitions, same as top level context
pub definition_ast_list: Vec<DefAst>,
@ -1723,7 +1723,13 @@ impl TopLevelComposer {
if *name != init_str_id {
unreachable!("must be init function here")
}
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
// let all_inited = Self::get_all_assigned_field(body.as_slice())?;
let all_inited = Self::get_all_assigned_field(
definition_ast_list,
def,
body.as_slice(),
)?;
for (f, _, _) in fields {
if !all_inited.contains(f) {
return Err(HashSet::from([

View File

@ -3,6 +3,7 @@ use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap};
use ast::ExprKind;
use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
@ -677,7 +678,11 @@ impl TopLevelComposer {
)
}
pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result<HashSet<StrRef>, HashSet<String>> {
pub fn get_all_assigned_field(
definition_ast_list: &Vec<DefAst>,
def: &Arc<RwLock<TopLevelDef>>,
stmts: &[Stmt<()>],
) -> Result<HashSet<StrRef>, HashSet<String>> {
let mut result = HashSet::new();
for s in stmts {
match &s.node {
@ -713,32 +718,151 @@ impl TopLevelComposer {
// TODO: do not check for For and While?
ast::StmtKind::For { body, orelse, .. }
| ast::StmtKind::While { body, orelse, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?);
result.extend(Self::get_all_assigned_field(orelse.as_slice())?);
result.extend(Self::get_all_assigned_field(
definition_ast_list,
def,
body.as_slice(),
)?);
result.extend(Self::get_all_assigned_field(
definition_ast_list,
def,
orelse.as_slice(),
)?);
}
ast::StmtKind::If { body, orelse, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.copied()
.collect::<HashSet<_>>();
let inited_for_sure =
Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())?
.intersection(&Self::get_all_assigned_field(
definition_ast_list,
def,
orelse.as_slice(),
)?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure);
}
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.copied()
.collect::<HashSet<_>>();
let inited_for_sure =
Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())?
.intersection(&Self::get_all_assigned_field(
definition_ast_list,
def,
orelse.as_slice(),
)?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure);
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
result.extend(Self::get_all_assigned_field(
definition_ast_list,
def,
finalbody.as_slice(),
)?);
}
ast::StmtKind::With { body, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?);
result.extend(Self::get_all_assigned_field(
definition_ast_list,
def,
body.as_slice(),
)?);
}
ast::StmtKind::Pass { .. }
| ast::StmtKind::Assert { .. }
| ast::StmtKind::Expr { .. } => {}
// If its a call to __init__function of ancestor extend with ancestor fields
ast::StmtKind::Expr { value, .. } => {
// Check if Expression is a function call to self
if let ExprKind::Call { func, args, .. } = &value.node {
if let ExprKind::Attribute { value, attr: fn_name, .. } = &func.node {
let class_def = def.read();
let (ancestors, methods) = {
let mut class_methods: HashMap<StrRef, DefinitionId> =
HashMap::new();
let mut class_ancestors: HashMap<
StrRef,
HashMap<StrRef, DefinitionId>,
> = HashMap::new();
if let TopLevelDef::Class { methods, ancestors, .. } = &*class_def {
for m in methods {
class_methods.insert(m.0, m.2);
}
ancestors.iter().skip(1).for_each(|a| {
if let TypeAnnotation::CustomClass { id, .. } = a {
let anc_def =
definition_ast_list.get(id.0).unwrap().0.read();
if let TopLevelDef::Class { name, methods, .. } =
&*anc_def
{
let mut temp: HashMap<StrRef, DefinitionId> =
HashMap::new();
for m in methods {
temp.insert(m.0, m.2);
}
// Remove module name suffix from name
let mut name_string = name.to_string();
let split_loc =
name_string.find(|c| c == '.').unwrap() + 1;
class_ancestors.insert(
name_string.split_off(split_loc).into(),
temp,
);
}
}
});
}
(class_ancestors, class_methods)
};
if let ExprKind::Name { id, .. } = value.node {
if id == "self".into() {
// Get Class methods and fields
let method_id = methods.get(fn_name);
if method_id.is_some() {
if let Some(fn_ast) = &definition_ast_list
.get(method_id.unwrap().0)
.unwrap()
.1
{
if let ast::StmtKind::FunctionDef { body, .. } =
&fn_ast.node
{
result.extend(Self::get_all_assigned_field(
definition_ast_list,
def,
body.as_slice(),
)?);
}
}
}
} else if let Some(ancestor_methods) = ancestors.get(&id) {
// First arg must be `self` when calling ancestor function
if let ExprKind::Name { id, .. } = args[0].node {
if id == "self".into() {
if let Some(method_id) = ancestor_methods.get(fn_name) {
if let Some(fn_ast) =
&definition_ast_list.get(method_id.0).unwrap().1
{
if let ast::StmtKind::FunctionDef {
body, ..
} = &fn_ast.node
{
result.extend(
Self::get_all_assigned_field(
definition_ast_list,
def,
body.as_slice(),
)?,
);
}
}
};
}
}
}
}
}
}
}
ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } => {}
_ => {
println!("{:?}", s.node);
unimplemented!()
}
}

View File

@ -23,10 +23,33 @@ class B(A):
self.a = b + 1
self.b = b
class C:
a: int32
def __init__(self):
self.a = 42
def test2(self):
self.a = 23
class D(C):
def __init__(self):
C.test2(self)
# self.test()
# C.test2(self)
# self.a = 2
# __main__.C.__init__(self)
def test(self):
self.a = 2
def run() -> int32:
aaa = A(5)
bbb = B(2)
aaa.f1()
bbb.f1()
x = D()
x.__init__()
output_int32(x.a)
# aaa = A(5)
# bbb = B(2)
# aaa.f1()
# bbb.f1()
return 0

BIN
pyo3_output/nac3artiq.so Executable file

Binary file not shown.