forked from M-Labs/nac3
1
0
Fork 0

Merge pull request 'fix errors of non-primitive host object when running multiple kernels' (#171) from multiple_kernel_err into master

Reviewed-on: M-Labs/nac3#171
This commit is contained in:
pca006132 2022-01-27 14:46:22 +08:00
commit 304181fd8c
1 changed files with 12 additions and 23 deletions

View File

@ -78,9 +78,7 @@ struct Nac3 {
builtins_ty: HashMap<StrRef, Type>,
builtins_def: HashMap<StrRef, DefinitionId>,
pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
primitive_ids: PrimitivePythonId,
global_value_ids: Arc<RwLock<HashSet<u64>>>,
working_directory: TempDir,
top_levels: Vec<TopLevelComponent>,
}
@ -406,8 +404,6 @@ impl Nac3 {
primitive_ids,
top_levels: Default::default(),
pyid_to_def: Default::default(),
pyid_to_type: Default::default(),
global_value_ids: Default::default(),
working_directory,
})
}
@ -451,8 +447,6 @@ impl Nac3 {
kernel_ann: Some("Kernel"),
kernel_invariant_ann: "KernelInvariant"
});
let mut id_to_def = HashMap::new();
let mut id_to_type = HashMap::new();
let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
@ -466,6 +460,8 @@ impl Nac3 {
};
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
let pyid_to_type = Arc::new(RwLock::new(HashMap::<u64, Type>::new()));
let global_value_ids = Arc::new(RwLock::new(HashSet::<u64>::new()));
for (stmt, path, module) in self.top_levels.iter() {
let py_module: &PyAny = module.extract(py)?;
let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
@ -486,9 +482,9 @@ impl Nac3 {
id_to_type: self.builtins_ty.clone().into(),
id_to_def: self.builtins_def.clone().into(),
pyid_to_def: self.pyid_to_def.clone(),
pyid_to_type: self.pyid_to_type.clone(),
pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(),
global_value_ids: self.global_value_ids.clone(),
global_value_ids: global_value_ids.clone(),
class_names: Default::default(),
name_to_pyid: name_to_pyid.clone(),
module: module.clone(),
@ -508,19 +504,12 @@ impl Nac3 {
.register_top_level(stmt.clone(), Some(resolver.clone()), path.clone())
.map_err(|e| exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e)))?;
let id = *name_to_pyid.get(&name).unwrap();
id_to_def.insert(id, def_id);
if let Some(ty) = ty {
id_to_type.insert(id, ty);
}
}
self.pyid_to_def.write().insert(id, def_id);
{
let mut map = self.pyid_to_def.write();
for (id, def) in id_to_def.into_iter() {
map.insert(id, def);
let mut pyid_to_ty = pyid_to_type.write();
if let Some(ty) = ty {
pyid_to_ty.insert(id, ty);
}
let mut map = self.pyid_to_type.write();
for (id, ty) in id_to_type.into_iter() {
map.insert(id, ty);
}
}
@ -553,9 +542,9 @@ impl Nac3 {
id_to_type: self.builtins_ty.clone().into(),
id_to_def: self.builtins_def.clone().into(),
pyid_to_def: self.pyid_to_def.clone(),
pyid_to_type: self.pyid_to_type.clone(),
pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(),
global_value_ids: self.global_value_ids.clone(),
global_value_ids: global_value_ids.clone(),
class_names: Default::default(),
id_to_pyval: Default::default(),
id_to_primitive: Default::default(),