diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 4417a682..2ede61a6 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -280,6 +280,7 @@ impl Nac3 { self.builtins.clone(), ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" }, ); + composer.build_constructor_lookup(self.top_levels.iter().map(|(stmt, _, _)| stmt)); let builtins = PyModule::import(py, "builtins")?; let typings = PyModule::import(py, "typing")?; diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 16f508a7..c11855b9 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -37,6 +37,8 @@ pub struct TopLevelComposer { // number of built-in function and classes in the definition list, later skip pub builtin_num: usize, pub core_config: ComposerConfig, + // the class name and its constructor function + pub constructor_lookup: HashMap>, } impl Default for TopLevelComposer { @@ -132,12 +134,75 @@ impl TopLevelComposer { defined_names, method_class, core_config, + constructor_lookup: Default::default(), }, builtin_id, builtin_ty, ) } + pub fn build_constructor_lookup<'a, I>(&mut self, stmts: I) + where + I: Iterator> + { + let classes = Vec::from_iter(stmts.filter_map(|stmt| { + if let ast::StmtKind::ClassDef { name, bases, body, .. } = &stmt.node { + Some((name, bases, body)) + } else { + None + } + })); + + let base_class_lookup: HashMap = HashMap::from_iter( + classes.iter().filter_map(|(class_name, bases, _)| { + // Get the first base class in the Vector of bases is good enough, since we only support single inheritance + bases + .get(0) + .and_then(|ast::Located { node, .. }| { + if let ast::ExprKind::Name { id, .. } = node { + Some(*id) + } else { + None + } + }) + .and_then(|base_class_name| Some((**class_name, base_class_name))) + }) + ); + + let constructor_lookup: HashMap> = HashMap::from_iter( + classes.iter().filter_map(|(class_name, _, body)| { + body.iter().find_map(|stmt| { + if let ast::StmtKind::FunctionDef { name, .. } = &stmt.node { + if name == &"__init__".into() { + return Some((**class_name, stmt.clone())); + } + } + None + }) + }) + ); + + for (class, _, _) in classes + .iter() + .filter(|(c, _, _)| constructor_lookup.get(c).is_none()) + { + let mut current_class = class.clone(); + while let Some(base) = base_class_lookup.get(current_class) { + if let Some(cons) = constructor_lookup.get(base) { + self.constructor_lookup.insert(**class, cons.clone()); + break; + } else { + current_class = base; + } + } + } + + // copy the rest of the classes with constructor into the self.constructor_lookup + for (k, v) in constructor_lookup.into_iter() { + self.constructor_lookup.insert(k, v); + } + } + pub fn make_top_level_context(&self) -> TopLevelContext { TopLevelContext { definitions: RwLock::new( @@ -222,18 +287,19 @@ impl TopLevelComposer { // we do not push anything to the def list, so we keep track of the index // and then push in the correct order after the for loop let mut class_method_index_offset = 0; - let init_id = "__init__".into(); let exception_id = "Exception".into(); // TODO: Fix this hack. We will generate constructor for classes that inherit // from Exception class (directly or indirectly), but this code cannot handle // subclass of other exception classes. let mut contains_constructor = bases .iter().any(|base| matches!(base.node, ast::ExprKind::Name { id, .. } if id == exception_id)); + + if self.constructor_lookup.contains_key(&class_name) { + contains_constructor = true; + } + for b in body { if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node { - if method_name == &init_id { - contains_constructor = true; - } if self.keyword_list.contains(method_name) { return Err(format!( "cannot use keyword `{}` as a method name (at {})", diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 49400711..1fa35a30 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -566,6 +566,152 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { } } +#[test_case( + indoc! {" + class A: + def __init__(self): + pass + + class B(A): + pass + + class C(B): + pass + "}, + HashMap::from([ + ("A".into(),ast::Located { + location: ast::Location::new(2, 5, ast::FileName("unknown".into())), + custom: (), + node: ast::StmtKind::FunctionDef { + name: "__init__".into(), + args: Box::new(ast::Arguments { + posonlyargs: vec![], + args: vec![ + ast::Located { + location: ast::Location::new(2, 18, ast::FileName("unknown".into())), + custom: (), + node: ast::ArgData { + arg: "self".into(), + annotation: None, + type_comment: None, + }, + }, + ], + vararg: None, + kwonlyargs: vec![], + kw_defaults: vec![], + kwarg: None, + defaults: vec![], + }), + body: vec![ + ast::Located { + location: Location::new(3, 9, ast::FileName("unknown".into())), + custom: (), + node: ast::StmtKind::Pass { + config_comment: vec![], + }, + }, + ], + decorator_list: vec![], + returns: None, + type_comment: None, + config_comment: vec![], + }, + }), + ("B".into(), ast::Located { + location: ast::Location::new(2, 5, ast::FileName("unknown".into())), + custom: (), + node: ast::StmtKind::FunctionDef { + name: "__init__".into(), + args: Box::new(ast::Arguments { + posonlyargs: vec![], + args: vec![ + ast::Located { + location: ast::Location::new(2, 18, ast::FileName("unknown".into())), + custom: (), + node: ast::ArgData { + arg: "self".into(), + annotation: None, + type_comment: None, + }, + }, + ], + vararg: None, + kwonlyargs: vec![], + kw_defaults: vec![], + kwarg: None, + defaults: vec![], + }), + body: vec![ + ast::Located { + location: Location::new(3, 9, ast::FileName("unknown".into())), + custom: (), + node: ast::StmtKind::Pass { + config_comment: vec![], + }, + }, + ], + decorator_list: vec![], + returns: None, + type_comment: None, + config_comment: vec![], + }, + }), + ("C".into(), ast::Located { + location: ast::Location::new(2, 5, ast::FileName("unknown".into())), + custom: (), + node: ast::StmtKind::FunctionDef { + name: "__init__".into(), + args: Box::new(ast::Arguments { + posonlyargs: vec![], + args: vec![ + ast::Located { + location: ast::Location::new(2, 18, ast::FileName("unknown".into())), + custom: (), + node: ast::ArgData { + arg: "self".into(), + annotation: None, + type_comment: None, + }, + }, + ], + vararg: None, + kwonlyargs: vec![], + kw_defaults: vec![], + kwarg: None, + defaults: vec![], + }), + body: vec![ + ast::Located { + location: Location::new(3, 9, ast::FileName("unknown".into())), + custom: (), + node: ast::StmtKind::Pass { + config_comment: vec![], + }, + }, + ], + decorator_list: vec![], + returns: None, + type_comment: None, + config_comment: vec![], + }, + }), + ]); + "build_constructor_lookup_table" +)] +fn test_build_constructor_lookup(source: &str, result: HashMap>) { + let ast = parse_program(source, Default::default()); + let mut composer: TopLevelComposer = Default::default(); + if let Ok(stmts) = ast { + composer.build_constructor_lookup(stmts.iter()) + } + + assert_eq!(result.get(&"A".into()).unwrap(), composer.constructor_lookup.get(&"A".into()).unwrap()); + assert_eq!(result.get(&"B".into()).unwrap(), composer.constructor_lookup.get(&"B".into()).unwrap()); + assert_eq!(result.get(&"C".into()).unwrap(), composer.constructor_lookup.get(&"C".into()).unwrap()); +} + + #[test_case( vec![ indoc! {" diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index cc50fb63..703310a1 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -183,6 +183,7 @@ fn main() { Arc::new(Resolver(internal_resolver.clone())) as Arc; let parser_result = parser::parse_program(&program, file_name.into()).unwrap(); + composer.build_constructor_lookup(parser_result.iter()); for stmt in parser_result.into_iter() { match &stmt.node {