Merge pull request 'fix errors of non-primitive host object when running multiple kernels' (#171) from multiple_kernel_err into master

Reviewed-on: M-Labs/nac3#171
This commit is contained in:
pca006132 2022-01-27 14:46:22 +08:00
commit 304181fd8c
1 changed files with 12 additions and 23 deletions

View File

@ -78,9 +78,7 @@ struct Nac3 {
builtins_ty: HashMap<StrRef, Type>, builtins_ty: HashMap<StrRef, Type>,
builtins_def: HashMap<StrRef, DefinitionId>, builtins_def: HashMap<StrRef, DefinitionId>,
pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>, pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
primitive_ids: PrimitivePythonId, primitive_ids: PrimitivePythonId,
global_value_ids: Arc<RwLock<HashSet<u64>>>,
working_directory: TempDir, working_directory: TempDir,
top_levels: Vec<TopLevelComponent>, top_levels: Vec<TopLevelComponent>,
} }
@ -406,8 +404,6 @@ impl Nac3 {
primitive_ids, primitive_ids,
top_levels: Default::default(), top_levels: Default::default(),
pyid_to_def: Default::default(), pyid_to_def: Default::default(),
pyid_to_type: Default::default(),
global_value_ids: Default::default(),
working_directory, working_directory,
}) })
} }
@ -451,8 +447,6 @@ impl Nac3 {
kernel_ann: Some("Kernel"), kernel_ann: Some("Kernel"),
kernel_invariant_ann: "KernelInvariant" kernel_invariant_ann: "KernelInvariant"
}); });
let mut id_to_def = HashMap::new();
let mut id_to_type = HashMap::new();
let builtins = PyModule::import(py, "builtins")?; let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?; let typings = PyModule::import(py, "typing")?;
@ -466,6 +460,8 @@ impl Nac3 {
}; };
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new(); let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
let pyid_to_type = Arc::new(RwLock::new(HashMap::<u64, Type>::new()));
let global_value_ids = Arc::new(RwLock::new(HashSet::<u64>::new()));
for (stmt, path, module) in self.top_levels.iter() { for (stmt, path, module) in self.top_levels.iter() {
let py_module: &PyAny = module.extract(py)?; let py_module: &PyAny = module.extract(py)?;
let module_id: u64 = id_fn.call1((py_module,))?.extract()?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
@ -486,9 +482,9 @@ impl Nac3 {
id_to_type: self.builtins_ty.clone().into(), id_to_type: self.builtins_ty.clone().into(),
id_to_def: self.builtins_def.clone().into(), id_to_def: self.builtins_def.clone().into(),
pyid_to_def: self.pyid_to_def.clone(), pyid_to_def: self.pyid_to_def.clone(),
pyid_to_type: self.pyid_to_type.clone(), pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(), primitive_ids: self.primitive_ids.clone(),
global_value_ids: self.global_value_ids.clone(), global_value_ids: global_value_ids.clone(),
class_names: Default::default(), class_names: Default::default(),
name_to_pyid: name_to_pyid.clone(), name_to_pyid: name_to_pyid.clone(),
module: module.clone(), module: module.clone(),
@ -508,19 +504,12 @@ impl Nac3 {
.register_top_level(stmt.clone(), Some(resolver.clone()), path.clone()) .register_top_level(stmt.clone(), Some(resolver.clone()), path.clone())
.map_err(|e| exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e)))?; .map_err(|e| exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e)))?;
let id = *name_to_pyid.get(&name).unwrap(); let id = *name_to_pyid.get(&name).unwrap();
id_to_def.insert(id, def_id); self.pyid_to_def.write().insert(id, def_id);
if let Some(ty) = ty {
id_to_type.insert(id, ty);
}
}
{ {
let mut map = self.pyid_to_def.write(); let mut pyid_to_ty = pyid_to_type.write();
for (id, def) in id_to_def.into_iter() { if let Some(ty) = ty {
map.insert(id, def); pyid_to_ty.insert(id, ty);
} }
let mut map = self.pyid_to_type.write();
for (id, ty) in id_to_type.into_iter() {
map.insert(id, ty);
} }
} }
@ -553,9 +542,9 @@ impl Nac3 {
id_to_type: self.builtins_ty.clone().into(), id_to_type: self.builtins_ty.clone().into(),
id_to_def: self.builtins_def.clone().into(), id_to_def: self.builtins_def.clone().into(),
pyid_to_def: self.pyid_to_def.clone(), pyid_to_def: self.pyid_to_def.clone(),
pyid_to_type: self.pyid_to_type.clone(), pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(), primitive_ids: self.primitive_ids.clone(),
global_value_ids: self.global_value_ids.clone(), global_value_ids: global_value_ids.clone(),
class_names: Default::default(), class_names: Default::default(),
id_to_pyval: Default::default(), id_to_pyval: Default::default(),
id_to_primitive: Default::default(), id_to_primitive: Default::default(),