nac3artiq: drop host-only base classes. Closes #80

This commit is contained in:
Sebastien Bourdeauducq 2021-11-11 16:08:29 +08:00
parent 7fc04936cb
commit c004da85f7
2 changed files with 68 additions and 30 deletions

View File

@ -1,4 +1,4 @@
from inspect import getfullargspec, getmodule
from inspect import getfullargspec
from functools import wraps
from types import SimpleNamespace
from numpy import int32, int64
@ -11,34 +11,38 @@ __all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3",
"Core", "TTLOut", "parallel", "sequential"]
import device_db
core_arguments = device_db.device_db["core"]["arguments"]
compiler = nac3artiq.NAC3(core_arguments["target"])
allow_module_registration = True
registered_modules = set()
T = TypeVar('T')
class KernelInvariant(Generic[T]):
pass
def register_module_of(obj):
assert allow_module_registration
import device_db
core_arguments = device_db.device_db["core"]["arguments"]
compiler = nac3artiq.NAC3(core_arguments["target"])
allow_registration = True
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.
registered_modules.add(getmodule(obj))
registered_functions = set()
registered_classes = set()
def register_function(fun):
assert allow_registration
registered_functions.add(fun)
def register_class(cls):
assert allow_registration
registered_classes.add(cls)
def extern(function):
"""Decorates a function declaration defined by the core device runtime."""
register_module_of(function)
register_function(function)
return function
def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device."""
register_module_of(function_or_method)
register_function(function_or_method)
argspec = getfullargspec(function_or_method)
if argspec.args and argspec.args[0] == "self":
@wraps(function_or_method)
@ -54,7 +58,7 @@ def kernel(function_or_method):
def portable(function):
"""Decorates a function or method to be executed on the same device (host/core device) as the caller."""
register_module_of(function)
register_function(function)
return function
@ -63,7 +67,7 @@ def nac3(cls):
Decorates a class to be analyzed by NAC3.
All classes containing kernels or portable methods must use this decorator.
"""
register_module_of(cls)
register_class(cls)
return cls
@ -104,10 +108,10 @@ class Core:
self.ref_period = core_arguments["ref_period"]
def run(self, method, *args, **kwargs):
global allow_module_registration
if allow_module_registration:
compiler.analyze_modules(registered_modules)
allow_module_registration = False
global allow_registration
if allow_registration:
compiler.analyze(registered_functions, registered_classes)
allow_registration = False
if hasattr(method, "__self__"):
obj = method.__self__

View File

@ -72,15 +72,15 @@ struct Nac3 {
}
impl Nac3 {
fn register_module_impl(&mut self, obj: PyObject) -> PyResult<()> {
fn register_module(&mut self, module: PyObject, registered_class_ids: &HashSet<u64>) -> PyResult<()> {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
let obj: &PyAny = obj.extract(py)?;
let module: &PyAny = module.extract(py)?;
let builtins = PyModule::import(py, "builtins")?;
let id_fn = builtins.getattr("id")?;
let members: &PyList = PyModule::import(py, "inspect")?
.getattr("getmembers")?
.call1((obj,))?
.call1((module,))?
.cast_as()?;
for member in members.iter() {
let key: &str = member.get_item(0)?.extract()?;
@ -88,8 +88,8 @@ impl Nac3 {
name_to_pyid.insert(key.into(), val);
}
Ok((
obj.getattr("__name__")?.extract()?,
obj.getattr("__file__")?.extract()?,
module.getattr("__name__")?.extract()?,
module.getattr("__file__")?.extract()?,
))
})?;
@ -107,7 +107,7 @@ impl Nac3 {
global_value_ids: self.global_value_ids.clone(),
class_names: Default::default(),
name_to_pyid: name_to_pyid.clone(),
module: obj,
module: module.clone(),
}) as Arc<dyn SymbolResolver + Send + Sync>;
let mut name_to_def = HashMap::new();
let mut name_to_type = HashMap::new();
@ -117,6 +117,7 @@ impl Nac3 {
ast::StmtKind::ClassDef {
ref decorator_list,
ref mut body,
ref mut bases,
..
} => {
let kernels = decorator_list.iter().any(|decorator| {
@ -126,6 +127,20 @@ impl Nac3 {
false
}
});
// Drop unregistered (i.e. host-only) base classes.
bases.retain(|base| {
Python::with_gil(|py| -> PyResult<bool> {
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
match &base.node {
ast::ExprKind::Name { id, .. } => {
let base_obj = module.getattr(py, id.to_string())?;
let base_id = id_fn.call1((base_obj,))?.extract()?;
Ok(registered_class_ids.contains(&base_id))
},
_ => Ok(true)
}
}).unwrap()
});
body.retain(|stmt| {
if let ast::StmtKind::FunctionDef {
ref decorator_list, ..
@ -303,9 +318,28 @@ impl Nac3 {
})
}
fn analyze_modules(&mut self, modules: &PySet) -> PyResult<()> {
for obj in modules.iter() {
self.register_module_impl(obj.into())?;
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
let (modules, class_ids) = Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
let mut modules: HashMap<u64, PyObject> = HashMap::new();
let mut class_ids: HashSet<u64> = HashSet::new();
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
for function in functions.iter() {
let module = getmodule_fn.call1((function,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
for class in classes.iter() {
let module = getmodule_fn.call1((class,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
class_ids.insert(id_fn.call1((class,))?.extract()?);
}
Ok((modules, class_ids))
})?;
for module in modules.into_values() {
self.register_module(module, &class_ids)?;
}
Ok(())
}