nac3artiq: filter out base class not annotated with nac3

This commit is contained in:
ychenfo 2021-11-11 04:14:55 +08:00
parent 12ab8bcd39
commit a2da1ecf05
2 changed files with 36 additions and 7 deletions

View File

@ -23,7 +23,7 @@ class virtual(Generic[T]):
compiler = nac3artiq.NAC3(core_arguments["target"])
allow_module_registration = True
registered_modules = set()
nac3annotated_class_ids = set()
class KernelInvariant(Generic[T]):
pass
@ -69,6 +69,7 @@ def nac3(cls):
All classes containing kernels or portable methods must use this decorator.
"""
register_module_of(cls)
nac3annotated_class_ids.add(id(cls))
return cls
@ -111,7 +112,7 @@ class Core:
def run(self, method, *args, **kwargs):
global allow_module_registration
if allow_module_registration:
compiler.analyze_modules(registered_modules)
compiler.analyze_modules(registered_modules, nac3annotated_class_ids)
allow_module_registration = False
if hasattr(method, "__self__"):

View File

@ -11,7 +11,7 @@ use inkwell::{
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyList, types::PySet, types::PyBytes};
use nac3parser::{
ast::{self, StrRef},
ast::{self, StrRef, Constant::Str},
parser::{self, parse_program},
};
@ -76,7 +76,7 @@ struct Nac3 {
}
impl Nac3 {
fn register_module_impl(&mut self, obj: PyObject) -> PyResult<()> {
fn register_module_impl(&mut self, obj: PyObject, nac3_annotated_cls: &PySet) -> PyResult<()> {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
let obj: &PyAny = obj.extract(py)?;
@ -111,7 +111,7 @@ impl Nac3 {
global_value_ids: self.global_value_ids.clone(),
class_names: Default::default(),
name_to_pyid: name_to_pyid.clone(),
module: obj,
module: obj.clone(),
}) as Arc<dyn SymbolResolver + Send + Sync>;
let mut name_to_def = HashMap::new();
let mut name_to_type = HashMap::new();
@ -121,6 +121,7 @@ impl Nac3 {
ast::StmtKind::ClassDef {
ref decorator_list,
ref mut body,
ref mut bases,
..
} => {
let kernels = decorator_list.iter().any(|decorator| {
@ -146,6 +147,33 @@ impl Nac3 {
true
}
});
bases.retain(|b| {
Python::with_gil(|py| -> PyResult<bool> {
let obj: &PyAny = obj.extract(py)?;
let annot_check = |id: &str| -> bool {
let id = py.eval(
&format!("id({})", id),
Some(obj.getattr("__dict__").unwrap().extract().unwrap()),
None
).unwrap();
nac3_annotated_cls.contains(id).unwrap()
};
match &b.node {
ast::ExprKind::Name { id, .. } => Ok(annot_check(&id.to_string())),
ast::ExprKind::Constant { value: Str(id), .. } =>
Ok(annot_check(id.split('[').next().unwrap())),
ast::ExprKind::Subscript { value, .. } => {
match &value.node {
ast::ExprKind::Name { id, .. } => Ok(annot_check(&id.to_string()) || *id == "Generic".into()),
ast::ExprKind::Constant { value: Str(id), .. } =>
Ok(annot_check(id.split('[').next().unwrap())),
_ => unreachable!("unsupported base declaration")
}
}
_ => unreachable!("unsupported base declaration")
}
}).unwrap()
});
kernels
}
ast::StmtKind::FunctionDef {
@ -336,9 +364,9 @@ impl Nac3 {
})
}
fn analyze_modules(&mut self, modules: &PySet) -> PyResult<()> {
fn analyze_modules(&mut self, modules: &PySet, nac3_annotated_cls: &PySet) -> PyResult<()> {
for obj in modules.iter() {
self.register_module_impl(obj.into())?;
self.register_module_impl(obj.into(), nac3_annotated_cls)?;
}
Ok(())
}