diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 1ea9611e..9e5ad3d4 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -23,7 +23,7 @@ class virtual(Generic[T]): compiler = nac3artiq.NAC3(core_arguments["target"]) allow_module_registration = True registered_modules = set() - +nac3annotated_class_ids = set() class KernelInvariant(Generic[T]): pass @@ -69,6 +69,7 @@ def nac3(cls): All classes containing kernels or portable methods must use this decorator. """ register_module_of(cls) + nac3annotated_class_ids.add(id(cls)) return cls @@ -111,7 +112,7 @@ class Core: def run(self, method, *args, **kwargs): global allow_module_registration if allow_module_registration: - compiler.analyze_modules(registered_modules) + compiler.analyze_modules(registered_modules, nac3annotated_class_ids) allow_module_registration = False if hasattr(method, "__self__"): diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 94a290a6..9e7cae5e 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -11,7 +11,7 @@ use inkwell::{ use pyo3::prelude::*; use pyo3::{exceptions, types::PyList, types::PySet, types::PyBytes}; use nac3parser::{ - ast::{self, StrRef}, + ast::{self, StrRef, Constant::Str}, parser::{self, parse_program}, }; @@ -76,7 +76,7 @@ struct Nac3 { } impl Nac3 { - fn register_module_impl(&mut self, obj: PyObject) -> PyResult<()> { + fn register_module_impl(&mut self, obj: PyObject, nac3_annotated_cls: &PySet) -> PyResult<()> { let mut name_to_pyid: HashMap = HashMap::new(); let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { let obj: &PyAny = obj.extract(py)?; @@ -111,7 +111,7 @@ impl Nac3 { global_value_ids: self.global_value_ids.clone(), class_names: Default::default(), name_to_pyid: name_to_pyid.clone(), - module: obj, + module: obj.clone(), }) as Arc; let mut name_to_def = HashMap::new(); let mut name_to_type = HashMap::new(); @@ -121,6 +121,7 @@ impl Nac3 { ast::StmtKind::ClassDef { ref decorator_list, ref mut body, + ref mut bases, .. } => { let kernels = decorator_list.iter().any(|decorator| { @@ -146,6 +147,33 @@ impl Nac3 { true } }); + bases.retain(|b| { + Python::with_gil(|py| -> PyResult { + let obj: &PyAny = obj.extract(py)?; + let annot_check = |id: &str| -> bool { + let id = py.eval( + &format!("id({})", id), + Some(obj.getattr("__dict__").unwrap().extract().unwrap()), + None + ).unwrap(); + nac3_annotated_cls.contains(id).unwrap() + }; + match &b.node { + ast::ExprKind::Name { id, .. } => Ok(annot_check(&id.to_string())), + ast::ExprKind::Constant { value: Str(id), .. } => + Ok(annot_check(id.split('[').next().unwrap())), + ast::ExprKind::Subscript { value, .. } => { + match &value.node { + ast::ExprKind::Name { id, .. } => Ok(annot_check(&id.to_string()) || *id == "Generic".into()), + ast::ExprKind::Constant { value: Str(id), .. } => + Ok(annot_check(id.split('[').next().unwrap())), + _ => unreachable!("unsupported base declaration") + } + } + _ => unreachable!("unsupported base declaration") + } + }).unwrap() + }); kernels } ast::StmtKind::FunctionDef { @@ -336,9 +364,9 @@ impl Nac3 { }) } - fn analyze_modules(&mut self, modules: &PySet) -> PyResult<()> { + fn analyze_modules(&mut self, modules: &PySet, nac3_annotated_cls: &PySet) -> PyResult<()> { for obj in modules.iter() { - self.register_module_impl(obj.into())?; + self.register_module_impl(obj.into(), nac3_annotated_cls)?; } Ok(()) }