cells/src/main.rs

759 lines
25 KiB
Rust

use std::collections::{HashMap, HashSet};
use std::num::NonZeroUsize;
use std::path::Path;
use std::process::Command;
use std::sync::Arc;
use eframe::egui;
use parking_lot::{Mutex, RwLock};
use nac3core::codegen;
use nac3core::inkwell;
use nac3core::nac3parser;
use nac3core::toplevel;
use nac3core::toplevel::composer;
use nac3core::typecheck::{type_inferencer, typedef};
mod basic_symbol_resolver;
use basic_symbol_resolver::{Resolver, ResolverInternal};
fn handle_typevar_definition(
var: &nac3parser::ast::Expr,
resolver: &(dyn nac3core::symbol_resolver::SymbolResolver + Send + Sync),
def_list: &[Arc<RwLock<toplevel::TopLevelDef>>],
unifier: &mut nac3core::typecheck::typedef::Unifier,
primitives: &type_inferencer::PrimitiveStore,
) -> Result<typedef::Type, HashSet<String>> {
let nac3parser::ast::ExprKind::Call { func, args, .. } = &var.node else {
return Err(HashSet::from([format!(
"expression {var:?} cannot be handled as a generic parameter in global scope"
)]));
};
match &func.node {
nac3parser::ast::ExprKind::Name { id, .. } if id == &"TypeVar".into() => {
let nac3parser::ast::ExprKind::Constant {
value: nac3parser::ast::Constant::Str(ty_name),
..
} = &args[0].node
else {
return Err(HashSet::from([format!(
"Expected string constant for first parameter of `TypeVar`, got {:?}",
&args[0].node
)]));
};
let generic_name: nac3parser::ast::StrRef = ty_name.to_string().into();
let constraints = args
.iter()
.skip(1)
.map(|x| -> Result<typedef::Type, HashSet<String>> {
let ty = toplevel::type_annotation::parse_ast_to_type_annotation_kinds(
resolver,
def_list,
unifier,
primitives,
x,
HashMap::new(),
)?;
toplevel::type_annotation::get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None,
)
})
.collect::<Result<Vec<_>, _>>()?;
let loc = func.location;
if constraints.len() == 1 {
return Err(HashSet::from([format!(
"A single constraint is not allowed (at {loc})"
)]));
}
Ok(unifier
.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc))
.ty)
}
nac3parser::ast::ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => {
if args.len() != 2 {
return Err(HashSet::from([format!(
"Expected 2 arguments for `ConstGeneric`, got {}",
args.len()
)]));
}
let nac3parser::ast::ExprKind::Constant {
value: nac3parser::ast::Constant::Str(ty_name),
..
} = &args[0].node
else {
return Err(HashSet::from([format!(
"Expected string constant for first parameter of `ConstGeneric`, got {:?}",
&args[0].node
)]));
};
let generic_name: nac3parser::ast::StrRef = ty_name.to_string().into();
let ty = toplevel::type_annotation::parse_ast_to_type_annotation_kinds(
resolver,
def_list,
unifier,
primitives,
&args[1],
HashMap::new(),
)?;
let constraint = toplevel::type_annotation::get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None,
)?;
let loc = func.location;
Ok(unifier
.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc))
.ty)
}
_ => Err(HashSet::from([format!(
"expression {var:?} cannot be handled as a generic parameter in global scope"
)])),
}
}
fn handle_assignment_pattern(
targets: &[nac3parser::ast::Expr],
value: &nac3parser::ast::Expr,
resolver: &(dyn nac3core::symbol_resolver::SymbolResolver + Send + Sync),
internal_resolver: &ResolverInternal,
def_list: &[Arc<RwLock<toplevel::TopLevelDef>>],
unifier: &mut nac3core::typecheck::typedef::Unifier,
primitives: &type_inferencer::PrimitiveStore,
) -> Result<(), String> {
if targets.len() == 1 {
match &targets[0].node {
nac3parser::ast::ExprKind::Name { id, .. } => {
if let Ok(var) =
handle_typevar_definition(value, resolver, def_list, unifier, primitives)
{
internal_resolver.add_id_type(*id, var);
Ok(())
} else if let Ok(val) =
toplevel::helper::parse_parameter_default_value(value, resolver)
{
internal_resolver.add_module_global(*id, val);
Ok(())
} else {
Err(format!("fails to evaluate this expression `{:?}` as a constant or generic parameter at {}",
targets[0].node,
targets[0].location,
))
}
}
nac3parser::ast::ExprKind::List { elts, .. }
| nac3parser::ast::ExprKind::Tuple { elts, .. } => {
handle_assignment_pattern(
elts,
value,
resolver,
internal_resolver,
def_list,
unifier,
primitives,
)?;
Ok(())
}
_ => Err(format!(
"assignment to {:?} is not supported at {}",
targets[0], targets[0].location
)),
}
} else {
match &value.node {
nac3parser::ast::ExprKind::List { elts, .. }
| nac3parser::ast::ExprKind::Tuple { elts, .. } => {
if elts.len() == targets.len() {
for (tar, val) in targets.iter().zip(elts) {
handle_assignment_pattern(
std::slice::from_ref(tar),
val,
resolver,
internal_resolver,
def_list,
unifier,
primitives,
)?;
}
Ok(())
} else {
Err(format!(
"number of elements to unpack does not match (expect {}, found {}) at {}",
targets.len(),
elts.len(),
value.location
))
}
}
_ => Err(format!(
"unpack of this expression is not supported at {}",
value.location
)),
}
}
}
fn handle_global_var(
target: &nac3parser::ast::Expr,
value: Option<&nac3parser::ast::Expr>,
resolver: &(dyn nac3core::symbol_resolver::SymbolResolver + Send + Sync),
internal_resolver: &ResolverInternal,
) -> Result<(), String> {
let nac3parser::ast::ExprKind::Name { id, .. } = target.node else {
return Err(format!(
"global variable declaration must be an identifier (at {})",
target.location,
));
};
let Some(value) = value else {
return Err(format!(
"global variable `{id}` must be initialized in its definition"
));
};
if let Ok(val) = toplevel::helper::parse_parameter_default_value(value, resolver) {
internal_resolver.add_module_global(id, val);
Ok(())
} else {
Err(format!(
"failed to evaluate this expression `{:?}` as a constant at {}",
target.node, target.location,
))
}
}
fn register_cells_function(
composer: &mut composer::TopLevelComposer,
resolver: Arc<dyn nac3core::symbol_resolver::SymbolResolver + Send + Sync>,
internal_resolver: &ResolverInternal,
name: &str,
signature: typedef::FunSignature,
) {
let ty = composer.unifier.add_ty(typedef::TypeEnum::TFunc(signature));
let name_ref = name.into();
let mut function_def = composer::TopLevelComposer::make_top_level_function_def(
name.to_string(),
name_ref,
ty,
Some(resolver.clone()),
None,
);
if let toplevel::TopLevelDef::Function {
ref mut instance_to_symbol,
..
} = function_def
{
instance_to_symbol.insert("".to_string(), "__nac3_cells_".to_string() + name);
} else {
unreachable!();
};
composer
.definition_ast_list
.push((RwLock::new(function_def).into(), None));
internal_resolver.add_id_def(
name_ref,
toplevel::DefinitionId(composer.definition_ast_list.len() - 1),
);
internal_resolver.add_id_type(name_ref, ty);
}
fn compile(code: &String, run_symbol: &String, output_filename: &Path) -> Result<(), String> {
let mut target_machine_options = codegen::CodeGenTargetMachineOptions::from_host();
target_machine_options.reloc_mode = inkwell::targets::RelocMode::PIC;
let llvm_options = codegen::CodeGenLLVMOptions {
opt_level: inkwell::OptimizationLevel::Default,
target: target_machine_options,
};
let context = inkwell::context::Context::create();
let target_machine = llvm_options
.target
.create_target_machine(llvm_options.opt_level)
.expect("couldn't create target machine");
let size_t = context
.ptr_sized_int_type(&target_machine.get_target_data(), None)
.get_bit_width();
let primitive: type_inferencer::PrimitiveStore =
composer::TopLevelComposer::make_primitives(size_t).0;
let (mut composer, builtins_def, builtins_ty) = composer::TopLevelComposer::new(
vec![],
vec![],
composer::ComposerConfig::default(),
size_t,
);
let internal_resolver: Arc<ResolverInternal> = ResolverInternal {
id_to_type: builtins_ty.into(),
id_to_def: builtins_def.into(),
module_globals: Mutex::default(),
str_store: Mutex::default(),
}
.into();
let resolver = Arc::new(Resolver(internal_resolver.clone()))
as Arc<dyn nac3core::symbol_resolver::SymbolResolver + Send + Sync>;
let irrt = codegen::irrt::load_irrt(&context, resolver.as_ref());
let list_tvar = if let typedef::TypeEnum::TObj { params, .. } =
&*composer.unifier.get_ty_immutable(primitive.list)
{
typedef::iter_type_vars(params).nth(0).unwrap()
} else {
unreachable!()
};
let list_float = composer
.unifier
.subst(
primitive.list,
&typedef::into_var_map([typedef::TypeVar {
id: list_tvar.id,
ty: primitive.float,
}]),
)
.unwrap();
register_cells_function(
&mut composer,
resolver.clone(),
internal_resolver.as_ref(),
"slider",
typedef::FunSignature {
args: vec![typedef::FuncArg {
name: "prev".into(),
ty: primitive.float,
default_value: None,
is_vararg: false,
}],
ret: primitive.float,
vars: typedef::VarMap::new(),
},
);
register_cells_function(
&mut composer,
resolver.clone(),
internal_resolver.as_ref(),
"plot",
typedef::FunSignature {
args: vec![typedef::FuncArg {
name: "data".into(),
ty: list_float,
default_value: None,
is_vararg: false,
}],
ret: primitive.none,
vars: typedef::VarMap::new(),
},
);
let parser_result =
match nac3parser::parser::parse_program(code.as_str(), String::from("cell1").into()) {
Ok(parser_result) => parser_result,
Err(err) => {
return Err(format!("parse error: {}", err));
}
};
for mut stmt in parser_result {
match stmt.node {
nac3parser::ast::StmtKind::Assign { targets, value, .. } => {
let def_list = composer.extract_def_list();
let unifier = &mut composer.unifier;
let primitives = &composer.primitives_ty;
handle_assignment_pattern(
&targets,
&value,
resolver.as_ref(),
internal_resolver.as_ref(),
&def_list,
unifier,
primitives,
)?;
}
nac3parser::ast::StmtKind::AnnAssign {
ref target,
ref value,
..
} => {
handle_global_var(
&target,
value.as_ref().map(Box::as_ref),
resolver.as_ref(),
internal_resolver.as_ref(),
)?;
let (name, def_id, _) = composer
.register_top_level(stmt, Some(resolver.clone()), "__main__", true)
.unwrap();
internal_resolver.add_id_def(name, def_id);
}
_ => {
if let nac3parser::ast::StmtKind::FunctionDef { name, .. } = &mut stmt.node {
if name.to_string() == "run" {
*name = run_symbol.as_str().into();
}
}
match composer.register_top_level(stmt, Some(resolver.clone()), "__main__", true) {
Ok((name, def_id, ty)) => {
internal_resolver.add_id_def(name, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(name, ty);
}
}
Err(err) => {
return Err(format!("composer error: {}", err));
}
}
}
}
}
let signature = typedef::FunSignature {
args: vec![],
ret: primitive.none,
vars: typedef::VarMap::new(),
};
let mut store = codegen::concrete_type::ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache);
let signature = store.add_cty(signature);
if let Err(errors) = composer.start_analysis(true) {
let error_count = errors.len();
let mut msg = format!("{error_count} error(s) occurred during top level analysis.");
for (error_i, error) in errors.iter().enumerate() {
let error_num = error_i + 1;
msg.push_str(&format!(
"\n=========== ERROR {error_num}/{error_count} ============\n{error}"
));
}
return Err(msg);
}
let top_level = Arc::new(composer.make_top_level_context());
let run_id_def = match resolver.get_identifier_def(run_symbol.as_str().into()) {
Ok(run_id_def) => run_id_def,
Err(_) => {
return Err(format!("no run() entry point"));
}
};
let instance = {
let defs = top_level.definitions.read();
let mut instance = defs[run_id_def.0].write();
let toplevel::TopLevelDef::Function {
instance_to_stmt,
instance_to_symbol,
..
} = &mut *instance
else {
unreachable!()
};
instance_to_symbol.insert(String::new(), run_symbol.clone());
instance_to_stmt[""].clone()
};
let task = codegen::CodeGenTask {
subst: Vec::default(),
symbol_name: run_symbol.clone(),
body: instance.body,
signature,
resolver,
store,
unifier_index: instance.unifier_id,
calls: instance.calls,
id: 0,
};
let nthreads = if inkwell::support::is_multithreaded() {
std::thread::available_parallelism()
.map(NonZeroUsize::get)
.unwrap_or(1usize)
} else {
1
};
let membuffers: Arc<Mutex<Vec<Vec<u8>>>> = Arc::default();
let membuffer = membuffers.clone();
let f = Arc::new(codegen::WithCall::new(Box::new(move |module| {
let buffer = module.write_bitcode_to_memory();
let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer);
})));
let threads = (0..nthreads)
.map(|i| {
Box::new(codegen::DefaultCodeGenerator::new(
format!("module{i}"),
size_t,
))
})
.collect();
let (registry, handles) =
codegen::WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
// Link all modules together into `main`
let buffers = membuffers.lock();
let main = context
.create_module_from_ir(
inkwell::memory_buffer::MemoryBuffer::create_from_memory_range(&buffers[0], "main"),
)
.unwrap();
for buffer in buffers.iter().skip(1) {
let other = context
.create_module_from_ir(
inkwell::memory_buffer::MemoryBuffer::create_from_memory_range(buffer, "main"),
)
.unwrap();
main.link_in_module(other).unwrap();
}
main.link_in_module(irrt).unwrap();
// Private all functions except "run"
let mut function_iter = main.get_first_function();
while let Some(func) = function_iter {
if func.count_basic_blocks() > 0 && func.get_name().to_str().unwrap() != run_symbol {
func.set_linkage(inkwell::module::Linkage::Private);
}
function_iter = func.get_next_function();
}
let pass_options = inkwell::passes::PassBuilderOptions::create();
pass_options.set_merge_functions(true);
main.run_passes("default<O2>", &target_machine, pass_options)
.unwrap();
target_machine
.write_to_file(&main, inkwell::targets::FileType::Object, output_filename)
.unwrap();
Ok(())
}
// The year is 2024, and compiler toolchains are still a trash fire.
fn link_with_lld(elf_filename: &Path, obj_filename: &Path) -> Result<(), String> {
let linker_args = vec![
"-shared".to_string(),
"--eh-frame-hdr".to_string(),
"-x".to_string(),
"-o".to_string(),
elf_filename.to_str().unwrap().to_string(),
obj_filename.to_str().unwrap().to_string(),
];
if let Ok(linker_status) = Command::new("ld.lld").args(linker_args).status() {
if !linker_status.success() {
return Err("failed to start linker".to_string());
}
} else {
return Err("linker returned non-zero status code".to_string());
}
Ok(())
}
type RunFn = unsafe extern "C" fn();
struct CellBin {
cell_id: usize,
// note destructor order
run_fn: Option<RunFn>,
library: Option<libloading::Library>,
directory: tempfile::TempDir,
}
impl CellBin {
fn new(cell_id: usize) -> Self {
Self {
cell_id,
run_fn: None,
library: None,
directory: tempfile::tempdir().unwrap(),
}
}
fn compile_and_load(&mut self, code: &String) -> Result<(), String> {
assert!(self.run_fn.is_none());
assert!(self.library.is_none());
let object = self.directory.path().join("module.o");
let library = self.directory.path().join("module.so");
let run_symbol = format!("__cells_run_{}", self.cell_id);
compile(code, &run_symbol, &object)?;
link_with_lld(&library, &object)?;
unsafe {
self.library = Some(libloading::Library::new(library).or_else(|e| Err(e.to_string()))?);
let raw_fun_ptr = self
.library
.as_ref()
.unwrap()
.get::<RunFn>(run_symbol.as_bytes())
.unwrap()
.try_as_raw_ptr()
.unwrap();
self.run_fn = Some(std::mem::transmute(raw_fun_ptr));
}
Ok(())
}
}
struct Cell {
code: String,
result: Result<CellBin, String>,
}
impl Cell {
const DEFAULT_CODE: &'static str = "def run():\n pass";
fn new() -> Self {
Self {
code: Self::DEFAULT_CODE.to_string(),
result: Err("".to_string()),
}
}
fn is_default(&self) -> bool {
self.code == Self::DEFAULT_CODE
}
fn update(&mut self) {
let mut new_bin = CellBin::new(0usize);
self.result = new_bin
.compile_and_load(&self.code)
.and_then(|_| Ok(new_bin));
}
}
struct Cells {
cells: Vec<Cell>,
}
static mut CELL_UI: Option<&'static mut egui::Ui> = None;
#[no_mangle]
pub extern "C" fn __nac3_cells_slider(prev: f64) -> f64 {
let ui = unsafe { CELL_UI.as_mut().unwrap() };
let mut value = prev;
ui.add(egui::Slider::new(&mut value, 0.0..=100.0));
value
}
pub struct List {
data: *mut f64,
length: usize,
}
#[no_mangle]
pub extern "C" fn __nac3_cells_plot(data: *const List) {
let ui = unsafe { CELL_UI.as_mut().unwrap() };
let data_slice = unsafe { std::slice::from_raw_parts((*data).data, (*data).length) };
let points: Vec<[f64; 2]> = (0..data_slice.len())
.map(|i| [i as f64, data_slice[i]])
.collect();
let line = egui_plot::Line::new(points);
egui_plot::Plot::new("my_plot")
.view_aspect(2.0)
.show(ui, |plot_ui| plot_ui.line(line));
}
impl Cells {
fn new() -> Self {
Self {
cells: vec![Cell::new()],
}
}
fn update(&mut self) {
for cell in self.cells.iter_mut() {
cell.update()
}
}
fn ensure_last(&mut self) {
let last_cell = self.cells.last().unwrap();
if last_cell.result.is_ok() && !last_cell.is_default() {
self.cells.push(Cell::new());
}
}
fn ui(&mut self, ui: &mut egui::Ui) {
for (cell_i, cell) in self.cells.iter_mut().enumerate() {
let theme = egui_extras::syntax_highlighting::CodeTheme::from_memory(ui.ctx());
let mut layouter = |ui: &egui::Ui, string: &str, wrap_width: f32| {
let mut layout_job =
egui_extras::syntax_highlighting::highlight(ui.ctx(), &theme, string, "Python");
layout_job.wrap.max_width = wrap_width;
layout_job.sections = layout_job
.sections
.iter()
.map(|layout_section| {
let mut section = layout_section.clone();
section.format.font_id = egui::FontId::monospace(16.0);
section
})
.collect();
ui.fonts(|f| f.layout_job(layout_job))
};
ui.group(|ui| {
ui.horizontal(|ui| {
ui.label(format!("[{}]", cell_i));
ui.add(
egui::TextEdit::multiline(&mut cell.code)
.code_editor()
.desired_rows(4)
.lock_focus(true)
.desired_width(f32::INFINITY)
.layouter(&mut layouter)
.font(egui::FontId::monospace(16.0)),
);
});
match &cell.result {
Ok(bin) => unsafe {
CELL_UI.replace(
std::mem::transmute::<&mut egui::Ui, &'static mut egui::Ui>(ui),
);
bin.run_fn.unwrap()();
CELL_UI.take();
},
Err(msg) => {
if !msg.is_empty() {
ui.colored_label(egui::Color32::from_rgb(255, 0, 0), msg);
}
}
};
});
}
}
}
fn main() -> eframe::Result {
inkwell::targets::Target::initialize_native(&inkwell::targets::InitializationConfig::default())
.unwrap();
let options = eframe::NativeOptions {
viewport: egui::ViewportBuilder::default().with_inner_size([1024.0, 768.0]),
..Default::default()
};
let mut cells = Cells::new();
eframe::run_simple_native("Cells", options, move |ctx, _frame| {
let submit_key = egui::KeyboardShortcut::new(egui::Modifiers::CTRL, egui::Key::Enter);
if ctx.input_mut(|input| input.consume_shortcut(&submit_key)) {
cells.update();
cells.ensure_last();
}
egui::CentralPanel::default().show(ctx, |ui| {
cells.ui(ui);
});
})
}
#[no_mangle]
pub extern "C" fn __nac3_personality(_state: u32, _exception_object: u32, _context: u32) -> u32 {
unimplemented!()
}