forked from M-Labs/nac3
1
0
Fork 0

pass destination to composer

This commit is contained in:
mwojcik 2024-10-31 17:36:57 +08:00
parent 2ef8b300b2
commit b0cb74423d
4 changed files with 72 additions and 17 deletions

View File

@ -383,6 +383,7 @@ impl Nac3 {
let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py);
let store_str = embedding_map.getattr("store_str").unwrap().to_object(py);
let store_fun = embedding_map.getattr("store_function").unwrap().to_object(py);
let store_subkernel = embedding_map.getattr("store_subkernel").unwrap().to_object(py);
let host_attributes = embedding_map.getattr("attributes_writeback").unwrap().to_object(py);
let global_value_ids: Arc<RwLock<HashMap<_, _>>> = Arc::new(RwLock::new(HashMap::new()));
let helper = PythonHelper {
@ -502,21 +503,26 @@ impl Nac3 {
.iter()
.any(|decorator| decorator_id_string(decorator) == Some("subkernel".to_string()))
{
if let Constant::Int(destination) = decorator_get_destination(decorator) {
if destination < 0 || destination > 255 {
if let Constant::Int(sk_dest) = decorator_get_destination(decorator) {
if sk_dest < 0 || sk_dest > 255 {
return Err(CompileError::new_err(format!(
"compilation failed\n----------\nSubkernel destination must be between 0 and 255 (at {})",
stmt.location
)));
}
subkernel_ids.push((None, def_id, destination));
if sk_dest != destination {
// subkernels with the same destination as currently compiled kernel
// are treated as normal kernels
store_subkernel.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap();
subkernel_ids.push((None, def_id, destination));
}
} else {
return Err(CompileError::new_err(format!(
"compilation failed\n----------\nDestination must be provided for subkernels (at {})",
stmt.location
)));
}
// store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap();
}
}
@ -541,20 +547,22 @@ impl Nac3 {
}
rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async));
} else if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "subkernel".into())) {
if let Constant::Int(destination) = decorator_get_destination(decorator) {
if let Constant::Int(sk_dest) = decorator_get_destination(decorator) {
if name == &"__init__".into() {
return Err(CompileError::new_err(format!(
"compilation failed\n----------\nThe constructor of class {} should not be decorated with subkernel decorator (at {})",
class_name, stmt.location
)));
}
if destination < 0 || destination > 255 {
if sk_dest < 0 || sk_dest > 255 {
return Err(CompileError::new_err(format!(
"compilation failed\n----------\nSubkernel destination must be between 0 and 255 (at {})",
stmt.location
)));
}
subkernel_ids.push((Some((class_obj.clone(), *name)), def_id, destination));
if sk_dest != destination {
subkernel_ids.push((Some((class_obj.clone(), *name)), def_id, sk_dest));
}
} else {
return Err(CompileError::new_err(format!(
"compilation failed\n----------\nDestination must be provided for subkernels (at {})",
@ -638,7 +646,7 @@ impl Nac3 {
);
let signature = store.add_cty(signature);
if let Err(e) = composer.start_analysis(true) {
if let Err(e) = composer.start_analysis(true, Some(destination)) {
// report error of __modinit__ separately
return if e.iter().any(|err| err.contains("<nac3_synthesized_modinit>")) {
let msg = Self::report_modinit(
@ -701,6 +709,43 @@ impl Nac3 {
}
}
}
for (class_data, id, sk_dest) in &subkernel_ids {
let mut def = defs[id.0].write();
match &mut *def {
TopLevelDef::Function { codegen_callback, .. } => {
*codegen_callback = Some(subkernel_codegen_callback());
}
TopLevelDef::Class { methods, .. } => {
let (class_def, method_name) = class_data.as_ref().unwrap();
for (name, _, id) in &*methods {
if name != method_name {
continue;
}
if let TopLevelDef::Function { codegen_callback, .. } =
&mut *defs[id.0].write()
{
*codegen_callback = Some(subkernel_codegen_callback());
store_subkernel
.call1(
py,
(
id.0.into_py(py),
class_def
.getattr(py, name.to_string().as_str())
.unwrap(),
),
)
.unwrap();
}
}
}
TopLevelDef::Variable { .. } => {
return Err(CompileError::new_err(String::from(
"Unsupported @subkernel annotation on global variable",
)))
}
}
}
}
let instance = {

View File

@ -463,13 +463,13 @@ impl TopLevelComposer {
Ok((name, DefinitionId(self.definition_ast_list.len() - 1), Some(ty_to_be_unified)))
}
pub fn start_analysis(&mut self, inference: bool) -> Result<(), HashSet<String>> {
pub fn start_analysis(&mut self, inference: bool, destination: Option<u8>) -> Result<(), HashSet<String>> {
self.analyze_top_level_class_type_var()?;
self.analyze_top_level_class_bases()?;
self.analyze_top_level_class_fields_methods()?;
self.analyze_top_level_function()?;
if inference {
self.analyze_function_instance()?;
self.analyze_function_instance(destination)?;
}
self.analyze_top_level_variables()?;
Ok(())
@ -1736,7 +1736,8 @@ impl TopLevelComposer {
/// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of
/// [`TopLevelDef::Function`]
fn analyze_function_instance(&mut self) -> Result<(), HashSet<String>> {
/// destination is an ARTIQ-only argument necessary for proper support of subkernels
fn analyze_function_instance(&mut self, destination: Option<u8>) -> Result<(), HashSet<String>> {
// first get the class constructor type correct for the following type check in function body
// also do class field instantiation check
let init_str_id = "__init__".into();
@ -2131,14 +2132,23 @@ impl TopLevelComposer {
else {
unreachable!()
};
instance_to_symbol.insert(String::new(), simple_name.to_string());
// do this only if destination differs from current
if let Some(dest) == /* also extract destination from the decorator ... */
{
instance_to_symbol.insert(String::new(), simple_name.to_string());
}
continue;
}
if !decorator_list.is_empty()
&& matches!(&decorator_list[0].node,
ast::ExprKind::Name{ id, .. } if id == &"subkernel".into())
{
instance_to_symbol.insert(String::new(), simple_name.to_string());
// do this only if destination differs from current
if let Some(dest) == /* also extract destination from the decorator ... */
{
instance_to_symbol.insert(String::new(), simple_name.to_string());
}
continue;
}

View File

@ -199,7 +199,7 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
}
}
composer.start_analysis(true).unwrap();
composer.start_analysis(true, None).unwrap();
for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.builtin_num).enumerate()
{
@ -563,7 +563,7 @@ fn test_analyze(source: &[&str], res: &[&str]) {
}
}
if let Err(msg) = composer.start_analysis(false) {
if let Err(msg) = composer.start_analysis(false, None) {
if print {
println!("{}", msg.iter().sorted().join("\n----------\n"));
} else {
@ -748,7 +748,7 @@ fn test_inference(source: Vec<&str>, res: &[&str]) {
}
}
if let Err(msg) = composer.start_analysis(true) {
if let Err(msg) = composer.start_analysis(true, None) {
if print {
println!("{}", msg.iter().sorted().join("\n----------\n"));
} else {

View File

@ -408,7 +408,7 @@ fn main() {
let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache);
let signature = store.add_cty(signature);
if let Err(errors) = composer.start_analysis(true) {
if let Err(errors) = composer.start_analysis(true, None) {
let error_count = errors.len();
eprintln!("{error_count} error(s) occurred during top level analysis.");