forked from M-Labs/nac3
nac3artiq: parse whole Python module, filter ast
This commit is contained in:
parent
8d839db553
commit
6141f01180
@ -1,4 +1,4 @@
|
||||
from inspect import isclass
|
||||
from inspect import isclass, getmodule
|
||||
from functools import wraps
|
||||
|
||||
import nac3artiq
|
||||
@ -8,28 +8,28 @@ __all__ = ["extern", "kernel"]
|
||||
|
||||
|
||||
nac3 = nac3artiq.NAC3()
|
||||
allow_object_registration = True
|
||||
allow_module_registration = True
|
||||
|
||||
|
||||
def extern(function):
|
||||
assert allow_object_registration
|
||||
nac3.register_object(function)
|
||||
assert allow_module_registration
|
||||
nac3.register_module(getmodule(function))
|
||||
return function
|
||||
|
||||
|
||||
def kernel(function_or_class):
|
||||
global allow_object_registration
|
||||
def kernel(class_or_function):
|
||||
global allow_module_registration
|
||||
|
||||
if isclass(function_or_class):
|
||||
assert allow_object_registration
|
||||
nac3.register_object(function_or_class)
|
||||
return function_or_class
|
||||
assert allow_module_registration
|
||||
nac3.register_module(getmodule(class_or_function))
|
||||
if isclass(class_or_function):
|
||||
return class_or_function
|
||||
else:
|
||||
@wraps(function_or_class)
|
||||
@wraps(class_or_function)
|
||||
def run_on_core(self, *args, **kwargs):
|
||||
global allow_object_registration
|
||||
if allow_object_registration:
|
||||
global allow_module_registration
|
||||
if allow_module_registration:
|
||||
nac3.analyze()
|
||||
allow_object_registration = False
|
||||
nac3.compile_method(self.__class__.__name__, function_or_class.__name__)
|
||||
allow_module_registration = False
|
||||
nac3.compile_method(self.__class__.__name__, class_or_function.__name__)
|
||||
return run_on_core
|
||||
|
@ -1,3 +1,4 @@
|
||||
use std::fs;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use std::path::Path;
|
||||
@ -5,7 +6,7 @@ use std::process::Command;
|
||||
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::exceptions;
|
||||
use rustpython_parser::parser;
|
||||
use rustpython_parser::{ast, parser};
|
||||
use inkwell::{
|
||||
passes::{PassManager, PassManagerBuilder},
|
||||
targets::*,
|
||||
@ -31,7 +32,8 @@ struct Nac3 {
|
||||
internal_resolver: Arc<ResolverInternal>,
|
||||
resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
|
||||
composer: TopLevelComposer,
|
||||
top_level: Option<Arc<TopLevelContext>>
|
||||
top_level: Option<Arc<TopLevelContext>>,
|
||||
registered_module_ids: HashSet<u64>
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@ -63,33 +65,57 @@ impl Nac3 {
|
||||
internal_resolver,
|
||||
resolver,
|
||||
composer,
|
||||
top_level: None
|
||||
top_level: None,
|
||||
registered_module_ids: HashSet::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn register_object(&mut self, obj: PyObject) -> PyResult<()> {
|
||||
Python::with_gil(|py| -> PyResult<()> {
|
||||
fn register_module(&mut self, obj: PyObject) -> PyResult<()> {
|
||||
let module_info = Python::with_gil(|py| -> PyResult<Option<(String, String)>> {
|
||||
let obj: &PyAny = obj.extract(py)?;
|
||||
let builtins = PyModule::import(py, "builtins")?;
|
||||
let id = builtins.getattr("id")?.call1((obj, ))?.extract()?;
|
||||
if self.registered_module_ids.insert(id) {
|
||||
Ok(Some((obj.getattr("__name__")?.extract()?, obj.getattr("__file__")?.extract()?)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
})?;
|
||||
|
||||
let source = PyModule::import(py, "inspect")?.getattr("getsource")?.call1((obj, ))?.extract()?;
|
||||
let parser_result = parser::parse_program(source).map_err(|e|
|
||||
if let Some((module_name, source_file)) = module_info {
|
||||
let source = fs::read_to_string(source_file).map_err(|e|
|
||||
exceptions::PyIOError::new_err(format!("failed to read input file: {}", e)))?;
|
||||
let parser_result = parser::parse_program(&source).map_err(|e|
|
||||
exceptions::PySyntaxError::new_err(format!("failed to parse host object source: {}", e)))?;
|
||||
|
||||
for stmt in parser_result.into_iter() {
|
||||
let (name, def_id, ty) = self.composer.register_top_level(
|
||||
stmt,
|
||||
Some(self.resolver.clone()),
|
||||
"__main__".into(),
|
||||
).unwrap();
|
||||
let include = match &stmt.node {
|
||||
ast::StmtKind::ClassDef { decorator_list, .. } => {
|
||||
decorator_list.iter().any(|decorator| if let ast::ExprKind::Name { id, .. } = decorator.node
|
||||
{ id.to_string() == "kernel" || id.to_string() == "portable" } else { false })
|
||||
},
|
||||
ast::StmtKind::FunctionDef { decorator_list, .. } => {
|
||||
decorator_list.iter().any(|decorator| if let ast::ExprKind::Name { id, .. } = decorator.node
|
||||
{ id.to_string() == "extern" || id.to_string() == "portable" } else { false })
|
||||
},
|
||||
_ => false
|
||||
};
|
||||
|
||||
self.internal_resolver.add_id_def(name.clone(), def_id);
|
||||
if let Some(ty) = ty {
|
||||
self.internal_resolver.add_id_type(name, ty);
|
||||
if include {
|
||||
let (name, def_id, ty) = self.composer.register_top_level(
|
||||
stmt,
|
||||
Some(self.resolver.clone()),
|
||||
module_name.clone(),
|
||||
).unwrap();
|
||||
|
||||
self.internal_resolver.add_id_def(name.clone(), def_id);
|
||||
if let Some(ty) = ty {
|
||||
self.internal_resolver.add_id_type(name, ty);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn analyze(&mut self) -> PyResult<()> {
|
||||
|
Loading…
Reference in New Issue
Block a user