diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 529eea3d..88cb9815 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -1,4 +1,4 @@ -from inspect import getfullargspec, getmodule +from inspect import getfullargspec from functools import wraps from types import SimpleNamespace from numpy import int32, int64 @@ -11,34 +11,38 @@ __all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3", "Core", "TTLOut", "parallel", "sequential"] -import device_db -core_arguments = device_db.device_db["core"]["arguments"] - -compiler = nac3artiq.NAC3(core_arguments["target"]) -allow_module_registration = True -registered_modules = set() - - T = TypeVar('T') class KernelInvariant(Generic[T]): pass -def register_module_of(obj): - assert allow_module_registration - # Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side. - registered_modules.add(getmodule(obj)) +import device_db +core_arguments = device_db.device_db["core"]["arguments"] + +compiler = nac3artiq.NAC3(core_arguments["target"]) +allow_registration = True +# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side. +registered_functions = set() +registered_classes = set() + +def register_function(fun): + assert allow_registration + registered_functions.add(fun) + +def register_class(cls): + assert allow_registration + registered_classes.add(cls) def extern(function): """Decorates a function declaration defined by the core device runtime.""" - register_module_of(function) + register_function(function) return function def kernel(function_or_method): """Decorates a function or method to be executed on the core device.""" - register_module_of(function_or_method) + register_function(function_or_method) argspec = getfullargspec(function_or_method) if argspec.args and argspec.args[0] == "self": @wraps(function_or_method) @@ -54,7 +58,7 @@ def kernel(function_or_method): def portable(function): """Decorates a function or method to be executed on the same device (host/core device) as the caller.""" - register_module_of(function) + register_function(function) return function @@ -63,7 +67,7 @@ def nac3(cls): Decorates a class to be analyzed by NAC3. All classes containing kernels or portable methods must use this decorator. """ - register_module_of(cls) + register_class(cls) return cls @@ -104,10 +108,10 @@ class Core: self.ref_period = core_arguments["ref_period"] def run(self, method, *args, **kwargs): - global allow_module_registration - if allow_module_registration: - compiler.analyze_modules(registered_modules) - allow_module_registration = False + global allow_registration + if allow_registration: + compiler.analyze(registered_functions, registered_classes) + allow_registration = False if hasattr(method, "__self__"): obj = method.__self__ diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index bfb319a8..ee67cb83 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -72,15 +72,15 @@ struct Nac3 { } impl Nac3 { - fn register_module_impl(&mut self, obj: PyObject) -> PyResult<()> { + fn register_module(&mut self, module: PyObject, registered_class_ids: &HashSet) -> PyResult<()> { let mut name_to_pyid: HashMap = HashMap::new(); let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { - let obj: &PyAny = obj.extract(py)?; + let module: &PyAny = module.extract(py)?; let builtins = PyModule::import(py, "builtins")?; let id_fn = builtins.getattr("id")?; let members: &PyList = PyModule::import(py, "inspect")? .getattr("getmembers")? - .call1((obj,))? + .call1((module,))? .cast_as()?; for member in members.iter() { let key: &str = member.get_item(0)?.extract()?; @@ -88,8 +88,8 @@ impl Nac3 { name_to_pyid.insert(key.into(), val); } Ok(( - obj.getattr("__name__")?.extract()?, - obj.getattr("__file__")?.extract()?, + module.getattr("__name__")?.extract()?, + module.getattr("__file__")?.extract()?, )) })?; @@ -107,7 +107,7 @@ impl Nac3 { global_value_ids: self.global_value_ids.clone(), class_names: Default::default(), name_to_pyid: name_to_pyid.clone(), - module: obj, + module: module.clone(), }) as Arc; let mut name_to_def = HashMap::new(); let mut name_to_type = HashMap::new(); @@ -117,6 +117,7 @@ impl Nac3 { ast::StmtKind::ClassDef { ref decorator_list, ref mut body, + ref mut bases, .. } => { let kernels = decorator_list.iter().any(|decorator| { @@ -126,6 +127,20 @@ impl Nac3 { false } }); + // Drop unregistered (i.e. host-only) base classes. + bases.retain(|base| { + Python::with_gil(|py| -> PyResult { + let id_fn = PyModule::import(py, "builtins")?.getattr("id")?; + match &base.node { + ast::ExprKind::Name { id, .. } => { + let base_obj = module.getattr(py, id.to_string())?; + let base_id = id_fn.call1((base_obj,))?.extract()?; + Ok(registered_class_ids.contains(&base_id)) + }, + _ => Ok(true) + } + }).unwrap() + }); body.retain(|stmt| { if let ast::StmtKind::FunctionDef { ref decorator_list, .. @@ -303,9 +318,28 @@ impl Nac3 { }) } - fn analyze_modules(&mut self, modules: &PySet) -> PyResult<()> { - for obj in modules.iter() { - self.register_module_impl(obj.into())?; + fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> { + let (modules, class_ids) = Python::with_gil(|py| -> PyResult<(HashMap, HashSet)> { + let mut modules: HashMap = HashMap::new(); + let mut class_ids: HashSet = HashSet::new(); + + let id_fn = PyModule::import(py, "builtins")?.getattr("id")?; + let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?; + + for function in functions.iter() { + let module = getmodule_fn.call1((function,))?.extract()?; + modules.insert(id_fn.call1((&module,))?.extract()?, module); + } + for class in classes.iter() { + let module = getmodule_fn.call1((class,))?.extract()?; + modules.insert(id_fn.call1((&module,))?.extract()?, module); + class_ids.insert(id_fn.call1((class,))?.extract()?); + } + Ok((modules, class_ids)) + })?; + + for module in modules.into_values() { + self.register_module(module, &class_ids)?; } Ok(()) }