From 65bc1e5fa446e7ff664f6567bf96e0ec54416ef6 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Sat, 4 Dec 2021 22:57:48 +0800 Subject: [PATCH] nac3artiq: handle name_to_pyid in compilation python variables can change between kernel invocations --- nac3artiq/src/lib.rs | 116 ++++++++++++++++--------------- nac3artiq/src/symbol_resolver.rs | 1 + 2 files changed, 60 insertions(+), 57 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index f27abde4..46ee93aa 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -15,7 +15,7 @@ use nac3parser::{ parser::{self, parse_program}, }; use pyo3::prelude::*; -use pyo3::{exceptions, types::PyBytes, types::PyList, types::PySet}; +use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet}; use parking_lot::{Mutex, RwLock}; @@ -63,12 +63,7 @@ pub struct PrimitivePythonId { virtual_id: u64, } -type TopLevelComponent = ( - Stmt, - Arc, - String, - Rc>, -); +type TopLevelComponent = (Stmt, String, PyObject); // TopLevelComposer is unsendable as it holds the unification table, which is // unsendable due to Rc. Arc would cause a performance hit. @@ -94,35 +89,13 @@ impl Nac3 { module: PyObject, registered_class_ids: &HashSet, ) -> PyResult<()> { - let mut name_to_pyid: HashMap = HashMap::new(); - let (module_name, source_file, helper) = - Python::with_gil(|py| -> PyResult<(String, String, PythonHelper)> { - let module: &PyAny = module.extract(py)?; - let builtins = PyModule::import(py, "builtins")?; - let id_fn = builtins.getattr("id")?; - let members: &PyList = PyModule::import(py, "inspect")? - .getattr("getmembers")? - .call1((module,))? - .cast_as()?; - for member in members.iter() { - let key: &str = member.get_item(0)?.extract()?; - let val = id_fn.call1((member.get_item(1)?,))?.extract()?; - name_to_pyid.insert(key.into(), val); - } - let typings = PyModule::import(py, "typing")?; - let helper = PythonHelper { - id_fn: builtins.getattr("id").unwrap().to_object(py), - len_fn: builtins.getattr("len").unwrap().to_object(py), - type_fn: builtins.getattr("type").unwrap().to_object(py), - origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py), - args_ty_fn: typings.getattr("get_args").unwrap().to_object(py), - }; - Ok(( - module.getattr("__name__")?.extract()?, - module.getattr("__file__")?.extract()?, - helper, - )) - })?; + let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { + let module: &PyAny = module.extract(py)?; + Ok(( + module.getattr("__name__")?.extract()?, + module.getattr("__file__")?.extract()?, + )) + })?; let source = fs::read_to_string(source_file).map_err(|e| { exceptions::PyIOError::new_err(format!("failed to read input file: {}", e)) @@ -130,20 +103,6 @@ impl Nac3 { let parser_result = parser::parse_program(&source) .map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {}", e)))?; - let resolver = Arc::new(Resolver(Arc::new(InnerResolver { - id_to_type: self.builtins_ty.clone().into(), - id_to_def: self.builtins_def.clone().into(), - pyid_to_def: self.pyid_to_def.clone(), - pyid_to_type: self.pyid_to_type.clone(), - primitive_ids: self.primitive_ids.clone(), - global_value_ids: self.global_value_ids.clone(), - class_names: Default::default(), - name_to_pyid: name_to_pyid.clone(), - module: module.clone(), - helper, - }))) as Arc; - let name_to_pyid = Rc::new(name_to_pyid); - for mut stmt in parser_result.into_iter() { let include = match stmt.node { ast::StmtKind::ClassDef { @@ -209,12 +168,8 @@ impl Nac3 { }; if include { - self.top_levels.push(( - stmt, - resolver.clone(), - module_name.clone(), - name_to_pyid.clone(), - )); + self.top_levels + .push((stmt, module_name.clone(), module.clone())); } } Ok(()) @@ -423,7 +378,54 @@ impl Nac3 { let (mut composer, _, _) = TopLevelComposer::new(self.builtins.clone()); let mut id_to_def = HashMap::new(); let mut id_to_type = HashMap::new(); - for (stmt, resolver, path, name_to_pyid) in self.top_levels.iter() { + + let builtins = PyModule::import(py, "builtins")?; + let typings = PyModule::import(py, "typing")?; + let id_fn = builtins.getattr("id")?; + let helper = PythonHelper { + id_fn: builtins.getattr("id").unwrap().to_object(py), + len_fn: builtins.getattr("len").unwrap().to_object(py), + type_fn: builtins.getattr("type").unwrap().to_object(py), + origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py), + args_ty_fn: typings.getattr("get_args").unwrap().to_object(py), + }; + let mut module_to_resolver_cache: HashMap = HashMap::new(); + + for (stmt, path, module) in self.top_levels.iter() { + let py_module: &PyAny = module.extract(py)?; + let module_id: u64 = id_fn.call1((py_module,))?.extract()?; + let helper = helper.clone(); + let (name_to_pyid, resolver) = module_to_resolver_cache + .get(&module_id) + .cloned() + .unwrap_or_else(|| { + let mut name_to_pyid: HashMap = HashMap::new(); + let members: &PyDict = + py_module.getattr("__dict__").unwrap().cast_as().unwrap(); + for (key, val) in members.iter() { + let key: &str = key.extract().unwrap(); + let val = id_fn.call1((val,)).unwrap().extract().unwrap(); + name_to_pyid.insert(key.into(), val); + } + let resolver = Arc::new(Resolver(Arc::new(InnerResolver { + id_to_type: self.builtins_ty.clone().into(), + id_to_def: self.builtins_def.clone().into(), + pyid_to_def: self.pyid_to_def.clone(), + pyid_to_type: self.pyid_to_type.clone(), + primitive_ids: self.primitive_ids.clone(), + global_value_ids: self.global_value_ids.clone(), + class_names: Default::default(), + name_to_pyid: name_to_pyid.clone(), + module: module.clone(), + helper, + }))) + as Arc; + let name_to_pyid = Rc::new(name_to_pyid); + module_to_resolver_cache + .insert(module_id, (name_to_pyid.clone(), resolver.clone())); + (name_to_pyid, resolver) + }); + let (name, def_id, ty) = composer .register_top_level(stmt.clone(), Some(resolver.clone()), path.clone()) .unwrap(); diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 92baa592..0eeedf25 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -39,6 +39,7 @@ pub struct InnerResolver { pub struct Resolver(pub Arc); +#[derive(Clone)] pub struct PythonHelper { pub type_fn: PyObject, pub len_fn: PyObject,