nac3artiq: parse whole Python module, filter ast

This commit is contained in:
Sebastien Bourdeauducq 2021-09-27 22:12:25 +08:00
parent 8d839db553
commit 6141f01180
2 changed files with 59 additions and 33 deletions

View File

@ -1,4 +1,4 @@
from inspect import isclass from inspect import isclass, getmodule
from functools import wraps from functools import wraps
import nac3artiq import nac3artiq
@ -8,28 +8,28 @@ __all__ = ["extern", "kernel"]
nac3 = nac3artiq.NAC3() nac3 = nac3artiq.NAC3()
allow_object_registration = True allow_module_registration = True
def extern(function): def extern(function):
assert allow_object_registration assert allow_module_registration
nac3.register_object(function) nac3.register_module(getmodule(function))
return function return function
def kernel(function_or_class): def kernel(class_or_function):
global allow_object_registration global allow_module_registration
if isclass(function_or_class): assert allow_module_registration
assert allow_object_registration nac3.register_module(getmodule(class_or_function))
nac3.register_object(function_or_class) if isclass(class_or_function):
return function_or_class return class_or_function
else: else:
@wraps(function_or_class) @wraps(class_or_function)
def run_on_core(self, *args, **kwargs): def run_on_core(self, *args, **kwargs):
global allow_object_registration global allow_module_registration
if allow_object_registration: if allow_module_registration:
nac3.analyze() nac3.analyze()
allow_object_registration = False allow_module_registration = False
nac3.compile_method(self.__class__.__name__, function_or_class.__name__) nac3.compile_method(self.__class__.__name__, class_or_function.__name__)
return run_on_core return run_on_core

View File

@ -1,3 +1,4 @@
use std::fs;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::path::Path; use std::path::Path;
@ -5,7 +6,7 @@ use std::process::Command;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::exceptions; use pyo3::exceptions;
use rustpython_parser::parser; use rustpython_parser::{ast, parser};
use inkwell::{ use inkwell::{
passes::{PassManager, PassManagerBuilder}, passes::{PassManager, PassManagerBuilder},
targets::*, targets::*,
@ -31,7 +32,8 @@ struct Nac3 {
internal_resolver: Arc<ResolverInternal>, internal_resolver: Arc<ResolverInternal>,
resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>, resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
composer: TopLevelComposer, composer: TopLevelComposer,
top_level: Option<Arc<TopLevelContext>> top_level: Option<Arc<TopLevelContext>>,
registered_module_ids: HashSet<u64>
} }
#[pymethods] #[pymethods]
@ -63,23 +65,47 @@ impl Nac3 {
internal_resolver, internal_resolver,
resolver, resolver,
composer, composer,
top_level: None top_level: None,
registered_module_ids: HashSet::new()
} }
} }
fn register_object(&mut self, obj: PyObject) -> PyResult<()> { fn register_module(&mut self, obj: PyObject) -> PyResult<()> {
Python::with_gil(|py| -> PyResult<()> { let module_info = Python::with_gil(|py| -> PyResult<Option<(String, String)>> {
let obj: &PyAny = obj.extract(py)?; let obj: &PyAny = obj.extract(py)?;
let builtins = PyModule::import(py, "builtins")?;
let id = builtins.getattr("id")?.call1((obj, ))?.extract()?;
if self.registered_module_ids.insert(id) {
Ok(Some((obj.getattr("__name__")?.extract()?, obj.getattr("__file__")?.extract()?)))
} else {
Ok(None)
}
})?;
let source = PyModule::import(py, "inspect")?.getattr("getsource")?.call1((obj, ))?.extract()?; if let Some((module_name, source_file)) = module_info {
let parser_result = parser::parse_program(source).map_err(|e| let source = fs::read_to_string(source_file).map_err(|e|
exceptions::PyIOError::new_err(format!("failed to read input file: {}", e)))?;
let parser_result = parser::parse_program(&source).map_err(|e|
exceptions::PySyntaxError::new_err(format!("failed to parse host object source: {}", e)))?; exceptions::PySyntaxError::new_err(format!("failed to parse host object source: {}", e)))?;
for stmt in parser_result.into_iter() { for stmt in parser_result.into_iter() {
let include = match &stmt.node {
ast::StmtKind::ClassDef { decorator_list, .. } => {
decorator_list.iter().any(|decorator| if let ast::ExprKind::Name { id, .. } = decorator.node
{ id.to_string() == "kernel" || id.to_string() == "portable" } else { false })
},
ast::StmtKind::FunctionDef { decorator_list, .. } => {
decorator_list.iter().any(|decorator| if let ast::ExprKind::Name { id, .. } = decorator.node
{ id.to_string() == "extern" || id.to_string() == "portable" } else { false })
},
_ => false
};
if include {
let (name, def_id, ty) = self.composer.register_top_level( let (name, def_id, ty) = self.composer.register_top_level(
stmt, stmt,
Some(self.resolver.clone()), Some(self.resolver.clone()),
"__main__".into(), module_name.clone(),
).unwrap(); ).unwrap();
self.internal_resolver.add_id_def(name.clone(), def_id); self.internal_resolver.add_id_def(name.clone(), def_id);
@ -87,9 +113,9 @@ impl Nac3 {
self.internal_resolver.add_id_type(name, ty); self.internal_resolver.add_id_type(name, ty);
} }
} }
}
}
Ok(()) Ok(())
})
} }
fn analyze(&mut self) -> PyResult<()> { fn analyze(&mut self) -> PyResult<()> {