nac3artiq: drop host-only base classes. Closes #80
This commit is contained in:
parent
7fc04936cb
commit
c004da85f7
@ -1,4 +1,4 @@
|
||||
from inspect import getfullargspec, getmodule
|
||||
from inspect import getfullargspec
|
||||
from functools import wraps
|
||||
from types import SimpleNamespace
|
||||
from numpy import int32, int64
|
||||
@ -11,34 +11,38 @@ __all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3",
|
||||
"Core", "TTLOut", "parallel", "sequential"]
|
||||
|
||||
|
||||
import device_db
|
||||
core_arguments = device_db.device_db["core"]["arguments"]
|
||||
|
||||
compiler = nac3artiq.NAC3(core_arguments["target"])
|
||||
allow_module_registration = True
|
||||
registered_modules = set()
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
class KernelInvariant(Generic[T]):
|
||||
pass
|
||||
|
||||
|
||||
def register_module_of(obj):
|
||||
assert allow_module_registration
|
||||
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.
|
||||
registered_modules.add(getmodule(obj))
|
||||
import device_db
|
||||
core_arguments = device_db.device_db["core"]["arguments"]
|
||||
|
||||
compiler = nac3artiq.NAC3(core_arguments["target"])
|
||||
allow_registration = True
|
||||
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.
|
||||
registered_functions = set()
|
||||
registered_classes = set()
|
||||
|
||||
def register_function(fun):
|
||||
assert allow_registration
|
||||
registered_functions.add(fun)
|
||||
|
||||
def register_class(cls):
|
||||
assert allow_registration
|
||||
registered_classes.add(cls)
|
||||
|
||||
|
||||
def extern(function):
|
||||
"""Decorates a function declaration defined by the core device runtime."""
|
||||
register_module_of(function)
|
||||
register_function(function)
|
||||
return function
|
||||
|
||||
|
||||
def kernel(function_or_method):
|
||||
"""Decorates a function or method to be executed on the core device."""
|
||||
register_module_of(function_or_method)
|
||||
register_function(function_or_method)
|
||||
argspec = getfullargspec(function_or_method)
|
||||
if argspec.args and argspec.args[0] == "self":
|
||||
@wraps(function_or_method)
|
||||
@ -54,7 +58,7 @@ def kernel(function_or_method):
|
||||
|
||||
def portable(function):
|
||||
"""Decorates a function or method to be executed on the same device (host/core device) as the caller."""
|
||||
register_module_of(function)
|
||||
register_function(function)
|
||||
return function
|
||||
|
||||
|
||||
@ -63,7 +67,7 @@ def nac3(cls):
|
||||
Decorates a class to be analyzed by NAC3.
|
||||
All classes containing kernels or portable methods must use this decorator.
|
||||
"""
|
||||
register_module_of(cls)
|
||||
register_class(cls)
|
||||
return cls
|
||||
|
||||
|
||||
@ -104,10 +108,10 @@ class Core:
|
||||
self.ref_period = core_arguments["ref_period"]
|
||||
|
||||
def run(self, method, *args, **kwargs):
|
||||
global allow_module_registration
|
||||
if allow_module_registration:
|
||||
compiler.analyze_modules(registered_modules)
|
||||
allow_module_registration = False
|
||||
global allow_registration
|
||||
if allow_registration:
|
||||
compiler.analyze(registered_functions, registered_classes)
|
||||
allow_registration = False
|
||||
|
||||
if hasattr(method, "__self__"):
|
||||
obj = method.__self__
|
||||
|
@ -72,15 +72,15 @@ struct Nac3 {
|
||||
}
|
||||
|
||||
impl Nac3 {
|
||||
fn register_module_impl(&mut self, obj: PyObject) -> PyResult<()> {
|
||||
fn register_module(&mut self, module: PyObject, registered_class_ids: &HashSet<u64>) -> PyResult<()> {
|
||||
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
|
||||
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
|
||||
let obj: &PyAny = obj.extract(py)?;
|
||||
let module: &PyAny = module.extract(py)?;
|
||||
let builtins = PyModule::import(py, "builtins")?;
|
||||
let id_fn = builtins.getattr("id")?;
|
||||
let members: &PyList = PyModule::import(py, "inspect")?
|
||||
.getattr("getmembers")?
|
||||
.call1((obj,))?
|
||||
.call1((module,))?
|
||||
.cast_as()?;
|
||||
for member in members.iter() {
|
||||
let key: &str = member.get_item(0)?.extract()?;
|
||||
@ -88,8 +88,8 @@ impl Nac3 {
|
||||
name_to_pyid.insert(key.into(), val);
|
||||
}
|
||||
Ok((
|
||||
obj.getattr("__name__")?.extract()?,
|
||||
obj.getattr("__file__")?.extract()?,
|
||||
module.getattr("__name__")?.extract()?,
|
||||
module.getattr("__file__")?.extract()?,
|
||||
))
|
||||
})?;
|
||||
|
||||
@ -107,7 +107,7 @@ impl Nac3 {
|
||||
global_value_ids: self.global_value_ids.clone(),
|
||||
class_names: Default::default(),
|
||||
name_to_pyid: name_to_pyid.clone(),
|
||||
module: obj,
|
||||
module: module.clone(),
|
||||
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
let mut name_to_def = HashMap::new();
|
||||
let mut name_to_type = HashMap::new();
|
||||
@ -117,6 +117,7 @@ impl Nac3 {
|
||||
ast::StmtKind::ClassDef {
|
||||
ref decorator_list,
|
||||
ref mut body,
|
||||
ref mut bases,
|
||||
..
|
||||
} => {
|
||||
let kernels = decorator_list.iter().any(|decorator| {
|
||||
@ -126,6 +127,20 @@ impl Nac3 {
|
||||
false
|
||||
}
|
||||
});
|
||||
// Drop unregistered (i.e. host-only) base classes.
|
||||
bases.retain(|base| {
|
||||
Python::with_gil(|py| -> PyResult<bool> {
|
||||
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
|
||||
match &base.node {
|
||||
ast::ExprKind::Name { id, .. } => {
|
||||
let base_obj = module.getattr(py, id.to_string())?;
|
||||
let base_id = id_fn.call1((base_obj,))?.extract()?;
|
||||
Ok(registered_class_ids.contains(&base_id))
|
||||
},
|
||||
_ => Ok(true)
|
||||
}
|
||||
}).unwrap()
|
||||
});
|
||||
body.retain(|stmt| {
|
||||
if let ast::StmtKind::FunctionDef {
|
||||
ref decorator_list, ..
|
||||
@ -303,9 +318,28 @@ impl Nac3 {
|
||||
})
|
||||
}
|
||||
|
||||
fn analyze_modules(&mut self, modules: &PySet) -> PyResult<()> {
|
||||
for obj in modules.iter() {
|
||||
self.register_module_impl(obj.into())?;
|
||||
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
|
||||
let (modules, class_ids) = Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
|
||||
let mut modules: HashMap<u64, PyObject> = HashMap::new();
|
||||
let mut class_ids: HashSet<u64> = HashSet::new();
|
||||
|
||||
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
|
||||
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
|
||||
|
||||
for function in functions.iter() {
|
||||
let module = getmodule_fn.call1((function,))?.extract()?;
|
||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||
}
|
||||
for class in classes.iter() {
|
||||
let module = getmodule_fn.call1((class,))?.extract()?;
|
||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||
class_ids.insert(id_fn.call1((class,))?.extract()?);
|
||||
}
|
||||
Ok((modules, class_ids))
|
||||
})?;
|
||||
|
||||
for module in modules.into_values() {
|
||||
self.register_module(module, &class_ids)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user