From 0439bf6aef5a48ddc0b24e871bc7bd6028a2f226 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Sat, 15 Jan 2022 04:43:39 +0800 Subject: [PATCH] nac3artiq: fix errors of non-primitive object when running multiple kernels --- nac3artiq/src/lib.rs | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index a180a283..1d7c476b 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -78,9 +78,7 @@ struct Nac3 { builtins_ty: HashMap, builtins_def: HashMap, pyid_to_def: Arc>>, - pyid_to_type: Arc>>, primitive_ids: PrimitivePythonId, - global_value_ids: Arc>>, working_directory: TempDir, top_levels: Vec, } @@ -406,8 +404,6 @@ impl Nac3 { primitive_ids, top_levels: Default::default(), pyid_to_def: Default::default(), - pyid_to_type: Default::default(), - global_value_ids: Default::default(), working_directory, }) } @@ -451,8 +447,6 @@ impl Nac3 { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" }); - let mut id_to_def = HashMap::new(); - let mut id_to_type = HashMap::new(); let builtins = PyModule::import(py, "builtins")?; let typings = PyModule::import(py, "typing")?; @@ -466,6 +460,8 @@ impl Nac3 { }; let mut module_to_resolver_cache: HashMap = HashMap::new(); + let pyid_to_type = Arc::new(RwLock::new(HashMap::::new())); + let global_value_ids = Arc::new(RwLock::new(HashSet::::new())); for (stmt, path, module) in self.top_levels.iter() { let py_module: &PyAny = module.extract(py)?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?; @@ -486,9 +482,9 @@ impl Nac3 { id_to_type: self.builtins_ty.clone().into(), id_to_def: self.builtins_def.clone().into(), pyid_to_def: self.pyid_to_def.clone(), - pyid_to_type: self.pyid_to_type.clone(), + pyid_to_type: pyid_to_type.clone(), primitive_ids: self.primitive_ids.clone(), - global_value_ids: self.global_value_ids.clone(), + global_value_ids: global_value_ids.clone(), class_names: Default::default(), name_to_pyid: name_to_pyid.clone(), module: module.clone(), @@ -508,19 +504,12 @@ impl Nac3 { .register_top_level(stmt.clone(), Some(resolver.clone()), path.clone()) .map_err(|e| exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e)))?; let id = *name_to_pyid.get(&name).unwrap(); - id_to_def.insert(id, def_id); - if let Some(ty) = ty { - id_to_type.insert(id, ty); - } - } - { - let mut map = self.pyid_to_def.write(); - for (id, def) in id_to_def.into_iter() { - map.insert(id, def); - } - let mut map = self.pyid_to_type.write(); - for (id, ty) in id_to_type.into_iter() { - map.insert(id, ty); + self.pyid_to_def.write().insert(id, def_id); + { + let mut pyid_to_ty = pyid_to_type.write(); + if let Some(ty) = ty { + pyid_to_ty.insert(id, ty); + } } } @@ -553,9 +542,9 @@ impl Nac3 { id_to_type: self.builtins_ty.clone().into(), id_to_def: self.builtins_def.clone().into(), pyid_to_def: self.pyid_to_def.clone(), - pyid_to_type: self.pyid_to_type.clone(), + pyid_to_type: pyid_to_type.clone(), primitive_ids: self.primitive_ids.clone(), - global_value_ids: self.global_value_ids.clone(), + global_value_ids: global_value_ids.clone(), class_names: Default::default(), id_to_pyval: Default::default(), id_to_primitive: Default::default(),