1
0
forked from M-Labs/nac3

nac3artiq: support kernels sent by content

This commit is contained in:
mwojcik 2024-11-19 16:55:24 +08:00
parent 979209a526
commit 4baa7be92e
2 changed files with 45 additions and 14 deletions

View File

@ -206,7 +206,7 @@ class Core:
embedding = EmbeddingMap() embedding = EmbeddingMap()
if allow_registration: if allow_registration:
compiler.analyze(registered_functions, registered_classes) compiler.analyze(registered_functions, registered_classes, set())
allow_registration = False allow_registration = False
if hasattr(method, "__self__"): if hasattr(method, "__self__"):

View File

@ -30,7 +30,7 @@ use parking_lot::{Mutex, RwLock};
use pyo3::{ use pyo3::{
create_exception, exceptions, create_exception, exceptions,
prelude::*, prelude::*,
types::{PyBytes, PyDict, PySet}, types::{PyBytes, PyDict, PyNone, PySet},
}; };
use tempfile::{self, TempDir}; use tempfile::{self, TempDir};
@ -148,14 +148,32 @@ impl Nac3 {
module: &PyObject, module: &PyObject,
registered_class_ids: &HashSet<u64>, registered_class_ids: &HashSet<u64>,
) -> PyResult<()> { ) -> PyResult<()> {
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { let (module_name, source_file, source) =
let module: &PyAny = module.extract(py)?; Python::with_gil(|py| -> PyResult<(String, String, String)> {
Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?)) let module: &PyAny = module.extract(py)?;
})?; let source_file = module.getattr("__file__");
let (source_file, source) = if let Ok(source_file) = source_file {
let source_file = source_file.extract()?;
(
source_file,
fs::read_to_string(&source_file).map_err(|e| {
exceptions::PyIOError::new_err(format!(
"failed to read input file: {e}"
))
})?,
)
} else {
// kernels submitted by content have no file
// but still can provide source by StringLoader
let get_src_fn = module
.getattr("__loader__")?
.extract::<PyObject>()?
.getattr(py, "get_source")?;
("<expcontent>", get_src_fn.call1(py, (PyNone::get(py),))?.extract(py)?)
};
Ok((module.getattr("__name__")?.extract()?, source_file.to_string(), source))
})?;
let source = fs::read_to_string(&source_file).map_err(|e| {
exceptions::PyIOError::new_err(format!("failed to read input file: {e}"))
})?;
let parser_result = parse_program(&source, source_file.into()) let parser_result = parse_program(&source, source_file.into())
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?; .map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?;
@ -1085,7 +1103,12 @@ impl Nac3 {
}) })
} }
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> { fn analyze(
&mut self,
functions: &PySet,
classes: &PySet,
content_modules: &PySet,
) -> PyResult<()> {
let (modules, class_ids) = let (modules, class_ids) =
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> { Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
let mut modules: HashMap<u64, PyObject> = HashMap::new(); let mut modules: HashMap<u64, PyObject> = HashMap::new();
@ -1095,14 +1118,22 @@ impl Nac3 {
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?; let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
for function in functions { for function in functions {
let module = getmodule_fn.call1((function,))?.extract()?; let module: PyObject = getmodule_fn.call1((function,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module); if !module.is_none(py) {
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
} }
for class in classes { for class in classes {
let module = getmodule_fn.call1((class,))?.extract()?; let module: PyObject = getmodule_fn.call1((class,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module); if !module.is_none(py) {
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
class_ids.insert(id_fn.call1((class,))?.extract()?); class_ids.insert(id_fn.call1((class,))?.extract()?);
} }
for module in content_modules {
let module: PyObject = module.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module.into());
}
Ok((modules, class_ids)) Ok((modules, class_ids))
})?; })?;