diff --git a/nac3artiq/language.py b/nac3artiq/language.py index 77d79b2..101bbcf 100644 --- a/nac3artiq/language.py +++ b/nac3artiq/language.py @@ -1,4 +1,4 @@ -from inspect import isclass +from inspect import isclass, getmodule from functools import wraps import nac3artiq @@ -8,28 +8,28 @@ __all__ = ["extern", "kernel"] nac3 = nac3artiq.NAC3() -allow_object_registration = True +allow_module_registration = True def extern(function): - assert allow_object_registration - nac3.register_object(function) + assert allow_module_registration + nac3.register_module(getmodule(function)) return function -def kernel(function_or_class): - global allow_object_registration +def kernel(class_or_function): + global allow_module_registration - if isclass(function_or_class): - assert allow_object_registration - nac3.register_object(function_or_class) - return function_or_class + assert allow_module_registration + nac3.register_module(getmodule(class_or_function)) + if isclass(class_or_function): + return class_or_function else: - @wraps(function_or_class) + @wraps(class_or_function) def run_on_core(self, *args, **kwargs): - global allow_object_registration - if allow_object_registration: + global allow_module_registration + if allow_module_registration: nac3.analyze() - allow_object_registration = False - nac3.compile_method(self.__class__.__name__, function_or_class.__name__) + allow_module_registration = False + nac3.compile_method(self.__class__.__name__, class_or_function.__name__) return run_on_core diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 4dec48a..a391dc6 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -1,3 +1,4 @@ +use std::fs; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::path::Path; @@ -5,7 +6,7 @@ use std::process::Command; use pyo3::prelude::*; use pyo3::exceptions; -use rustpython_parser::parser; +use rustpython_parser::{ast, parser}; use inkwell::{ passes::{PassManager, PassManagerBuilder}, targets::*, @@ -31,7 +32,8 @@ struct Nac3 { internal_resolver: Arc, resolver: Arc>, composer: TopLevelComposer, - top_level: Option> + top_level: Option>, + registered_module_ids: HashSet } #[pymethods] @@ -63,33 +65,57 @@ impl Nac3 { internal_resolver, resolver, composer, - top_level: None + top_level: None, + registered_module_ids: HashSet::new() } } - fn register_object(&mut self, obj: PyObject) -> PyResult<()> { - Python::with_gil(|py| -> PyResult<()> { + fn register_module(&mut self, obj: PyObject) -> PyResult<()> { + let module_info = Python::with_gil(|py| -> PyResult> { let obj: &PyAny = obj.extract(py)?; + let builtins = PyModule::import(py, "builtins")?; + let id = builtins.getattr("id")?.call1((obj, ))?.extract()?; + if self.registered_module_ids.insert(id) { + Ok(Some((obj.getattr("__name__")?.extract()?, obj.getattr("__file__")?.extract()?))) + } else { + Ok(None) + } + })?; - let source = PyModule::import(py, "inspect")?.getattr("getsource")?.call1((obj, ))?.extract()?; - let parser_result = parser::parse_program(source).map_err(|e| + if let Some((module_name, source_file)) = module_info { + let source = fs::read_to_string(source_file).map_err(|e| + exceptions::PyIOError::new_err(format!("failed to read input file: {}", e)))?; + let parser_result = parser::parse_program(&source).map_err(|e| exceptions::PySyntaxError::new_err(format!("failed to parse host object source: {}", e)))?; for stmt in parser_result.into_iter() { - let (name, def_id, ty) = self.composer.register_top_level( - stmt, - Some(self.resolver.clone()), - "__main__".into(), - ).unwrap(); + let include = match &stmt.node { + ast::StmtKind::ClassDef { decorator_list, .. } => { + decorator_list.iter().any(|decorator| if let ast::ExprKind::Name { id, .. } = decorator.node + { id.to_string() == "kernel" || id.to_string() == "portable" } else { false }) + }, + ast::StmtKind::FunctionDef { decorator_list, .. } => { + decorator_list.iter().any(|decorator| if let ast::ExprKind::Name { id, .. } = decorator.node + { id.to_string() == "extern" || id.to_string() == "portable" } else { false }) + }, + _ => false + }; - self.internal_resolver.add_id_def(name.clone(), def_id); - if let Some(ty) = ty { - self.internal_resolver.add_id_type(name, ty); + if include { + let (name, def_id, ty) = self.composer.register_top_level( + stmt, + Some(self.resolver.clone()), + module_name.clone(), + ).unwrap(); + + self.internal_resolver.add_id_def(name.clone(), def_id); + if let Some(ty) = ty { + self.internal_resolver.add_id_type(name, ty); + } } } - - Ok(()) - }) + } + Ok(()) } fn analyze(&mut self) -> PyResult<()> {