nac3artiq: drop host-only base classes. Closes #80

This commit is contained in:
Sebastien Bourdeauducq 2021-11-11 16:08:29 +08:00
parent 7fc04936cb
commit c004da85f7
2 changed files with 68 additions and 30 deletions

View File

@ -1,4 +1,4 @@
from inspect import getfullargspec, getmodule from inspect import getfullargspec
from functools import wraps from functools import wraps
from types import SimpleNamespace from types import SimpleNamespace
from numpy import int32, int64 from numpy import int32, int64
@ -11,34 +11,38 @@ __all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3",
"Core", "TTLOut", "parallel", "sequential"] "Core", "TTLOut", "parallel", "sequential"]
import device_db
core_arguments = device_db.device_db["core"]["arguments"]
compiler = nac3artiq.NAC3(core_arguments["target"])
allow_module_registration = True
registered_modules = set()
T = TypeVar('T') T = TypeVar('T')
class KernelInvariant(Generic[T]): class KernelInvariant(Generic[T]):
pass pass
def register_module_of(obj): import device_db
assert allow_module_registration core_arguments = device_db.device_db["core"]["arguments"]
compiler = nac3artiq.NAC3(core_arguments["target"])
allow_registration = True
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side. # Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.
registered_modules.add(getmodule(obj)) registered_functions = set()
registered_classes = set()
def register_function(fun):
assert allow_registration
registered_functions.add(fun)
def register_class(cls):
assert allow_registration
registered_classes.add(cls)
def extern(function): def extern(function):
"""Decorates a function declaration defined by the core device runtime.""" """Decorates a function declaration defined by the core device runtime."""
register_module_of(function) register_function(function)
return function return function
def kernel(function_or_method): def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device.""" """Decorates a function or method to be executed on the core device."""
register_module_of(function_or_method) register_function(function_or_method)
argspec = getfullargspec(function_or_method) argspec = getfullargspec(function_or_method)
if argspec.args and argspec.args[0] == "self": if argspec.args and argspec.args[0] == "self":
@wraps(function_or_method) @wraps(function_or_method)
@ -54,7 +58,7 @@ def kernel(function_or_method):
def portable(function): def portable(function):
"""Decorates a function or method to be executed on the same device (host/core device) as the caller.""" """Decorates a function or method to be executed on the same device (host/core device) as the caller."""
register_module_of(function) register_function(function)
return function return function
@ -63,7 +67,7 @@ def nac3(cls):
Decorates a class to be analyzed by NAC3. Decorates a class to be analyzed by NAC3.
All classes containing kernels or portable methods must use this decorator. All classes containing kernels or portable methods must use this decorator.
""" """
register_module_of(cls) register_class(cls)
return cls return cls
@ -104,10 +108,10 @@ class Core:
self.ref_period = core_arguments["ref_period"] self.ref_period = core_arguments["ref_period"]
def run(self, method, *args, **kwargs): def run(self, method, *args, **kwargs):
global allow_module_registration global allow_registration
if allow_module_registration: if allow_registration:
compiler.analyze_modules(registered_modules) compiler.analyze(registered_functions, registered_classes)
allow_module_registration = False allow_registration = False
if hasattr(method, "__self__"): if hasattr(method, "__self__"):
obj = method.__self__ obj = method.__self__

View File

@ -72,15 +72,15 @@ struct Nac3 {
} }
impl Nac3 { impl Nac3 {
fn register_module_impl(&mut self, obj: PyObject) -> PyResult<()> { fn register_module(&mut self, module: PyObject, registered_class_ids: &HashSet<u64>) -> PyResult<()> {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new(); let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
let obj: &PyAny = obj.extract(py)?; let module: &PyAny = module.extract(py)?;
let builtins = PyModule::import(py, "builtins")?; let builtins = PyModule::import(py, "builtins")?;
let id_fn = builtins.getattr("id")?; let id_fn = builtins.getattr("id")?;
let members: &PyList = PyModule::import(py, "inspect")? let members: &PyList = PyModule::import(py, "inspect")?
.getattr("getmembers")? .getattr("getmembers")?
.call1((obj,))? .call1((module,))?
.cast_as()?; .cast_as()?;
for member in members.iter() { for member in members.iter() {
let key: &str = member.get_item(0)?.extract()?; let key: &str = member.get_item(0)?.extract()?;
@ -88,8 +88,8 @@ impl Nac3 {
name_to_pyid.insert(key.into(), val); name_to_pyid.insert(key.into(), val);
} }
Ok(( Ok((
obj.getattr("__name__")?.extract()?, module.getattr("__name__")?.extract()?,
obj.getattr("__file__")?.extract()?, module.getattr("__file__")?.extract()?,
)) ))
})?; })?;
@ -107,7 +107,7 @@ impl Nac3 {
global_value_ids: self.global_value_ids.clone(), global_value_ids: self.global_value_ids.clone(),
class_names: Default::default(), class_names: Default::default(),
name_to_pyid: name_to_pyid.clone(), name_to_pyid: name_to_pyid.clone(),
module: obj, module: module.clone(),
}) as Arc<dyn SymbolResolver + Send + Sync>; }) as Arc<dyn SymbolResolver + Send + Sync>;
let mut name_to_def = HashMap::new(); let mut name_to_def = HashMap::new();
let mut name_to_type = HashMap::new(); let mut name_to_type = HashMap::new();
@ -117,6 +117,7 @@ impl Nac3 {
ast::StmtKind::ClassDef { ast::StmtKind::ClassDef {
ref decorator_list, ref decorator_list,
ref mut body, ref mut body,
ref mut bases,
.. ..
} => { } => {
let kernels = decorator_list.iter().any(|decorator| { let kernels = decorator_list.iter().any(|decorator| {
@ -126,6 +127,20 @@ impl Nac3 {
false false
} }
}); });
// Drop unregistered (i.e. host-only) base classes.
bases.retain(|base| {
Python::with_gil(|py| -> PyResult<bool> {
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
match &base.node {
ast::ExprKind::Name { id, .. } => {
let base_obj = module.getattr(py, id.to_string())?;
let base_id = id_fn.call1((base_obj,))?.extract()?;
Ok(registered_class_ids.contains(&base_id))
},
_ => Ok(true)
}
}).unwrap()
});
body.retain(|stmt| { body.retain(|stmt| {
if let ast::StmtKind::FunctionDef { if let ast::StmtKind::FunctionDef {
ref decorator_list, .. ref decorator_list, ..
@ -303,9 +318,28 @@ impl Nac3 {
}) })
} }
fn analyze_modules(&mut self, modules: &PySet) -> PyResult<()> { fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
for obj in modules.iter() { let (modules, class_ids) = Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
self.register_module_impl(obj.into())?; let mut modules: HashMap<u64, PyObject> = HashMap::new();
let mut class_ids: HashSet<u64> = HashSet::new();
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
for function in functions.iter() {
let module = getmodule_fn.call1((function,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
for class in classes.iter() {
let module = getmodule_fn.call1((class,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
class_ids.insert(id_fn.call1((class,))?.extract()?);
}
Ok((modules, class_ids))
})?;
for module in modules.into_values() {
self.register_module(module, &class_ids)?;
} }
Ok(()) Ok(())
} }