From 1566f67917611eb84dbdc94d669b8dca7f566f86 Mon Sep 17 00:00:00 2001 From: mwojcik Date: Thu, 31 Oct 2024 17:36:57 +0800 Subject: [PATCH] pass destination to composer --- nac3artiq/src/lib.rs | 61 +++++++++++++++++++++++++++---- nac3core/src/toplevel/composer.rs | 20 +++++++--- nac3core/src/toplevel/test.rs | 6 +-- nac3standalone/src/main.rs | 2 +- 4 files changed, 72 insertions(+), 17 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index a52640d3..8b8f3a75 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -383,6 +383,7 @@ impl Nac3 { let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py); let store_str = embedding_map.getattr("store_str").unwrap().to_object(py); let store_fun = embedding_map.getattr("store_function").unwrap().to_object(py); + let store_subkernel = embedding_map.getattr("store_subkernel").unwrap().to_object(py); let host_attributes = embedding_map.getattr("attributes_writeback").unwrap().to_object(py); let global_value_ids: Arc>> = Arc::new(RwLock::new(HashMap::new())); let helper = PythonHelper { @@ -502,21 +503,26 @@ impl Nac3 { .iter() .any(|decorator| decorator_id_string(decorator) == Some("subkernel".to_string())) { - if let Constant::Int(destination) = decorator_get_destination(decorator) { - if destination < 0 || destination > 255 { + if let Constant::Int(sk_dest) = decorator_get_destination(decorator) { + if sk_dest < 0 || sk_dest > 255 { return Err(CompileError::new_err(format!( "compilation failed\n----------\nSubkernel destination must be between 0 and 255 (at {})", stmt.location ))); } - subkernel_ids.push((None, def_id, destination)); + if sk_dest != destination { + // subkernels with the same destination as currently compiled kernel + // are treated as normal kernels + store_subkernel.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap(); + subkernel_ids.push((None, def_id, destination)); + } } else { return Err(CompileError::new_err(format!( "compilation failed\n----------\nDestination must be provided for subkernels (at {})", stmt.location ))); } - // store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap(); + } } @@ -541,20 +547,22 @@ impl Nac3 { } rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async)); } else if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "subkernel".into())) { - if let Constant::Int(destination) = decorator_get_destination(decorator) { + if let Constant::Int(sk_dest) = decorator_get_destination(decorator) { if name == &"__init__".into() { return Err(CompileError::new_err(format!( "compilation failed\n----------\nThe constructor of class {} should not be decorated with subkernel decorator (at {})", class_name, stmt.location ))); } - if destination < 0 || destination > 255 { + if sk_dest < 0 || sk_dest > 255 { return Err(CompileError::new_err(format!( "compilation failed\n----------\nSubkernel destination must be between 0 and 255 (at {})", stmt.location ))); } - subkernel_ids.push((Some((class_obj.clone(), *name)), def_id, destination)); + if sk_dest != destination { + subkernel_ids.push((Some((class_obj.clone(), *name)), def_id, sk_dest)); + } } else { return Err(CompileError::new_err(format!( "compilation failed\n----------\nDestination must be provided for subkernels (at {})", @@ -638,7 +646,7 @@ impl Nac3 { ); let signature = store.add_cty(signature); - if let Err(e) = composer.start_analysis(true) { + if let Err(e) = composer.start_analysis(true, Some(destination)) { // report error of __modinit__ separately return if e.iter().any(|err| err.contains("")) { let msg = Self::report_modinit( @@ -701,6 +709,43 @@ impl Nac3 { } } } + for (class_data, id, sk_dest) in &subkernel_ids { + let mut def = defs[id.0].write(); + match &mut *def { + TopLevelDef::Function { codegen_callback, .. } => { + *codegen_callback = Some(subkernel_codegen_callback()); + } + TopLevelDef::Class { methods, .. } => { + let (class_def, method_name) = class_data.as_ref().unwrap(); + for (name, _, id) in &*methods { + if name != method_name { + continue; + } + if let TopLevelDef::Function { codegen_callback, .. } = + &mut *defs[id.0].write() + { + *codegen_callback = Some(subkernel_codegen_callback()); + store_subkernel + .call1( + py, + ( + id.0.into_py(py), + class_def + .getattr(py, name.to_string().as_str()) + .unwrap(), + ), + ) + .unwrap(); + } + } + } + TopLevelDef::Variable { .. } => { + return Err(CompileError::new_err(String::from( + "Unsupported @subkernel annotation on global variable", + ))) + } + } + } } let instance = { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 5e456888..eaf4457a 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -463,13 +463,13 @@ impl TopLevelComposer { Ok((name, DefinitionId(self.definition_ast_list.len() - 1), Some(ty_to_be_unified))) } - pub fn start_analysis(&mut self, inference: bool) -> Result<(), HashSet> { + pub fn start_analysis(&mut self, inference: bool, destination: Option) -> Result<(), HashSet> { self.analyze_top_level_class_type_var()?; self.analyze_top_level_class_bases()?; self.analyze_top_level_class_fields_methods()?; self.analyze_top_level_function()?; if inference { - self.analyze_function_instance()?; + self.analyze_function_instance(destination)?; } self.analyze_top_level_variables()?; Ok(()) @@ -1736,7 +1736,8 @@ impl TopLevelComposer { /// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of /// [`TopLevelDef::Function`] - fn analyze_function_instance(&mut self) -> Result<(), HashSet> { + /// destination is an ARTIQ-only argument necessary for proper support of subkernels + fn analyze_function_instance(&mut self, destination: Option) -> Result<(), HashSet> { // first get the class constructor type correct for the following type check in function body // also do class field instantiation check let init_str_id = "__init__".into(); @@ -2131,14 +2132,23 @@ impl TopLevelComposer { else { unreachable!() }; - instance_to_symbol.insert(String::new(), simple_name.to_string()); + // do this only if destination differs from current + if let Some(dest) == /* also extract destination from the decorator ... */ + { + instance_to_symbol.insert(String::new(), simple_name.to_string()); + } continue; } if !decorator_list.is_empty() && matches!(&decorator_list[0].node, ast::ExprKind::Name{ id, .. } if id == &"subkernel".into()) { - instance_to_symbol.insert(String::new(), simple_name.to_string()); + // do this only if destination differs from current + if let Some(dest) == /* also extract destination from the decorator ... */ + { + instance_to_symbol.insert(String::new(), simple_name.to_string()); + } + continue; } diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 41dcf0df..c211c62a 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -199,7 +199,7 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) { } } - composer.start_analysis(true).unwrap(); + composer.start_analysis(true, None).unwrap(); for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.builtin_num).enumerate() { @@ -563,7 +563,7 @@ fn test_analyze(source: &[&str], res: &[&str]) { } } - if let Err(msg) = composer.start_analysis(false) { + if let Err(msg) = composer.start_analysis(false, None) { if print { println!("{}", msg.iter().sorted().join("\n----------\n")); } else { @@ -748,7 +748,7 @@ fn test_inference(source: Vec<&str>, res: &[&str]) { } } - if let Err(msg) = composer.start_analysis(true) { + if let Err(msg) = composer.start_analysis(true, None) { if print { println!("{}", msg.iter().sorted().join("\n----------\n")); } else { diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 965d274e..befcbb8d 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -408,7 +408,7 @@ fn main() { let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache); let signature = store.add_cty(signature); - if let Err(errors) = composer.start_analysis(true) { + if let Err(errors) = composer.start_analysis(true, None) { let error_count = errors.len(); eprintln!("{error_count} error(s) occurred during top level analysis.");