1
0
forked from M-Labs/nac3
nac3/nac3standalone/src/main.rs

483 lines
17 KiB
Rust
Raw Normal View History

2024-06-12 15:13:09 +08:00
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
2024-06-12 15:13:09 +08:00
#![allow(clippy::too_many_lines, clippy::wildcard_imports)]
use clap::Parser;
use inkwell::context::Context;
use inkwell::{
2024-06-12 14:45:03 +08:00
memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*,
2022-02-21 18:27:46 +08:00
OptimizationLevel,
};
use nac3core::codegen::irrt::setup_irrt_exceptions;
2021-08-19 15:30:52 +08:00
use nac3core::{
codegen::{
concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenLLVMOptions,
CodeGenTargetMachineOptions, CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry,
},
2021-08-25 15:30:36 +08:00
symbol_resolver::SymbolResolver,
toplevel::{
composer::{ComposerConfig, TopLevelComposer},
2024-06-12 14:45:03 +08:00
helper::parse_parameter_default_value,
type_annotation::*,
2022-02-21 18:27:46 +08:00
TopLevelDef,
},
2022-02-21 18:27:46 +08:00
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{FunSignature, Type, Unifier, VarMap},
2022-02-21 18:27:46 +08:00
},
};
use nac3parser::{
ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
2022-02-21 18:27:46 +08:00
parser,
2021-08-19 15:30:52 +08:00
};
use parking_lot::{Mutex, RwLock};
use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::{collections::HashMap, fs, path::Path, sync::Arc};
2020-12-17 22:20:30 +08:00
2021-08-19 15:30:52 +08:00
mod basic_symbol_resolver;
use basic_symbol_resolver::*;
2020-12-17 22:20:30 +08:00
/// Command-line argument parser definition.
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct CommandLineArgs {
/// The name of the input file.
file_name: String,
/// The number of threads allocated to processing the source file. If 0 is passed to this
/// parameter, all available threads will be used for compilation.
#[arg(short = 'T', default_value_t = 1)]
threads: u32,
/// The level to optimize the LLVM IR.
#[arg(short = 'O', default_value_t = 2, value_parser = clap::value_parser!(u32).range(0..=3))]
opt_level: u32,
/// Whether to emit LLVM IR at the end of every module.
///
/// If multithreaded compilation is also enabled, each thread will emit its own module.
#[arg(long, default_value_t = false)]
emit_llvm: bool,
/// The target triple to compile for.
#[arg(long)]
triple: Option<String>,
/// The target CPU to compile for.
#[arg(long)]
mcpu: Option<String>,
/// Additional target features to enable/disable, specified using the `+`/`-` prefixes.
#[arg(long)]
target_features: Option<String>,
}
2022-04-18 16:02:48 +08:00
fn handle_typevar_definition(
var: &Expr,
resolver: &(dyn SymbolResolver + Send + Sync),
def_list: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
) -> Result<Type, HashSet<String>> {
let ExprKind::Call { func, args, .. } = &var.node else {
2024-06-12 14:45:03 +08:00
return Err(HashSet::from([format!(
"expression {var:?} cannot be handled as a generic parameter in global scope"
)]));
};
match &func.node {
ExprKind::Name { id, .. } if id == &"TypeVar".into() => {
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
2024-06-12 14:45:03 +08:00
return Err(HashSet::from([format!(
"Expected string constant for first parameter of `TypeVar`, got {:?}",
&args[0].node
)]));
};
let generic_name: StrRef = ty_name.to_string().into();
2022-04-18 16:02:48 +08:00
let constraints = args
.iter()
.skip(1)
.map(|x| -> Result<Type, HashSet<String>> {
2022-04-18 16:02:48 +08:00
let ty = parse_ast_to_type_annotation_kinds(
resolver,
def_list,
unifier,
primitives,
x,
2024-06-12 15:13:09 +08:00
HashMap::new(),
2022-04-18 16:02:48 +08:00
)?;
get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None,
)
2022-04-18 16:02:48 +08:00
})
.collect::<Result<Vec<_>, _>>()?;
let loc = func.location;
if constraints.len() == 1 {
2024-06-12 14:45:03 +08:00
return Err(HashSet::from([format!(
"A single constraint is not allowed (at {loc})"
)]));
}
Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).ty)
}
ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => {
if args.len() != 2 {
2024-06-12 14:45:03 +08:00
return Err(HashSet::from([format!(
"Expected 2 arguments for `ConstGeneric`, got {}",
args.len()
)]));
}
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
2024-06-12 14:45:03 +08:00
return Err(HashSet::from([format!(
"Expected string constant for first parameter of `ConstGeneric`, got {:?}",
&args[0].node
)]));
};
let generic_name: StrRef = ty_name.to_string().into();
let ty = parse_ast_to_type_annotation_kinds(
resolver,
def_list,
unifier,
primitives,
&args[1],
2024-06-12 15:13:09 +08:00
HashMap::new(),
)?;
2024-06-12 14:45:03 +08:00
let constraint =
get_type_from_type_annotation_kinds(def_list, unifier, primitives, &ty, &mut None)?;
let loc = func.location;
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty)
2022-04-18 16:02:48 +08:00
}
2024-06-12 14:45:03 +08:00
_ => Err(HashSet::from([format!(
"expression {var:?} cannot be handled as a generic parameter in global scope"
)])),
2022-04-18 16:02:48 +08:00
}
}
fn handle_assignment_pattern(
targets: &[Expr],
value: &Expr,
resolver: &(dyn SymbolResolver + Send + Sync),
internal_resolver: &ResolverInternal,
def_list: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
) -> Result<(), String> {
if targets.len() == 1 {
match &targets[0].node {
ExprKind::Name { id, .. } => {
2024-06-12 14:45:03 +08:00
if let Ok(var) =
handle_typevar_definition(value, resolver, def_list, unifier, primitives)
{
2022-04-18 16:02:48 +08:00
internal_resolver.add_id_type(*id, var);
Ok(())
2024-06-12 14:45:03 +08:00
} else if let Ok(val) = parse_parameter_default_value(value, resolver) {
2022-04-18 16:02:48 +08:00
internal_resolver.add_module_global(*id, val);
Ok(())
} else {
Err(format!("fails to evaluate this expression `{:?}` as a constant or generic parameter at {}",
2022-04-18 16:02:48 +08:00
targets[0].node,
targets[0].location,
))
}
}
ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
handle_assignment_pattern(
elts,
value,
resolver,
internal_resolver,
def_list,
unifier,
primitives,
)?;
Ok(())
}
_ => Err(format!(
"assignment to {:?} is not supported at {}",
targets[0], targets[0].location
)),
}
} else {
match &value.node {
ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
if elts.len() == targets.len() {
2022-04-18 16:02:48 +08:00
for (tar, val) in targets.iter().zip(elts) {
handle_assignment_pattern(
std::slice::from_ref(tar),
val,
resolver,
internal_resolver,
def_list,
unifier,
primitives,
)?;
}
Ok(())
} else {
Err(format!(
"number of elements to unpack does not match (expect {}, found {}) at {}",
targets.len(),
elts.len(),
value.location
))
2022-04-18 16:02:48 +08:00
}
}
2024-06-12 14:45:03 +08:00
_ => Err(format!("unpack of this expression is not supported at {}", value.location)),
2022-04-18 16:02:48 +08:00
}
}
}
2020-12-17 22:20:30 +08:00
fn main() {
let cli = CommandLineArgs::parse();
2024-06-12 14:45:03 +08:00
let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } =
cli;
Target::initialize_all(&InitializationConfig::default());
let host_target_machine = CodeGenTargetMachineOptions::from_host();
let triple = triple.unwrap_or(host_target_machine.triple.clone());
let mcpu = mcpu
.map(|arg| if arg == "native" { host_target_machine.cpu.clone() } else { arg })
.unwrap_or_default();
let target_features = target_features.unwrap_or_default();
let threads = if is_multithreaded() {
if threads == 0 {
2024-06-12 15:13:09 +08:00
std::thread::available_parallelism().map(NonZeroUsize::get).unwrap_or(1usize)
} else {
2024-06-12 15:13:09 +08:00
threads as usize
}
} else {
if threads != 1 {
println!("Warning: Number of threads specified in command-line but multithreading is disabled in LLVM at build time! Defaulting to single-threaded compilation");
}
1
};
let opt_level = match opt_level {
0 => OptimizationLevel::None,
1 => OptimizationLevel::Less,
2 => OptimizationLevel::Default,
// The default behavior for -O<n> where n>3 defaults to O3 for both Clang and GCC
_ => OptimizationLevel::Aggressive,
};
2021-09-22 14:30:52 +08:00
let target_machine_options = CodeGenTargetMachineOptions {
triple,
cpu: mcpu,
features: target_features,
reloc_mode: RelocMode::PIC,
..host_target_machine
};
let size_t = Context::create()
.ptr_sized_int_type(
&target_machine_options
.create_target_machine(opt_level)
.map(|tm| tm.get_target_data())
.unwrap(),
None,
)
.get_bit_width();
2021-12-28 10:59:17 +08:00
let program = match fs::read_to_string(file_name.clone()) {
2020-12-17 22:20:30 +08:00
Ok(program) => program,
2021-08-19 15:30:52 +08:00
Err(err) => {
println!("Cannot open input file: {err}");
2021-08-19 15:30:52 +08:00
return;
}
};
2021-08-19 15:30:52 +08:00
let primitive: PrimitiveStore = TopLevelComposer::make_primitives(size_t).0;
2022-02-21 18:27:46 +08:00
let (mut composer, builtins_def, builtins_ty) =
TopLevelComposer::new(vec![], vec![], ComposerConfig::default(), size_t);
let internal_resolver: Arc<ResolverInternal> = ResolverInternal {
id_to_type: builtins_ty.into(),
id_to_def: builtins_def.into(),
class_names: Mutex::default(),
module_globals: Mutex::default(),
str_store: Mutex::default(),
2024-06-12 14:45:03 +08:00
}
.into();
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
let context = inkwell::context::Context::create();
// Process IRRT
let irrt = load_irrt(&context);
setup_irrt_exceptions(&context, &irrt, resolver.as_ref());
if emit_llvm {
irrt.write_bitcode_to_path(Path::new("irrt.bc"));
}
// Process the Python script
2021-12-28 10:59:17 +08:00
let parser_result = parser::parse_program(&program, file_name.into()).unwrap();
for stmt in parser_result {
2022-04-18 16:02:48 +08:00
match &stmt.node {
StmtKind::Assign { targets, value, .. } => {
let def_list = composer.extract_def_list();
let unifier = &mut composer.unifier;
let primitives = &composer.primitives_ty;
if let Err(err) = handle_assignment_pattern(
targets,
value,
resolver.as_ref(),
internal_resolver.as_ref(),
&def_list,
unifier,
primitives,
) {
eprintln!("{err}");
2022-04-18 16:02:48 +08:00
return;
}
2024-06-12 14:45:03 +08:00
}
2022-04-18 16:02:48 +08:00
// allow (and ignore) "from __future__ import annotations"
StmtKind::ImportFrom { module, names, .. }
2024-06-12 14:45:03 +08:00
if module == &Some("__future__".into())
&& names.len() == 1
&& names[0].name == "annotations".into() => {}
2022-04-18 16:02:48 +08:00
_ => {
2024-06-12 14:45:03 +08:00
let (name, def_id, ty) = composer
.register_top_level(stmt, Some(resolver.clone()), "__main__", true)
.unwrap();
2022-04-18 16:02:48 +08:00
internal_resolver.add_id_def(name, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(name, ty);
}
}
}
2021-08-27 11:39:36 +08:00
}
let signature = FunSignature { args: vec![], ret: primitive.int32, vars: VarMap::new() };
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache);
let signature = store.add_cty(signature);
if let Err(errors) = composer.start_analysis(true) {
let error_count = errors.len();
eprintln!("{error_count} error(s) occurred during top level analysis.");
for (error_i, error) in errors.iter().enumerate() {
let error_num = error_i + 1;
eprintln!("=========== ERROR {error_num}/{error_count} ============");
eprintln!("{error}");
}
eprintln!("==================================");
panic!("top level analysis failed");
}
let top_level = Arc::new(composer.make_top_level_context());
2021-08-25 15:30:36 +08:00
let instance = {
let defs = top_level.definitions.read();
2022-02-21 18:27:46 +08:00
let mut instance = defs[resolver
.get_identifier_def("run".into())
.unwrap_or_else(|_| panic!("cannot find run() entry point"))
.0]
.write();
2024-06-12 14:45:03 +08:00
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance
else {
unreachable!()
};
instance_to_symbol.insert(String::new(), "run".to_string());
instance_to_stmt[""].clone()
};
let llvm_options = CodeGenLLVMOptions { opt_level, target: target_machine_options };
2021-08-19 15:30:52 +08:00
let task = CodeGenTask {
subst: Vec::default(),
2021-08-19 15:30:52 +08:00
symbol_name: "run".to_string(),
body: instance.body,
2021-08-19 15:30:52 +08:00
signature,
resolver,
store,
unifier_index: instance.unifier_id,
calls: instance.calls,
2021-11-20 19:50:25 +08:00
id: 0,
2021-08-19 15:30:52 +08:00
};
let membuffers: Arc<Mutex<Vec<Vec<u8>>>> = Arc::default();
let membuffer = membuffers.clone();
2022-02-13 17:21:42 +08:00
let f = Arc::new(WithCall::new(Box::new(move |module| {
let buffer = module.write_bitcode_to_memory();
let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer);
2021-08-19 15:30:52 +08:00
})));
let threads = (0..threads)
.map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), size_t)))
.collect();
2023-12-08 17:43:32 +08:00
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
2021-08-19 15:30:52 +08:00
registry.add_task(task);
registry.wait_tasks_complete(handles);
2021-09-22 14:30:52 +08:00
// Link all modules together into `main`
let buffers = membuffers.lock();
let main = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
.unwrap();
if emit_llvm {
main.write_bitcode_to_path(Path::new("main.bc"));
}
for (idx, buffer) in buffers.iter().skip(1).enumerate() {
let other = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
.unwrap();
if emit_llvm {
other.write_bitcode_to_path(Path::new(&format!("module{idx}.bc")));
}
main.link_in_module(other).unwrap();
}
2023-09-29 16:41:32 +08:00
main.link_in_module(irrt).unwrap();
// Private all functions except "run"
let mut function_iter = main.get_first_function();
while let Some(func) = function_iter {
if func.count_basic_blocks() > 0 && func.get_name().to_str().unwrap() != "run" {
func.set_linkage(inkwell::module::Linkage::Private);
}
function_iter = func.get_next_function();
}
// Optimize `main`
2024-06-12 14:45:03 +08:00
let target_machine = llvm_options
.target
.create_target_machine(llvm_options.opt_level)
.expect("couldn't create target machine");
2023-09-11 13:14:35 +08:00
let pass_options = PassBuilderOptions::create();
pass_options.set_merge_functions(true);
let passes = format!("default<O{}>", opt_level as u32);
let result = main.run_passes(passes.as_str(), &target_machine, pass_options);
2023-09-11 13:14:35 +08:00
if let Err(err) = result {
panic!("Failed to run optimization for module `main`: {}", err.to_string());
}
// Write output
target_machine
2022-02-21 18:27:46 +08:00
.write_to_file(&main, FileType::Object, Path::new("module.o"))
.expect("couldn't write module to file");
2020-12-17 22:20:30 +08:00
}