1
0
forked from M-Labs/nac3

nac3artiq: handle name_to_pyid in compilation

python variables can change between kernel invocations
This commit is contained in:
pca006132 2021-12-04 22:57:48 +08:00 committed by Gitea
parent 2938eacd16
commit 65bc1e5fa4
2 changed files with 60 additions and 57 deletions

View File

@ -15,7 +15,7 @@ use nac3parser::{
parser::{self, parse_program},
};
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyList, types::PySet};
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use parking_lot::{Mutex, RwLock};
@ -63,12 +63,7 @@ pub struct PrimitivePythonId {
virtual_id: u64,
}
type TopLevelComponent = (
Stmt,
Arc<dyn SymbolResolver + Send + Sync>,
String,
Rc<HashMap<StrRef, u64>>,
);
type TopLevelComponent = (Stmt, String, PyObject);
// TopLevelComposer is unsendable as it holds the unification table, which is
// unsendable due to Rc. Arc would cause a performance hit.
@ -94,33 +89,11 @@ impl Nac3 {
module: PyObject,
registered_class_ids: &HashSet<u64>,
) -> PyResult<()> {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let (module_name, source_file, helper) =
Python::with_gil(|py| -> PyResult<(String, String, PythonHelper)> {
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
let module: &PyAny = module.extract(py)?;
let builtins = PyModule::import(py, "builtins")?;
let id_fn = builtins.getattr("id")?;
let members: &PyList = PyModule::import(py, "inspect")?
.getattr("getmembers")?
.call1((module,))?
.cast_as()?;
for member in members.iter() {
let key: &str = member.get_item(0)?.extract()?;
let val = id_fn.call1((member.get_item(1)?,))?.extract()?;
name_to_pyid.insert(key.into(), val);
}
let typings = PyModule::import(py, "typing")?;
let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap().to_object(py),
len_fn: builtins.getattr("len").unwrap().to_object(py),
type_fn: builtins.getattr("type").unwrap().to_object(py),
origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py),
args_ty_fn: typings.getattr("get_args").unwrap().to_object(py),
};
Ok((
module.getattr("__name__")?.extract()?,
module.getattr("__file__")?.extract()?,
helper,
))
})?;
@ -130,20 +103,6 @@ impl Nac3 {
let parser_result = parser::parse_program(&source)
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {}", e)))?;
let resolver = Arc::new(Resolver(Arc::new(InnerResolver {
id_to_type: self.builtins_ty.clone().into(),
id_to_def: self.builtins_def.clone().into(),
pyid_to_def: self.pyid_to_def.clone(),
pyid_to_type: self.pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(),
global_value_ids: self.global_value_ids.clone(),
class_names: Default::default(),
name_to_pyid: name_to_pyid.clone(),
module: module.clone(),
helper,
}))) as Arc<dyn SymbolResolver + Send + Sync>;
let name_to_pyid = Rc::new(name_to_pyid);
for mut stmt in parser_result.into_iter() {
let include = match stmt.node {
ast::StmtKind::ClassDef {
@ -209,12 +168,8 @@ impl Nac3 {
};
if include {
self.top_levels.push((
stmt,
resolver.clone(),
module_name.clone(),
name_to_pyid.clone(),
));
self.top_levels
.push((stmt, module_name.clone(), module.clone()));
}
}
Ok(())
@ -423,7 +378,54 @@ impl Nac3 {
let (mut composer, _, _) = TopLevelComposer::new(self.builtins.clone());
let mut id_to_def = HashMap::new();
let mut id_to_type = HashMap::new();
for (stmt, resolver, path, name_to_pyid) in self.top_levels.iter() {
let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let id_fn = builtins.getattr("id")?;
let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap().to_object(py),
len_fn: builtins.getattr("len").unwrap().to_object(py),
type_fn: builtins.getattr("type").unwrap().to_object(py),
origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py),
args_ty_fn: typings.getattr("get_args").unwrap().to_object(py),
};
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
for (stmt, path, module) in self.top_levels.iter() {
let py_module: &PyAny = module.extract(py)?;
let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
let helper = helper.clone();
let (name_to_pyid, resolver) = module_to_resolver_cache
.get(&module_id)
.cloned()
.unwrap_or_else(|| {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let members: &PyDict =
py_module.getattr("__dict__").unwrap().cast_as().unwrap();
for (key, val) in members.iter() {
let key: &str = key.extract().unwrap();
let val = id_fn.call1((val,)).unwrap().extract().unwrap();
name_to_pyid.insert(key.into(), val);
}
let resolver = Arc::new(Resolver(Arc::new(InnerResolver {
id_to_type: self.builtins_ty.clone().into(),
id_to_def: self.builtins_def.clone().into(),
pyid_to_def: self.pyid_to_def.clone(),
pyid_to_type: self.pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(),
global_value_ids: self.global_value_ids.clone(),
class_names: Default::default(),
name_to_pyid: name_to_pyid.clone(),
module: module.clone(),
helper,
})))
as Arc<dyn SymbolResolver + Send + Sync>;
let name_to_pyid = Rc::new(name_to_pyid);
module_to_resolver_cache
.insert(module_id, (name_to_pyid.clone(), resolver.clone()));
(name_to_pyid, resolver)
});
let (name, def_id, ty) = composer
.register_top_level(stmt.clone(), Some(resolver.clone()), path.clone())
.unwrap();

View File

@ -39,6 +39,7 @@ pub struct InnerResolver {
pub struct Resolver(pub Arc<InnerResolver>);
#[derive(Clone)]
pub struct PythonHelper {
pub type_fn: PyObject,
pub len_fn: PyObject,