forked from M-Labs/nac3
nac3artiq: parse whole Python module, filter ast
This commit is contained in:
parent
8d839db553
commit
6141f01180
|
@ -1,4 +1,4 @@
|
||||||
from inspect import isclass
|
from inspect import isclass, getmodule
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
import nac3artiq
|
import nac3artiq
|
||||||
|
@ -8,28 +8,28 @@ __all__ = ["extern", "kernel"]
|
||||||
|
|
||||||
|
|
||||||
nac3 = nac3artiq.NAC3()
|
nac3 = nac3artiq.NAC3()
|
||||||
allow_object_registration = True
|
allow_module_registration = True
|
||||||
|
|
||||||
|
|
||||||
def extern(function):
|
def extern(function):
|
||||||
assert allow_object_registration
|
assert allow_module_registration
|
||||||
nac3.register_object(function)
|
nac3.register_module(getmodule(function))
|
||||||
return function
|
return function
|
||||||
|
|
||||||
|
|
||||||
def kernel(function_or_class):
|
def kernel(class_or_function):
|
||||||
global allow_object_registration
|
global allow_module_registration
|
||||||
|
|
||||||
if isclass(function_or_class):
|
assert allow_module_registration
|
||||||
assert allow_object_registration
|
nac3.register_module(getmodule(class_or_function))
|
||||||
nac3.register_object(function_or_class)
|
if isclass(class_or_function):
|
||||||
return function_or_class
|
return class_or_function
|
||||||
else:
|
else:
|
||||||
@wraps(function_or_class)
|
@wraps(class_or_function)
|
||||||
def run_on_core(self, *args, **kwargs):
|
def run_on_core(self, *args, **kwargs):
|
||||||
global allow_object_registration
|
global allow_module_registration
|
||||||
if allow_object_registration:
|
if allow_module_registration:
|
||||||
nac3.analyze()
|
nac3.analyze()
|
||||||
allow_object_registration = False
|
allow_module_registration = False
|
||||||
nac3.compile_method(self.__class__.__name__, function_or_class.__name__)
|
nac3.compile_method(self.__class__.__name__, class_or_function.__name__)
|
||||||
return run_on_core
|
return run_on_core
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use std::fs;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
@ -5,7 +6,7 @@ use std::process::Command;
|
||||||
|
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::exceptions;
|
use pyo3::exceptions;
|
||||||
use rustpython_parser::parser;
|
use rustpython_parser::{ast, parser};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
passes::{PassManager, PassManagerBuilder},
|
passes::{PassManager, PassManagerBuilder},
|
||||||
targets::*,
|
targets::*,
|
||||||
|
@ -31,7 +32,8 @@ struct Nac3 {
|
||||||
internal_resolver: Arc<ResolverInternal>,
|
internal_resolver: Arc<ResolverInternal>,
|
||||||
resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
|
resolver: Arc<Box<dyn SymbolResolver + Send + Sync>>,
|
||||||
composer: TopLevelComposer,
|
composer: TopLevelComposer,
|
||||||
top_level: Option<Arc<TopLevelContext>>
|
top_level: Option<Arc<TopLevelContext>>,
|
||||||
|
registered_module_ids: HashSet<u64>
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
|
@ -63,23 +65,47 @@ impl Nac3 {
|
||||||
internal_resolver,
|
internal_resolver,
|
||||||
resolver,
|
resolver,
|
||||||
composer,
|
composer,
|
||||||
top_level: None
|
top_level: None,
|
||||||
|
registered_module_ids: HashSet::new()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn register_object(&mut self, obj: PyObject) -> PyResult<()> {
|
fn register_module(&mut self, obj: PyObject) -> PyResult<()> {
|
||||||
Python::with_gil(|py| -> PyResult<()> {
|
let module_info = Python::with_gil(|py| -> PyResult<Option<(String, String)>> {
|
||||||
let obj: &PyAny = obj.extract(py)?;
|
let obj: &PyAny = obj.extract(py)?;
|
||||||
|
let builtins = PyModule::import(py, "builtins")?;
|
||||||
|
let id = builtins.getattr("id")?.call1((obj, ))?.extract()?;
|
||||||
|
if self.registered_module_ids.insert(id) {
|
||||||
|
Ok(Some((obj.getattr("__name__")?.extract()?, obj.getattr("__file__")?.extract()?)))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
let source = PyModule::import(py, "inspect")?.getattr("getsource")?.call1((obj, ))?.extract()?;
|
if let Some((module_name, source_file)) = module_info {
|
||||||
let parser_result = parser::parse_program(source).map_err(|e|
|
let source = fs::read_to_string(source_file).map_err(|e|
|
||||||
|
exceptions::PyIOError::new_err(format!("failed to read input file: {}", e)))?;
|
||||||
|
let parser_result = parser::parse_program(&source).map_err(|e|
|
||||||
exceptions::PySyntaxError::new_err(format!("failed to parse host object source: {}", e)))?;
|
exceptions::PySyntaxError::new_err(format!("failed to parse host object source: {}", e)))?;
|
||||||
|
|
||||||
for stmt in parser_result.into_iter() {
|
for stmt in parser_result.into_iter() {
|
||||||
|
let include = match &stmt.node {
|
||||||
|
ast::StmtKind::ClassDef { decorator_list, .. } => {
|
||||||
|
decorator_list.iter().any(|decorator| if let ast::ExprKind::Name { id, .. } = decorator.node
|
||||||
|
{ id.to_string() == "kernel" || id.to_string() == "portable" } else { false })
|
||||||
|
},
|
||||||
|
ast::StmtKind::FunctionDef { decorator_list, .. } => {
|
||||||
|
decorator_list.iter().any(|decorator| if let ast::ExprKind::Name { id, .. } = decorator.node
|
||||||
|
{ id.to_string() == "extern" || id.to_string() == "portable" } else { false })
|
||||||
|
},
|
||||||
|
_ => false
|
||||||
|
};
|
||||||
|
|
||||||
|
if include {
|
||||||
let (name, def_id, ty) = self.composer.register_top_level(
|
let (name, def_id, ty) = self.composer.register_top_level(
|
||||||
stmt,
|
stmt,
|
||||||
Some(self.resolver.clone()),
|
Some(self.resolver.clone()),
|
||||||
"__main__".into(),
|
module_name.clone(),
|
||||||
).unwrap();
|
).unwrap();
|
||||||
|
|
||||||
self.internal_resolver.add_id_def(name.clone(), def_id);
|
self.internal_resolver.add_id_def(name.clone(), def_id);
|
||||||
|
@ -87,9 +113,9 @@ impl Nac3 {
|
||||||
self.internal_resolver.add_id_type(name, ty);
|
self.internal_resolver.add_id_type(name, ty);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn analyze(&mut self) -> PyResult<()> {
|
fn analyze(&mut self) -> PyResult<()> {
|
||||||
|
|
Loading…
Reference in New Issue