forked from M-Labs/nac3
483 lines
17 KiB
Rust
483 lines
17 KiB
Rust
#![deny(
|
|
future_incompatible,
|
|
let_underscore,
|
|
nonstandard_style,
|
|
rust_2024_compatibility,
|
|
clippy::all
|
|
)]
|
|
#![warn(clippy::pedantic)]
|
|
#![allow(clippy::too_many_lines, clippy::wildcard_imports)]
|
|
|
|
use clap::Parser;
|
|
use inkwell::context::Context;
|
|
use inkwell::{
|
|
memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*,
|
|
OptimizationLevel,
|
|
};
|
|
use nac3core::codegen::irrt::setup_irrt_exceptions;
|
|
use nac3core::{
|
|
codegen::{
|
|
concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenLLVMOptions,
|
|
CodeGenTargetMachineOptions, CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry,
|
|
},
|
|
symbol_resolver::SymbolResolver,
|
|
toplevel::{
|
|
composer::{ComposerConfig, TopLevelComposer},
|
|
helper::parse_parameter_default_value,
|
|
type_annotation::*,
|
|
TopLevelDef,
|
|
},
|
|
typecheck::{
|
|
type_inferencer::PrimitiveStore,
|
|
typedef::{FunSignature, Type, Unifier, VarMap},
|
|
},
|
|
};
|
|
use nac3parser::{
|
|
ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
|
|
parser,
|
|
};
|
|
use parking_lot::{Mutex, RwLock};
|
|
use std::collections::HashSet;
|
|
use std::num::NonZeroUsize;
|
|
use std::{collections::HashMap, fs, path::Path, sync::Arc};
|
|
|
|
mod basic_symbol_resolver;
|
|
use basic_symbol_resolver::*;
|
|
|
|
/// Command-line argument parser definition.
|
|
#[derive(Parser)]
|
|
#[command(author, version, about, long_about = None)]
|
|
struct CommandLineArgs {
|
|
/// The name of the input file.
|
|
file_name: String,
|
|
|
|
/// The number of threads allocated to processing the source file. If 0 is passed to this
|
|
/// parameter, all available threads will be used for compilation.
|
|
#[arg(short = 'T', default_value_t = 1)]
|
|
threads: u32,
|
|
|
|
/// The level to optimize the LLVM IR.
|
|
#[arg(short = 'O', default_value_t = 2, value_parser = clap::value_parser!(u32).range(0..=3))]
|
|
opt_level: u32,
|
|
|
|
/// Whether to emit LLVM IR at the end of every module.
|
|
///
|
|
/// If multithreaded compilation is also enabled, each thread will emit its own module.
|
|
#[arg(long, default_value_t = false)]
|
|
emit_llvm: bool,
|
|
|
|
/// The target triple to compile for.
|
|
#[arg(long)]
|
|
triple: Option<String>,
|
|
|
|
/// The target CPU to compile for.
|
|
#[arg(long)]
|
|
mcpu: Option<String>,
|
|
|
|
/// Additional target features to enable/disable, specified using the `+`/`-` prefixes.
|
|
#[arg(long)]
|
|
target_features: Option<String>,
|
|
}
|
|
|
|
fn handle_typevar_definition(
|
|
var: &Expr,
|
|
resolver: &(dyn SymbolResolver + Send + Sync),
|
|
def_list: &[Arc<RwLock<TopLevelDef>>],
|
|
unifier: &mut Unifier,
|
|
primitives: &PrimitiveStore,
|
|
) -> Result<Type, HashSet<String>> {
|
|
let ExprKind::Call { func, args, .. } = &var.node else {
|
|
return Err(HashSet::from([format!(
|
|
"expression {var:?} cannot be handled as a generic parameter in global scope"
|
|
)]));
|
|
};
|
|
|
|
match &func.node {
|
|
ExprKind::Name { id, .. } if id == &"TypeVar".into() => {
|
|
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
|
|
return Err(HashSet::from([format!(
|
|
"Expected string constant for first parameter of `TypeVar`, got {:?}",
|
|
&args[0].node
|
|
)]));
|
|
};
|
|
let generic_name: StrRef = ty_name.to_string().into();
|
|
|
|
let constraints = args
|
|
.iter()
|
|
.skip(1)
|
|
.map(|x| -> Result<Type, HashSet<String>> {
|
|
let ty = parse_ast_to_type_annotation_kinds(
|
|
resolver,
|
|
def_list,
|
|
unifier,
|
|
primitives,
|
|
x,
|
|
HashMap::new(),
|
|
)?;
|
|
get_type_from_type_annotation_kinds(
|
|
def_list, unifier, primitives, &ty, &mut None,
|
|
)
|
|
})
|
|
.collect::<Result<Vec<_>, _>>()?;
|
|
let loc = func.location;
|
|
|
|
if constraints.len() == 1 {
|
|
return Err(HashSet::from([format!(
|
|
"A single constraint is not allowed (at {loc})"
|
|
)]));
|
|
}
|
|
|
|
Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).ty)
|
|
}
|
|
|
|
ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => {
|
|
if args.len() != 2 {
|
|
return Err(HashSet::from([format!(
|
|
"Expected 2 arguments for `ConstGeneric`, got {}",
|
|
args.len()
|
|
)]));
|
|
}
|
|
|
|
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
|
|
return Err(HashSet::from([format!(
|
|
"Expected string constant for first parameter of `ConstGeneric`, got {:?}",
|
|
&args[0].node
|
|
)]));
|
|
};
|
|
let generic_name: StrRef = ty_name.to_string().into();
|
|
|
|
let ty = parse_ast_to_type_annotation_kinds(
|
|
resolver,
|
|
def_list,
|
|
unifier,
|
|
primitives,
|
|
&args[1],
|
|
HashMap::new(),
|
|
)?;
|
|
let constraint =
|
|
get_type_from_type_annotation_kinds(def_list, unifier, primitives, &ty, &mut None)?;
|
|
let loc = func.location;
|
|
|
|
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty)
|
|
}
|
|
|
|
_ => Err(HashSet::from([format!(
|
|
"expression {var:?} cannot be handled as a generic parameter in global scope"
|
|
)])),
|
|
}
|
|
}
|
|
|
|
fn handle_assignment_pattern(
|
|
targets: &[Expr],
|
|
value: &Expr,
|
|
resolver: &(dyn SymbolResolver + Send + Sync),
|
|
internal_resolver: &ResolverInternal,
|
|
def_list: &[Arc<RwLock<TopLevelDef>>],
|
|
unifier: &mut Unifier,
|
|
primitives: &PrimitiveStore,
|
|
) -> Result<(), String> {
|
|
if targets.len() == 1 {
|
|
match &targets[0].node {
|
|
ExprKind::Name { id, .. } => {
|
|
if let Ok(var) =
|
|
handle_typevar_definition(value, resolver, def_list, unifier, primitives)
|
|
{
|
|
internal_resolver.add_id_type(*id, var);
|
|
Ok(())
|
|
} else if let Ok(val) = parse_parameter_default_value(value, resolver) {
|
|
internal_resolver.add_module_global(*id, val);
|
|
Ok(())
|
|
} else {
|
|
Err(format!("fails to evaluate this expression `{:?}` as a constant or generic parameter at {}",
|
|
targets[0].node,
|
|
targets[0].location,
|
|
))
|
|
}
|
|
}
|
|
ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
|
|
handle_assignment_pattern(
|
|
elts,
|
|
value,
|
|
resolver,
|
|
internal_resolver,
|
|
def_list,
|
|
unifier,
|
|
primitives,
|
|
)?;
|
|
Ok(())
|
|
}
|
|
_ => Err(format!(
|
|
"assignment to {:?} is not supported at {}",
|
|
targets[0], targets[0].location
|
|
)),
|
|
}
|
|
} else {
|
|
match &value.node {
|
|
ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
|
|
if elts.len() == targets.len() {
|
|
for (tar, val) in targets.iter().zip(elts) {
|
|
handle_assignment_pattern(
|
|
std::slice::from_ref(tar),
|
|
val,
|
|
resolver,
|
|
internal_resolver,
|
|
def_list,
|
|
unifier,
|
|
primitives,
|
|
)?;
|
|
}
|
|
Ok(())
|
|
} else {
|
|
Err(format!(
|
|
"number of elements to unpack does not match (expect {}, found {}) at {}",
|
|
targets.len(),
|
|
elts.len(),
|
|
value.location
|
|
))
|
|
}
|
|
}
|
|
_ => Err(format!("unpack of this expression is not supported at {}", value.location)),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn main() {
|
|
let cli = CommandLineArgs::parse();
|
|
let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } =
|
|
cli;
|
|
|
|
Target::initialize_all(&InitializationConfig::default());
|
|
|
|
let host_target_machine = CodeGenTargetMachineOptions::from_host();
|
|
let triple = triple.unwrap_or(host_target_machine.triple.clone());
|
|
let mcpu = mcpu
|
|
.map(|arg| if arg == "native" { host_target_machine.cpu.clone() } else { arg })
|
|
.unwrap_or_default();
|
|
let target_features = target_features.unwrap_or_default();
|
|
let threads = if is_multithreaded() {
|
|
if threads == 0 {
|
|
std::thread::available_parallelism().map(NonZeroUsize::get).unwrap_or(1usize)
|
|
} else {
|
|
threads as usize
|
|
}
|
|
} else {
|
|
if threads != 1 {
|
|
println!("Warning: Number of threads specified in command-line but multithreading is disabled in LLVM at build time! Defaulting to single-threaded compilation");
|
|
}
|
|
1
|
|
};
|
|
let opt_level = match opt_level {
|
|
0 => OptimizationLevel::None,
|
|
1 => OptimizationLevel::Less,
|
|
2 => OptimizationLevel::Default,
|
|
// The default behavior for -O<n> where n>3 defaults to O3 for both Clang and GCC
|
|
_ => OptimizationLevel::Aggressive,
|
|
};
|
|
|
|
let target_machine_options = CodeGenTargetMachineOptions {
|
|
triple,
|
|
cpu: mcpu,
|
|
features: target_features,
|
|
reloc_mode: RelocMode::PIC,
|
|
..host_target_machine
|
|
};
|
|
|
|
let size_t = Context::create()
|
|
.ptr_sized_int_type(
|
|
&target_machine_options
|
|
.create_target_machine(opt_level)
|
|
.map(|tm| tm.get_target_data())
|
|
.unwrap(),
|
|
None,
|
|
)
|
|
.get_bit_width();
|
|
|
|
let program = match fs::read_to_string(file_name.clone()) {
|
|
Ok(program) => program,
|
|
Err(err) => {
|
|
println!("Cannot open input file: {err}");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let primitive: PrimitiveStore = TopLevelComposer::make_primitives(size_t).0;
|
|
let (mut composer, builtins_def, builtins_ty) =
|
|
TopLevelComposer::new(vec![], vec![], ComposerConfig::default(), size_t);
|
|
|
|
let internal_resolver: Arc<ResolverInternal> = ResolverInternal {
|
|
id_to_type: builtins_ty.into(),
|
|
id_to_def: builtins_def.into(),
|
|
class_names: Mutex::default(),
|
|
module_globals: Mutex::default(),
|
|
str_store: Mutex::default(),
|
|
}
|
|
.into();
|
|
let resolver =
|
|
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
|
|
|
|
let context = inkwell::context::Context::create();
|
|
|
|
// Process IRRT
|
|
let irrt = load_irrt(&context);
|
|
setup_irrt_exceptions(&context, &irrt, resolver.as_ref());
|
|
if emit_llvm {
|
|
irrt.write_bitcode_to_path(Path::new("irrt.bc"));
|
|
}
|
|
|
|
// Process the Python script
|
|
let parser_result = parser::parse_program(&program, file_name.into()).unwrap();
|
|
|
|
for stmt in parser_result {
|
|
match &stmt.node {
|
|
StmtKind::Assign { targets, value, .. } => {
|
|
let def_list = composer.extract_def_list();
|
|
let unifier = &mut composer.unifier;
|
|
let primitives = &composer.primitives_ty;
|
|
if let Err(err) = handle_assignment_pattern(
|
|
targets,
|
|
value,
|
|
resolver.as_ref(),
|
|
internal_resolver.as_ref(),
|
|
&def_list,
|
|
unifier,
|
|
primitives,
|
|
) {
|
|
eprintln!("{err}");
|
|
return;
|
|
}
|
|
}
|
|
// allow (and ignore) "from __future__ import annotations"
|
|
StmtKind::ImportFrom { module, names, .. }
|
|
if module == &Some("__future__".into())
|
|
&& names.len() == 1
|
|
&& names[0].name == "annotations".into() => {}
|
|
_ => {
|
|
let (name, def_id, ty) = composer
|
|
.register_top_level(stmt, Some(resolver.clone()), "__main__", true)
|
|
.unwrap();
|
|
internal_resolver.add_id_def(name, def_id);
|
|
if let Some(ty) = ty {
|
|
internal_resolver.add_id_type(name, ty);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let signature = FunSignature { args: vec![], ret: primitive.int32, vars: VarMap::new() };
|
|
let mut store = ConcreteTypeStore::new();
|
|
let mut cache = HashMap::new();
|
|
let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache);
|
|
let signature = store.add_cty(signature);
|
|
|
|
if let Err(errors) = composer.start_analysis(true) {
|
|
let error_count = errors.len();
|
|
eprintln!("{error_count} error(s) occurred during top level analysis.");
|
|
|
|
for (error_i, error) in errors.iter().enumerate() {
|
|
let error_num = error_i + 1;
|
|
eprintln!("=========== ERROR {error_num}/{error_count} ============");
|
|
eprintln!("{error}");
|
|
}
|
|
eprintln!("==================================");
|
|
|
|
panic!("top level analysis failed");
|
|
}
|
|
|
|
let top_level = Arc::new(composer.make_top_level_context());
|
|
|
|
let instance = {
|
|
let defs = top_level.definitions.read();
|
|
let mut instance = defs[resolver
|
|
.get_identifier_def("run".into())
|
|
.unwrap_or_else(|_| panic!("cannot find run() entry point"))
|
|
.0]
|
|
.write();
|
|
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance
|
|
else {
|
|
unreachable!()
|
|
};
|
|
instance_to_symbol.insert(String::new(), "run".to_string());
|
|
instance_to_stmt[""].clone()
|
|
};
|
|
|
|
let llvm_options = CodeGenLLVMOptions { opt_level, target: target_machine_options };
|
|
|
|
let task = CodeGenTask {
|
|
subst: Vec::default(),
|
|
symbol_name: "run".to_string(),
|
|
body: instance.body,
|
|
signature,
|
|
resolver,
|
|
store,
|
|
unifier_index: instance.unifier_id,
|
|
calls: instance.calls,
|
|
id: 0,
|
|
};
|
|
|
|
let membuffers: Arc<Mutex<Vec<Vec<u8>>>> = Arc::default();
|
|
let membuffer = membuffers.clone();
|
|
|
|
let f = Arc::new(WithCall::new(Box::new(move |module| {
|
|
let buffer = module.write_bitcode_to_memory();
|
|
let buffer = buffer.as_slice().into();
|
|
membuffer.lock().push(buffer);
|
|
})));
|
|
let threads = (0..threads)
|
|
.map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), size_t)))
|
|
.collect();
|
|
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
|
|
registry.add_task(task);
|
|
registry.wait_tasks_complete(handles);
|
|
|
|
// Link all modules together into `main`
|
|
let buffers = membuffers.lock();
|
|
let main = context
|
|
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
|
|
.unwrap();
|
|
if emit_llvm {
|
|
main.write_bitcode_to_path(Path::new("main.bc"));
|
|
}
|
|
|
|
for (idx, buffer) in buffers.iter().skip(1).enumerate() {
|
|
let other = context
|
|
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
|
|
.unwrap();
|
|
|
|
if emit_llvm {
|
|
other.write_bitcode_to_path(Path::new(&format!("module{idx}.bc")));
|
|
}
|
|
|
|
main.link_in_module(other).unwrap();
|
|
}
|
|
|
|
main.link_in_module(irrt).unwrap();
|
|
|
|
// Private all functions except "run"
|
|
let mut function_iter = main.get_first_function();
|
|
while let Some(func) = function_iter {
|
|
if func.count_basic_blocks() > 0 && func.get_name().to_str().unwrap() != "run" {
|
|
func.set_linkage(inkwell::module::Linkage::Private);
|
|
}
|
|
function_iter = func.get_next_function();
|
|
}
|
|
|
|
// Optimize `main`
|
|
let target_machine = llvm_options
|
|
.target
|
|
.create_target_machine(llvm_options.opt_level)
|
|
.expect("couldn't create target machine");
|
|
|
|
let pass_options = PassBuilderOptions::create();
|
|
pass_options.set_merge_functions(true);
|
|
let passes = format!("default<O{}>", opt_level as u32);
|
|
let result = main.run_passes(passes.as_str(), &target_machine, pass_options);
|
|
if let Err(err) = result {
|
|
panic!("Failed to run optimization for module `main`: {}", err.to_string());
|
|
}
|
|
|
|
// Write output
|
|
target_machine
|
|
.write_to_file(&main, FileType::Object, Path::new("module.o"))
|
|
.expect("couldn't write module to file");
|
|
}
|