nac3artiq: filter out base class not annotated with nac3

This commit is contained in:
ychenfo 2021-11-11 04:14:55 +08:00
parent 12ab8bcd39
commit a2da1ecf05
2 changed files with 36 additions and 7 deletions

View File

@ -23,7 +23,7 @@ class virtual(Generic[T]):
compiler = nac3artiq.NAC3(core_arguments["target"]) compiler = nac3artiq.NAC3(core_arguments["target"])
allow_module_registration = True allow_module_registration = True
registered_modules = set() registered_modules = set()
nac3annotated_class_ids = set()
class KernelInvariant(Generic[T]): class KernelInvariant(Generic[T]):
pass pass
@ -69,6 +69,7 @@ def nac3(cls):
All classes containing kernels or portable methods must use this decorator. All classes containing kernels or portable methods must use this decorator.
""" """
register_module_of(cls) register_module_of(cls)
nac3annotated_class_ids.add(id(cls))
return cls return cls
@ -111,7 +112,7 @@ class Core:
def run(self, method, *args, **kwargs): def run(self, method, *args, **kwargs):
global allow_module_registration global allow_module_registration
if allow_module_registration: if allow_module_registration:
compiler.analyze_modules(registered_modules) compiler.analyze_modules(registered_modules, nac3annotated_class_ids)
allow_module_registration = False allow_module_registration = False
if hasattr(method, "__self__"): if hasattr(method, "__self__"):

View File

@ -11,7 +11,7 @@ use inkwell::{
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::{exceptions, types::PyList, types::PySet, types::PyBytes}; use pyo3::{exceptions, types::PyList, types::PySet, types::PyBytes};
use nac3parser::{ use nac3parser::{
ast::{self, StrRef}, ast::{self, StrRef, Constant::Str},
parser::{self, parse_program}, parser::{self, parse_program},
}; };
@ -76,7 +76,7 @@ struct Nac3 {
} }
impl Nac3 { impl Nac3 {
fn register_module_impl(&mut self, obj: PyObject) -> PyResult<()> { fn register_module_impl(&mut self, obj: PyObject, nac3_annotated_cls: &PySet) -> PyResult<()> {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new(); let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
let obj: &PyAny = obj.extract(py)?; let obj: &PyAny = obj.extract(py)?;
@ -111,7 +111,7 @@ impl Nac3 {
global_value_ids: self.global_value_ids.clone(), global_value_ids: self.global_value_ids.clone(),
class_names: Default::default(), class_names: Default::default(),
name_to_pyid: name_to_pyid.clone(), name_to_pyid: name_to_pyid.clone(),
module: obj, module: obj.clone(),
}) as Arc<dyn SymbolResolver + Send + Sync>; }) as Arc<dyn SymbolResolver + Send + Sync>;
let mut name_to_def = HashMap::new(); let mut name_to_def = HashMap::new();
let mut name_to_type = HashMap::new(); let mut name_to_type = HashMap::new();
@ -121,6 +121,7 @@ impl Nac3 {
ast::StmtKind::ClassDef { ast::StmtKind::ClassDef {
ref decorator_list, ref decorator_list,
ref mut body, ref mut body,
ref mut bases,
.. ..
} => { } => {
let kernels = decorator_list.iter().any(|decorator| { let kernels = decorator_list.iter().any(|decorator| {
@ -146,6 +147,33 @@ impl Nac3 {
true true
} }
}); });
bases.retain(|b| {
Python::with_gil(|py| -> PyResult<bool> {
let obj: &PyAny = obj.extract(py)?;
let annot_check = |id: &str| -> bool {
let id = py.eval(
&format!("id({})", id),
Some(obj.getattr("__dict__").unwrap().extract().unwrap()),
None
).unwrap();
nac3_annotated_cls.contains(id).unwrap()
};
match &b.node {
ast::ExprKind::Name { id, .. } => Ok(annot_check(&id.to_string())),
ast::ExprKind::Constant { value: Str(id), .. } =>
Ok(annot_check(id.split('[').next().unwrap())),
ast::ExprKind::Subscript { value, .. } => {
match &value.node {
ast::ExprKind::Name { id, .. } => Ok(annot_check(&id.to_string()) || *id == "Generic".into()),
ast::ExprKind::Constant { value: Str(id), .. } =>
Ok(annot_check(id.split('[').next().unwrap())),
_ => unreachable!("unsupported base declaration")
}
}
_ => unreachable!("unsupported base declaration")
}
}).unwrap()
});
kernels kernels
} }
ast::StmtKind::FunctionDef { ast::StmtKind::FunctionDef {
@ -336,9 +364,9 @@ impl Nac3 {
}) })
} }
fn analyze_modules(&mut self, modules: &PySet) -> PyResult<()> { fn analyze_modules(&mut self, modules: &PySet, nac3_annotated_cls: &PySet) -> PyResult<()> {
for obj in modules.iter() { for obj in modules.iter() {
self.register_module_impl(obj.into())?; self.register_module_impl(obj.into(), nac3_annotated_cls)?;
} }
Ok(()) Ok(())
} }