diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 3840a57a..62d32cc3 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -206,7 +206,7 @@ class Core: embedding = EmbeddingMap() if allow_registration: - compiler.analyze(registered_functions, registered_classes) + compiler.analyze(registered_functions, registered_classes, set()) allow_registration = False if hasattr(method, "__self__"): diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 6e80fd03..86ab2045 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -30,7 +30,7 @@ use parking_lot::{Mutex, RwLock}; use pyo3::{ create_exception, exceptions, prelude::*, - types::{PyBytes, PyDict, PySet}, + types::{PyBytes, PyDict, PySet, PyNone}, }; use tempfile::{self, TempDir}; @@ -148,14 +148,22 @@ impl Nac3 { module: &PyObject, registered_class_ids: &HashSet, ) -> PyResult<()> { - let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { + let (module_name, source_file, source) = Python::with_gil(|py| -> PyResult<(String, String, String)> { let module: &PyAny = module.extract(py)?; - Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?)) + let source_file = module.getattr("__file__"); + let (source_file, source) = if let Ok(source_file) = source_file { + let source_file = source_file.extract()?; + (source_file, fs::read_to_string(&source_file).map_err(|e| { + exceptions::PyIOError::new_err(format!("failed to read input file: {e}")) + })?) + } else { + // kernels submitted by content have no file + let get_src_fn = module.getattr("__loader__")?.extract::()?.getattr(py, "get_source")?; + ("", get_src_fn.call1(py, (PyNone::get(py),))?.extract(py)?) + }; + Ok((module.getattr("__name__")?.extract()?, source_file.to_string(), source)) })?; - let source = fs::read_to_string(&source_file).map_err(|e| { - exceptions::PyIOError::new_err(format!("failed to read input file: {e}")) - })?; let parser_result = parse_program(&source, source_file.into()) .map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?; @@ -1085,7 +1093,7 @@ impl Nac3 { }) } - fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> { + fn analyze(&mut self, functions: &PySet, classes: &PySet, content_modules: &PySet) -> PyResult<()> { let (modules, class_ids) = Python::with_gil(|py| -> PyResult<(HashMap, HashSet)> { let mut modules: HashMap = HashMap::new(); @@ -1095,14 +1103,22 @@ impl Nac3 { let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?; for function in functions { - let module = getmodule_fn.call1((function,))?.extract()?; - modules.insert(id_fn.call1((&module,))?.extract()?, module); + let module: PyObject = getmodule_fn.call1((function,))?.extract()?; + if !module.is_none(py) { + modules.insert(id_fn.call1((&module,))?.extract()?, module); + } } for class in classes { - let module = getmodule_fn.call1((class,))?.extract()?; - modules.insert(id_fn.call1((&module,))?.extract()?, module); + let module: PyObject = getmodule_fn.call1((class,))?.extract()?; + if !module.is_none(py) { + modules.insert(id_fn.call1((&module,))?.extract()?, module); + } class_ids.insert(id_fn.call1((class,))?.extract()?); } + for module in content_modules { + let module: PyObject = module.extract()?; + modules.insert(id_fn.call1((&module,))?.extract()?, module.into()); + } Ok((modules, class_ids)) })?;