1
0
forked from M-Labs/nac3

nac3artiq: parse whole Python module, filter ast

This commit is contained in:
Sebastien Bourdeauducq 2021-09-27 22:12:25 +08:00
parent 8d839db553
commit 6141f01180
2 changed files with 59 additions and 33 deletions

View File

@ -1,4 +1,4 @@
from inspect import isclass
from inspect import isclass, getmodule
from functools import wraps
import nac3artiq
@ -8,28 +8,28 @@ __all__ = ["extern", "kernel"]
nac3 = nac3artiq.NAC3()
allow_object_registration = True
allow_module_registration = True
def extern(function):
assert allow_object_registration
nac3.register_object(function)
assert allow_module_registration
nac3.register_module(getmodule(function))
return function
def kernel(function_or_class):
global allow_object_registration
def kernel(class_or_function):
global allow_module_registration
if isclass(function_or_class):
assert allow_object_registration
nac3.register_object(function_or_class)
return function_or_class
assert allow_module_registration
nac3.register_module(getmodule(class_or_function))
if isclass(class_or_function):
return class_or_function
else:
@wraps(function_or_class)
@wraps(class_or_function)
def run_on_core(self, *args, **kwargs):
global allow_object_registration
if allow_object_registration:
global allow_module_registration
if allow_module_registration:
nac3.analyze()
allow_object_registration = False
nac3.compile_method(self.__class__.__name__, function_or_class.__name__)
allow_module_registration = False
nac3.compile_method(self.__class__.__name__, class_or_function.__name__)
return run_on_core

View File

@ -1,3 +1,4 @@
use std::fs;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::path::Path;
@ -5,7 +6,7 @@ use std::process::Command;
use pyo3::prelude::*;
use pyo3::exceptions;
use rustpython_parser::parser;
use rustpython_parser::{ast, parser};
use inkwell::{
passes::{PassManager, PassManagerBuilder},
targets::*,
@ -31,7 +32,8 @@ struct Nac3 {
internal_resolver: Arc<ResolverInternal>,
resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
composer: TopLevelComposer,
top_level: Option<Arc<TopLevelContext>>
top_level: Option<Arc<TopLevelContext>>,
registered_module_ids: HashSet<u64>
}
#[pymethods]
@ -63,33 +65,57 @@ impl Nac3 {
internal_resolver,
resolver,
composer,
top_level: None
top_level: None,
registered_module_ids: HashSet::new()
}
}
fn register_object(&mut self, obj: PyObject) -> PyResult<()> {
Python::with_gil(|py| -> PyResult<()> {
fn register_module(&mut self, obj: PyObject) -> PyResult<()> {
let module_info = Python::with_gil(|py| -> PyResult<Option<(String, String)>> {
let obj: &PyAny = obj.extract(py)?;
let builtins = PyModule::import(py, "builtins")?;
let id = builtins.getattr("id")?.call1((obj, ))?.extract()?;
if self.registered_module_ids.insert(id) {
Ok(Some((obj.getattr("__name__")?.extract()?, obj.getattr("__file__")?.extract()?)))
} else {
Ok(None)
}
})?;
let source = PyModule::import(py, "inspect")?.getattr("getsource")?.call1((obj, ))?.extract()?;
let parser_result = parser::parse_program(source).map_err(|e|
if let Some((module_name, source_file)) = module_info {
let source = fs::read_to_string(source_file).map_err(|e|
exceptions::PyIOError::new_err(format!("failed to read input file: {}", e)))?;
let parser_result = parser::parse_program(&source).map_err(|e|
exceptions::PySyntaxError::new_err(format!("failed to parse host object source: {}", e)))?;
for stmt in parser_result.into_iter() {
let (name, def_id, ty) = self.composer.register_top_level(
stmt,
Some(self.resolver.clone()),
"__main__".into(),
).unwrap();
let include = match &stmt.node {
ast::StmtKind::ClassDef { decorator_list, .. } => {
decorator_list.iter().any(|decorator| if let ast::ExprKind::Name { id, .. } = decorator.node
{ id.to_string() == "kernel" || id.to_string() == "portable" } else { false })
},
ast::StmtKind::FunctionDef { decorator_list, .. } => {
decorator_list.iter().any(|decorator| if let ast::ExprKind::Name { id, .. } = decorator.node
{ id.to_string() == "extern" || id.to_string() == "portable" } else { false })
},
_ => false
};
self.internal_resolver.add_id_def(name.clone(), def_id);
if let Some(ty) = ty {
self.internal_resolver.add_id_type(name, ty);
if include {
let (name, def_id, ty) = self.composer.register_top_level(
stmt,
Some(self.resolver.clone()),
module_name.clone(),
).unwrap();
self.internal_resolver.add_id_def(name.clone(), def_id);
if let Some(ty) = ty {
self.internal_resolver.add_id_type(name, ty);
}
}
}
Ok(())
})
}
Ok(())
}
fn analyze(&mut self) -> PyResult<()> {