1
0
forked from M-Labs/nac3
nac3/nac3artiq/src/lib.rs

394 lines
14 KiB
Rust
Raw Normal View History

use std::collections::{HashMap, HashSet};
use std::fs;
2021-09-23 19:30:03 +08:00
use std::path::Path;
2021-09-23 21:30:13 +08:00
use std::process::Command;
2021-09-30 22:30:54 +08:00
use std::sync::Arc;
2020-12-19 15:29:39 +08:00
2021-09-23 19:30:03 +08:00
use inkwell::{
passes::{PassManager, PassManagerBuilder},
targets::*,
OptimizationLevel,
};
2021-09-30 22:30:54 +08:00
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyList};
use rustpython_parser::{
ast::{self, StrRef},
parser,
};
use parking_lot::{RwLock, Mutex};
2020-12-19 16:23:12 +08:00
2021-09-23 19:30:03 +08:00
use nac3core::{
2021-09-30 17:07:48 +08:00
codegen::{CodeGenTask, WithCall, WorkerRegistry},
2021-09-23 19:30:03 +08:00
symbol_resolver::SymbolResolver,
toplevel::{composer::TopLevelComposer, TopLevelContext, TopLevelDef},
typecheck::typedef::FunSignature,
2021-09-23 19:30:03 +08:00
};
2021-09-30 22:30:54 +08:00
use nac3core::{
toplevel::DefinitionId,
typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
};
use crate::symbol_resolver::Resolver;
2020-12-19 15:29:39 +08:00
mod builtins;
2021-09-23 19:30:03 +08:00
mod symbol_resolver;
2020-12-19 15:29:39 +08:00
#[derive(PartialEq, Clone, Copy)]
2021-09-29 15:33:12 +08:00
enum Isa {
RiscV,
2021-09-30 22:30:54 +08:00
CortexA9,
2021-09-29 15:33:12 +08:00
}
#[derive(Clone)]
pub struct PrimitivePythonId {
int: u64,
int32: u64,
int64: u64,
float: u64,
bool: u64,
list: u64,
tuple: u64,
}
2021-10-02 18:28:44 +08:00
// TopLevelComposer is unsendable as it holds the unification table, which is
// unsendable due to Rc. Arc would cause a performance hit.
2021-09-30 22:30:54 +08:00
#[pyclass(unsendable, name = "NAC3")]
2020-12-19 15:29:39 +08:00
struct Nac3 {
2021-09-29 15:33:12 +08:00
isa: Isa,
2021-09-23 19:30:03 +08:00
primitive: PrimitiveStore,
2021-09-30 22:30:54 +08:00
builtins_ty: HashMap<StrRef, Type>,
builtins_def: HashMap<StrRef, DefinitionId>,
pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
2021-09-23 19:30:03 +08:00
composer: TopLevelComposer,
top_level: Option<Arc<TopLevelContext>>,
2021-09-30 22:30:54 +08:00
to_be_registered: Vec<PyObject>,
primitive_ids: PrimitivePythonId,
global_value_ids: Arc<Mutex<HashSet<u64>>>,
2021-09-30 22:30:54 +08:00
}
impl Nac3 {
fn register_module_impl(&mut self, obj: PyObject) -> PyResult<()> {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
let obj: &PyAny = obj.extract(py)?;
let builtins = PyModule::import(py, "builtins")?;
let id_fn = builtins.getattr("id")?;
let members: &PyList = PyModule::import(py, "inspect")?
.getattr("getmembers")?
.call1((obj,))?
.cast_as()?;
for member in members.iter() {
let key: &str = member.get_item(0)?.extract()?;
let val = id_fn.call1((member.get_item(1)?,))?.extract()?;
name_to_pyid.insert(key.into(), val);
}
Ok((
obj.getattr("__name__")?.extract()?,
obj.getattr("__file__")?.extract()?,
))
})?;
let source = fs::read_to_string(source_file).map_err(|e| {
exceptions::PyIOError::new_err(format!("failed to read input file: {}", e))
})?;
let parser_result = parser::parse_program(&source)
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {}", e)))?;
2021-09-30 22:30:54 +08:00
let resolver = Arc::new(Box::new(Resolver {
id_to_type: self.builtins_ty.clone().into(),
id_to_def: self.builtins_def.clone().into(),
pyid_to_def: self.pyid_to_def.clone(),
pyid_to_type: self.pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(),
global_value_ids: self.global_value_ids.clone(),
2021-09-30 22:30:54 +08:00
class_names: Default::default(),
name_to_pyid: name_to_pyid.clone(),
module: obj,
2021-09-30 22:30:54 +08:00
}) as Box<dyn SymbolResolver + Send + Sync>);
let mut name_to_def = HashMap::new();
let mut name_to_type = HashMap::new();
for mut stmt in parser_result.into_iter() {
let include = match stmt.node {
ast::StmtKind::ClassDef {
ref decorator_list,
ref mut body,
..
} => {
let kernels = decorator_list.iter().any(|decorator| {
if let ast::ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "kernel" || id.to_string() == "portable"
} else {
false
}
});
body.retain(|stmt| {
if let ast::StmtKind::FunctionDef {
ref decorator_list, ..
} = stmt.node
{
decorator_list.iter().any(|decorator| {
if let ast::ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "kernel" || id.to_string() == "portable"
} else {
false
}
})
} else {
true
}
});
kernels
}
ast::StmtKind::FunctionDef {
ref decorator_list, ..
} => decorator_list.iter().any(|decorator| {
if let ast::ExprKind::Name { id, .. } = decorator.node {
2021-10-01 00:02:15 +08:00
let id = id.to_string();
id == "extern" || id == "portable" || id == "kernel"
2021-09-30 22:30:54 +08:00
} else {
false
}
}),
_ => false,
};
if include {
let (name, def_id, ty) = self
.composer
.register_top_level(stmt, Some(resolver.clone()), module_name.clone())
.unwrap();
name_to_def.insert(name, def_id);
if let Some(ty) = ty {
name_to_type.insert(name, ty);
}
}
}
let mut map = self.pyid_to_def.write();
for (name, def) in name_to_def.into_iter() {
map.insert(*name_to_pyid.get(&name).unwrap(), def);
}
let mut map = self.pyid_to_type.write();
for (name, ty) in name_to_type.into_iter() {
map.insert(*name_to_pyid.get(&name).unwrap(), ty);
}
Ok(())
}
2020-12-19 15:29:39 +08:00
}
#[pymethods]
impl Nac3 {
#[new]
fn new(isa: &str, py: Python) -> PyResult<Self> {
2021-09-29 15:33:12 +08:00
let isa = match isa {
"riscv" => Isa::RiscV,
"cortexa9" => Isa::CortexA9,
2021-09-30 22:30:54 +08:00
_ => return Err(exceptions::PyValueError::new_err("invalid ISA")),
2021-09-29 15:33:12 +08:00
};
2021-09-23 19:30:03 +08:00
let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0;
let builtins = if isa == Isa::RiscV {
builtins::timeline_builtins(&primitive)
} else {
vec![]
};
let (composer, builtins_def, builtins_ty) = TopLevelComposer::new(builtins);
let builtins_mod = PyModule::import(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap();
let numpy_mod = PyModule::import(py, "numpy").unwrap();
let primitive_ids = PrimitivePythonId {
int: id_fn
.call1((builtins_mod.getattr("int").unwrap(),))
.unwrap()
.extract()
.unwrap(),
int32: id_fn
.call1((numpy_mod.getattr("int32").unwrap(),))
.unwrap()
.extract()
.unwrap(),
int64: id_fn
.call1((numpy_mod.getattr("int64").unwrap(),))
.unwrap()
.extract()
.unwrap(),
bool: id_fn
.call1((builtins_mod.getattr("bool").unwrap(),))
.unwrap()
.extract()
.unwrap(),
float: id_fn
.call1((builtins_mod.getattr("float").unwrap(),))
.unwrap()
.extract()
.unwrap(),
list: id_fn
.call1((builtins_mod.getattr("list").unwrap(),))
.unwrap()
.extract()
.unwrap(),
tuple: id_fn
.call1((builtins_mod.getattr("tuple").unwrap(),))
.unwrap()
.extract()
.unwrap(),
};
2021-09-29 15:33:12 +08:00
Ok(Nac3 {
isa,
2021-09-26 22:17:09 +08:00
primitive,
2021-09-30 22:30:54 +08:00
builtins_ty,
builtins_def,
2021-09-26 22:17:09 +08:00
composer,
primitive_ids,
top_level: None,
2021-09-30 22:30:54 +08:00
pyid_to_def: Default::default(),
pyid_to_type: Default::default(),
to_be_registered: Default::default(),
global_value_ids: Default::default(),
2021-09-29 15:33:12 +08:00
})
2020-12-19 15:29:39 +08:00
}
2021-09-30 22:30:54 +08:00
fn register_module(&mut self, obj: PyObject) {
self.to_be_registered.push(obj);
2020-12-19 15:29:39 +08:00
}
2021-09-23 19:30:03 +08:00
fn analyze(&mut self) -> PyResult<()> {
2021-09-30 22:30:54 +08:00
for obj in std::mem::take(&mut self.to_be_registered).into_iter() {
self.register_module_impl(obj)?;
}
2021-09-23 19:30:03 +08:00
self.composer.start_analysis(true).unwrap();
self.top_level = Some(Arc::new(self.composer.make_top_level_context()));
Ok(())
}
fn compile_method(&mut self, class: u64, method_name: String, py: Python) -> PyResult<()> {
2021-09-23 19:30:03 +08:00
let top_level = self.top_level.as_ref().unwrap();
2021-09-30 22:30:54 +08:00
let module_resolver;
2021-09-23 19:30:03 +08:00
let instance = {
let defs = top_level.definitions.read();
2021-09-30 22:30:54 +08:00
let class_def = defs[self.pyid_to_def.read().get(&class).unwrap().0].write();
let mut method_def = if let TopLevelDef::Class {
methods, resolver, ..
} = &*class_def
{
module_resolver = Some(resolver.clone().unwrap());
if let Some((_name, _unification_key, definition_id)) = methods
.iter()
.find(|method| method.0.to_string() == method_name)
{
2021-09-23 19:30:03 +08:00
defs[definition_id.0].write()
} else {
return Err(exceptions::PyValueError::new_err("method not found"));
2020-12-19 15:29:39 +08:00
}
} else {
2021-09-30 22:30:54 +08:00
return Err(exceptions::PyTypeError::new_err(
"parent object is not a class",
));
2021-09-23 19:30:03 +08:00
};
// FIXME: what is this for? What happens if the kernel is called twice?
if let TopLevelDef::Function {
instance_to_stmt,
instance_to_symbol,
..
} = &mut *method_def
{
2021-09-30 22:30:54 +08:00
instance_to_symbol.insert("".to_string(), method_name);
2021-09-23 19:30:03 +08:00
instance_to_stmt[""].clone()
} else {
unreachable!()
2020-12-19 15:29:39 +08:00
}
2021-09-23 19:30:03 +08:00
};
let signature = FunSignature {
args: vec![],
ret: self.primitive.none,
vars: HashMap::new(),
};
let task = CodeGenTask {
subst: Default::default(),
symbol_name: "__modinit__".to_string(),
2021-09-23 19:30:03 +08:00
body: instance.body,
signature,
2021-09-30 22:30:54 +08:00
resolver: module_resolver.unwrap(),
2021-09-23 19:30:03 +08:00
unifier: top_level.unifiers.read()[instance.unifier_id].clone(),
calls: instance.calls,
};
2021-09-29 15:33:12 +08:00
let isa = self.isa;
2021-09-23 19:30:03 +08:00
let f = Arc::new(WithCall::new(Box::new(move |module| {
let builder = PassManagerBuilder::create();
builder.set_optimization_level(OptimizationLevel::Default);
2021-09-23 19:30:03 +08:00
let passes = PassManager::create(());
builder.populate_module_pass_manager(&passes);
passes.run_on(module);
2020-12-19 15:29:39 +08:00
2021-09-29 15:33:12 +08:00
let (triple, features) = match isa {
Isa::RiscV => (TargetTriple::create("riscv32-unknown-linux"), "+a,+m"),
2021-09-30 22:30:54 +08:00
Isa::CortexA9 => (
TargetTriple::create("armv7-unknown-linux-gnueabihf"),
"+dsp,+fp16,+neon,+vfp3",
),
2021-09-29 15:33:12 +08:00
};
2021-09-23 19:30:03 +08:00
let target =
Target::from_triple(&triple).expect("couldn't create target from target triple");
let target_machine = target
.create_target_machine(
&triple,
"",
2021-09-24 14:45:01 +08:00
features,
2021-09-23 19:30:03 +08:00
OptimizationLevel::Default,
RelocMode::PIC,
2021-09-23 19:30:03 +08:00
CodeModel::Default,
)
.expect("couldn't create target machine");
target_machine
2021-09-30 22:30:54 +08:00
.write_to_file(
module,
FileType::Object,
Path::new(&format!("{}.o", module.get_name().to_str().unwrap())),
)
2021-09-23 19:30:03 +08:00
.expect("couldn't write module to file");
})));
2021-09-23 21:30:13 +08:00
let thread_names: Vec<String> = (0..4).map(|i| format!("module{}", i)).collect();
let threads: Vec<_> = thread_names.iter().map(|s| s.as_str()).collect();
py.allow_threads(|| {
let (registry, handles) =
WorkerRegistry::create_workers(&threads, top_level.clone(), f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
});
2021-09-23 21:30:13 +08:00
let mut linker_args = vec![
"-shared".to_string(),
"--eh-frame-hdr".to_string(),
"-Tkernel.ld".to_string(),
"-x".to_string(),
"-o".to_string(),
2021-09-30 22:30:54 +08:00
"module.elf".to_string(),
];
linker_args.extend(thread_names.iter().map(|name| name.to_owned() + ".o"));
if let Ok(linker_status) = Command::new("ld.lld").args(linker_args).status() {
2021-09-23 21:30:13 +08:00
if !linker_status.success() {
2021-09-30 22:30:54 +08:00
return Err(exceptions::PyRuntimeError::new_err(
"failed to start linker",
));
2021-09-23 21:30:13 +08:00
}
} else {
2021-09-30 22:30:54 +08:00
return Err(exceptions::PyRuntimeError::new_err(
"linker returned non-zero status code",
));
2021-09-23 21:30:13 +08:00
}
2021-09-23 19:30:03 +08:00
Ok(())
2020-12-19 15:29:39 +08:00
}
2020-12-18 10:09:35 +08:00
}
#[pymodule]
fn nac3artiq(_py: Python, m: &PyModule) -> PyResult<()> {
2020-12-19 16:23:12 +08:00
Target::initialize_all(&InitializationConfig::default());
2020-12-19 15:29:39 +08:00
m.add_class::<Nac3>()?;
2020-12-18 10:09:35 +08:00
Ok(())
}