From 4baa7be92e3c68736c4af86d385ca88d71f8f6e7 Mon Sep 17 00:00:00 2001 From: mwojcik Date: Tue, 19 Nov 2024 16:55:24 +0800 Subject: [PATCH] nac3artiq: support kernels sent by content --- nac3artiq/demo/min_artiq.py | 2 +- nac3artiq/src/lib.rs | 57 ++++++++++++++++++++++++++++--------- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 3840a57a..62d32cc3 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -206,7 +206,7 @@ class Core: embedding = EmbeddingMap() if allow_registration: - compiler.analyze(registered_functions, registered_classes) + compiler.analyze(registered_functions, registered_classes, set()) allow_registration = False if hasattr(method, "__self__"): diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 6e80fd03..300bed8e 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -30,7 +30,7 @@ use parking_lot::{Mutex, RwLock}; use pyo3::{ create_exception, exceptions, prelude::*, - types::{PyBytes, PyDict, PySet}, + types::{PyBytes, PyDict, PyNone, PySet}, }; use tempfile::{self, TempDir}; @@ -148,14 +148,32 @@ impl Nac3 { module: &PyObject, registered_class_ids: &HashSet, ) -> PyResult<()> { - let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { - let module: &PyAny = module.extract(py)?; - Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?)) - })?; + let (module_name, source_file, source) = + Python::with_gil(|py| -> PyResult<(String, String, String)> { + let module: &PyAny = module.extract(py)?; + let source_file = module.getattr("__file__"); + let (source_file, source) = if let Ok(source_file) = source_file { + let source_file = source_file.extract()?; + ( + source_file, + fs::read_to_string(&source_file).map_err(|e| { + exceptions::PyIOError::new_err(format!( + "failed to read input file: {e}" + )) + })?, + ) + } else { + // kernels submitted by content have no file + // but still can provide source by StringLoader + let get_src_fn = module + .getattr("__loader__")? + .extract::()? + .getattr(py, "get_source")?; + ("", get_src_fn.call1(py, (PyNone::get(py),))?.extract(py)?) + }; + Ok((module.getattr("__name__")?.extract()?, source_file.to_string(), source)) + })?; - let source = fs::read_to_string(&source_file).map_err(|e| { - exceptions::PyIOError::new_err(format!("failed to read input file: {e}")) - })?; let parser_result = parse_program(&source, source_file.into()) .map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?; @@ -1085,7 +1103,12 @@ impl Nac3 { }) } - fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> { + fn analyze( + &mut self, + functions: &PySet, + classes: &PySet, + content_modules: &PySet, + ) -> PyResult<()> { let (modules, class_ids) = Python::with_gil(|py| -> PyResult<(HashMap, HashSet)> { let mut modules: HashMap = HashMap::new(); @@ -1095,14 +1118,22 @@ impl Nac3 { let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?; for function in functions { - let module = getmodule_fn.call1((function,))?.extract()?; - modules.insert(id_fn.call1((&module,))?.extract()?, module); + let module: PyObject = getmodule_fn.call1((function,))?.extract()?; + if !module.is_none(py) { + modules.insert(id_fn.call1((&module,))?.extract()?, module); + } } for class in classes { - let module = getmodule_fn.call1((class,))?.extract()?; - modules.insert(id_fn.call1((&module,))?.extract()?, module); + let module: PyObject = getmodule_fn.call1((class,))?.extract()?; + if !module.is_none(py) { + modules.insert(id_fn.call1((&module,))?.extract()?, module); + } class_ids.insert(id_fn.call1((class,))?.extract()?); } + for module in content_modules { + let module: PyObject = module.extract()?; + modules.insert(id_fn.call1((&module,))?.extract()?, module.into()); + } Ok((modules, class_ids)) })?;