forked from M-Labs/nac3
nac3artiq: drop host-only base classes. Closes #80
This commit is contained in:
parent
7fc04936cb
commit
c004da85f7
|
@ -1,4 +1,4 @@
|
||||||
from inspect import getfullargspec, getmodule
|
from inspect import getfullargspec
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from numpy import int32, int64
|
from numpy import int32, int64
|
||||||
|
@ -11,34 +11,38 @@ __all__ = ["KernelInvariant", "extern", "kernel", "portable", "nac3",
|
||||||
"Core", "TTLOut", "parallel", "sequential"]
|
"Core", "TTLOut", "parallel", "sequential"]
|
||||||
|
|
||||||
|
|
||||||
import device_db
|
|
||||||
core_arguments = device_db.device_db["core"]["arguments"]
|
|
||||||
|
|
||||||
compiler = nac3artiq.NAC3(core_arguments["target"])
|
|
||||||
allow_module_registration = True
|
|
||||||
registered_modules = set()
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
class KernelInvariant(Generic[T]):
|
class KernelInvariant(Generic[T]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def register_module_of(obj):
|
import device_db
|
||||||
assert allow_module_registration
|
core_arguments = device_db.device_db["core"]["arguments"]
|
||||||
|
|
||||||
|
compiler = nac3artiq.NAC3(core_arguments["target"])
|
||||||
|
allow_registration = True
|
||||||
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.
|
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.
|
||||||
registered_modules.add(getmodule(obj))
|
registered_functions = set()
|
||||||
|
registered_classes = set()
|
||||||
|
|
||||||
|
def register_function(fun):
|
||||||
|
assert allow_registration
|
||||||
|
registered_functions.add(fun)
|
||||||
|
|
||||||
|
def register_class(cls):
|
||||||
|
assert allow_registration
|
||||||
|
registered_classes.add(cls)
|
||||||
|
|
||||||
|
|
||||||
def extern(function):
|
def extern(function):
|
||||||
"""Decorates a function declaration defined by the core device runtime."""
|
"""Decorates a function declaration defined by the core device runtime."""
|
||||||
register_module_of(function)
|
register_function(function)
|
||||||
return function
|
return function
|
||||||
|
|
||||||
|
|
||||||
def kernel(function_or_method):
|
def kernel(function_or_method):
|
||||||
"""Decorates a function or method to be executed on the core device."""
|
"""Decorates a function or method to be executed on the core device."""
|
||||||
register_module_of(function_or_method)
|
register_function(function_or_method)
|
||||||
argspec = getfullargspec(function_or_method)
|
argspec = getfullargspec(function_or_method)
|
||||||
if argspec.args and argspec.args[0] == "self":
|
if argspec.args and argspec.args[0] == "self":
|
||||||
@wraps(function_or_method)
|
@wraps(function_or_method)
|
||||||
|
@ -54,7 +58,7 @@ def kernel(function_or_method):
|
||||||
|
|
||||||
def portable(function):
|
def portable(function):
|
||||||
"""Decorates a function or method to be executed on the same device (host/core device) as the caller."""
|
"""Decorates a function or method to be executed on the same device (host/core device) as the caller."""
|
||||||
register_module_of(function)
|
register_function(function)
|
||||||
return function
|
return function
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,7 +67,7 @@ def nac3(cls):
|
||||||
Decorates a class to be analyzed by NAC3.
|
Decorates a class to be analyzed by NAC3.
|
||||||
All classes containing kernels or portable methods must use this decorator.
|
All classes containing kernels or portable methods must use this decorator.
|
||||||
"""
|
"""
|
||||||
register_module_of(cls)
|
register_class(cls)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,10 +108,10 @@ class Core:
|
||||||
self.ref_period = core_arguments["ref_period"]
|
self.ref_period = core_arguments["ref_period"]
|
||||||
|
|
||||||
def run(self, method, *args, **kwargs):
|
def run(self, method, *args, **kwargs):
|
||||||
global allow_module_registration
|
global allow_registration
|
||||||
if allow_module_registration:
|
if allow_registration:
|
||||||
compiler.analyze_modules(registered_modules)
|
compiler.analyze(registered_functions, registered_classes)
|
||||||
allow_module_registration = False
|
allow_registration = False
|
||||||
|
|
||||||
if hasattr(method, "__self__"):
|
if hasattr(method, "__self__"):
|
||||||
obj = method.__self__
|
obj = method.__self__
|
||||||
|
|
|
@ -72,15 +72,15 @@ struct Nac3 {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Nac3 {
|
impl Nac3 {
|
||||||
fn register_module_impl(&mut self, obj: PyObject) -> PyResult<()> {
|
fn register_module(&mut self, module: PyObject, registered_class_ids: &HashSet<u64>) -> PyResult<()> {
|
||||||
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
|
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
|
||||||
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
|
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
|
||||||
let obj: &PyAny = obj.extract(py)?;
|
let module: &PyAny = module.extract(py)?;
|
||||||
let builtins = PyModule::import(py, "builtins")?;
|
let builtins = PyModule::import(py, "builtins")?;
|
||||||
let id_fn = builtins.getattr("id")?;
|
let id_fn = builtins.getattr("id")?;
|
||||||
let members: &PyList = PyModule::import(py, "inspect")?
|
let members: &PyList = PyModule::import(py, "inspect")?
|
||||||
.getattr("getmembers")?
|
.getattr("getmembers")?
|
||||||
.call1((obj,))?
|
.call1((module,))?
|
||||||
.cast_as()?;
|
.cast_as()?;
|
||||||
for member in members.iter() {
|
for member in members.iter() {
|
||||||
let key: &str = member.get_item(0)?.extract()?;
|
let key: &str = member.get_item(0)?.extract()?;
|
||||||
|
@ -88,8 +88,8 @@ impl Nac3 {
|
||||||
name_to_pyid.insert(key.into(), val);
|
name_to_pyid.insert(key.into(), val);
|
||||||
}
|
}
|
||||||
Ok((
|
Ok((
|
||||||
obj.getattr("__name__")?.extract()?,
|
module.getattr("__name__")?.extract()?,
|
||||||
obj.getattr("__file__")?.extract()?,
|
module.getattr("__file__")?.extract()?,
|
||||||
))
|
))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ impl Nac3 {
|
||||||
global_value_ids: self.global_value_ids.clone(),
|
global_value_ids: self.global_value_ids.clone(),
|
||||||
class_names: Default::default(),
|
class_names: Default::default(),
|
||||||
name_to_pyid: name_to_pyid.clone(),
|
name_to_pyid: name_to_pyid.clone(),
|
||||||
module: obj,
|
module: module.clone(),
|
||||||
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||||
let mut name_to_def = HashMap::new();
|
let mut name_to_def = HashMap::new();
|
||||||
let mut name_to_type = HashMap::new();
|
let mut name_to_type = HashMap::new();
|
||||||
|
@ -117,6 +117,7 @@ impl Nac3 {
|
||||||
ast::StmtKind::ClassDef {
|
ast::StmtKind::ClassDef {
|
||||||
ref decorator_list,
|
ref decorator_list,
|
||||||
ref mut body,
|
ref mut body,
|
||||||
|
ref mut bases,
|
||||||
..
|
..
|
||||||
} => {
|
} => {
|
||||||
let kernels = decorator_list.iter().any(|decorator| {
|
let kernels = decorator_list.iter().any(|decorator| {
|
||||||
|
@ -126,6 +127,20 @@ impl Nac3 {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
// Drop unregistered (i.e. host-only) base classes.
|
||||||
|
bases.retain(|base| {
|
||||||
|
Python::with_gil(|py| -> PyResult<bool> {
|
||||||
|
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
|
||||||
|
match &base.node {
|
||||||
|
ast::ExprKind::Name { id, .. } => {
|
||||||
|
let base_obj = module.getattr(py, id.to_string())?;
|
||||||
|
let base_id = id_fn.call1((base_obj,))?.extract()?;
|
||||||
|
Ok(registered_class_ids.contains(&base_id))
|
||||||
|
},
|
||||||
|
_ => Ok(true)
|
||||||
|
}
|
||||||
|
}).unwrap()
|
||||||
|
});
|
||||||
body.retain(|stmt| {
|
body.retain(|stmt| {
|
||||||
if let ast::StmtKind::FunctionDef {
|
if let ast::StmtKind::FunctionDef {
|
||||||
ref decorator_list, ..
|
ref decorator_list, ..
|
||||||
|
@ -303,9 +318,28 @@ impl Nac3 {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn analyze_modules(&mut self, modules: &PySet) -> PyResult<()> {
|
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
|
||||||
for obj in modules.iter() {
|
let (modules, class_ids) = Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
|
||||||
self.register_module_impl(obj.into())?;
|
let mut modules: HashMap<u64, PyObject> = HashMap::new();
|
||||||
|
let mut class_ids: HashSet<u64> = HashSet::new();
|
||||||
|
|
||||||
|
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
|
||||||
|
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
|
||||||
|
|
||||||
|
for function in functions.iter() {
|
||||||
|
let module = getmodule_fn.call1((function,))?.extract()?;
|
||||||
|
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||||
|
}
|
||||||
|
for class in classes.iter() {
|
||||||
|
let module = getmodule_fn.call1((class,))?.extract()?;
|
||||||
|
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||||
|
class_ids.insert(id_fn.call1((class,))?.extract()?);
|
||||||
|
}
|
||||||
|
Ok((modules, class_ids))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
for module in modules.into_values() {
|
||||||
|
self.register_module(module, &class_ids)?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue