From 75bd9b9a1510cb4b95ae7bad7a6a9be2ae8321b1 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sat, 19 Dec 2020 15:29:39 +0800 Subject: [PATCH] nac3embedded: ast storage --- nac3embedded/language.py | 4 +- nac3embedded/src/lib.rs | 117 +++++++++++++++++++++++++++++++++------ 2 files changed, 103 insertions(+), 18 deletions(-) diff --git a/nac3embedded/language.py b/nac3embedded/language.py index 57e10d938..1fdee5bf5 100644 --- a/nac3embedded/language.py +++ b/nac3embedded/language.py @@ -9,7 +9,9 @@ __all__ = ["kernel", "portable"] def kernel(function): @wraps(function) def run_on_core(self, *args, **kwargs): - nac3embedded.add_host_object(self) + nac3 = nac3embedded.NAC3() + nac3.register_host_object(self) + nac3.compile_method(self, function.__name__) return run_on_core diff --git a/nac3embedded/src/lib.rs b/nac3embedded/src/lib.rs index bffe23742..91165859f 100644 --- a/nac3embedded/src/lib.rs +++ b/nac3embedded/src/lib.rs @@ -1,24 +1,107 @@ -use pyo3::prelude::*; -use pyo3::wrap_pyfunction; -use pyo3::exceptions; -use rustpython_parser::parser; +use std::collections::HashMap; +use std::collections::hash_map::Entry; -#[pyfunction] -fn add_host_object(obj: PyObject) -> PyResult<()> { - Python::with_gil(|py| -> PyResult<()> { - let obj: &PyAny = obj.extract(py)?; - let inspect = PyModule::import(py, "inspect")?; - let source = inspect.call1("getsource", (obj.get_type(), ))?; - let ast = parser::parse_program(source.extract()?).map_err(|e| - exceptions::PySyntaxError::new_err(format!("failed to parse host object source: {}", e)))?; - println!("{:?}", ast); - Ok(()) - })?; - Ok(()) +use pyo3::prelude::*; +use pyo3::exceptions; +use rustpython_parser::{ast, parser}; + +fn runs_on_core(decorator_list: &[ast::Expression]) -> bool { + for decorator in decorator_list.iter() { + if let ast::ExpressionType::Identifier { name } = &decorator.node { + if name == "kernel" || name == "portable" { + return true + } + } + } + false +} + +#[pyclass(name=NAC3)] +struct Nac3 { + type_definitions: HashMap, + host_objects: HashMap, +} + +#[pymethods] +impl Nac3 { + #[new] + fn new() -> Self { + Nac3 { + type_definitions: HashMap::new(), + host_objects: HashMap::new(), + } + } + + fn register_host_object(&mut self, obj: PyObject) -> PyResult<()> { + Python::with_gil(|py| -> PyResult<()> { + let obj: &PyAny = obj.extract(py)?; + let obj_type = obj.get_type(); + + let builtins = PyModule::import(py, "builtins")?; + let type_id = builtins.call1("id", (obj_type, ))?.extract()?; + + let entry = self.type_definitions.entry(type_id); + if let Entry::Vacant(entry) = entry { + let source = PyModule::import(py, "inspect")?.call1("getsource", (obj_type, ))?; + let ast = parser::parse_program(source.extract()?).map_err(|e| + exceptions::PySyntaxError::new_err(format!("failed to parse host object source: {}", e)))?; + entry.insert(ast); + // TODO: examine AST and recursively register dependencies + }; + + let obj_id = builtins.call1("id", (obj, ))?.extract()?; + match self.host_objects.entry(obj_id) { + Entry::Vacant(entry) => entry.insert(type_id), + Entry::Occupied(_) => return Err( + exceptions::PyValueError::new_err("host object registered twice")), + }; + // TODO: collect other information about host object, e.g. value of fields + + Ok(()) + }) + } + + fn compile_method(&self, obj: PyObject, name: String) -> PyResult<()> { + Python::with_gil(|py| -> PyResult<()> { + let obj: &PyAny = obj.extract(py)?; + let builtins = PyModule::import(py, "builtins")?; + let obj_id = builtins.call1("id", (obj, ))?.extract()?; + + let type_id = self.host_objects.get(&obj_id).ok_or_else(|| + exceptions::PyKeyError::new_err("type of host object not found"))?; + let ast = self.type_definitions.get(&type_id).ok_or_else(|| + exceptions::PyKeyError::new_err("type definition not found"))?; + + if let ast::StatementType::ClassDef { + name: _, + body, + bases: _, + keywords: _, + decorator_list: _ } = &ast.statements[0].node { + for statement in body.iter() { + if let ast::StatementType::FunctionDef { + is_async: _, + name: funcdef_name, + args: _, + body: _, + decorator_list, + returns: _ } = &statement.node { + if runs_on_core(decorator_list) && funcdef_name == &name { + println!("found: {:?}", &statement.node); + } + } + } + } else { + return Err(exceptions::PyValueError::new_err("expected ClassDef for type definition")); + } + + Ok(()) + }) + } } #[pymodule] fn nac3embedded(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(add_host_object, m)?)?; + m.add_class::()?; Ok(()) }