CodeGenerator: Add with_target_machine factory function

Allows creating CodeGenerator with the LLVM target machine to infer the
expected type for size_t.
This commit is contained in:
David Mak 2025-01-13 14:55:33 +08:00
parent 8baf111734
commit d1dcfa19ff
5 changed files with 56 additions and 18 deletions

View File

@ -29,6 +29,7 @@ use nac3core::{
inkwell::{
context::Context,
module::Linkage,
targets::TargetMachine,
types::{BasicType, IntType},
values::{BasicValueEnum, IntValue, PointerValue, StructValue},
AddressSpace, IntPredicate, OptimizationLevel,
@ -87,13 +88,13 @@ pub struct ArtiqCodeGenerator<'a> {
impl<'a> ArtiqCodeGenerator<'a> {
pub fn new(
name: String,
size_t: u32,
size_t: IntType<'_>,
timeline: &'a (dyn TimeFns + Sync),
) -> ArtiqCodeGenerator<'a> {
assert!(size_t == 32 || size_t == 64);
assert!(matches!(size_t.get_bit_width(), 32 | 64));
ArtiqCodeGenerator {
name,
size_t,
size_t: size_t.get_bit_width(),
name_counter: 0,
start: None,
end: None,
@ -102,6 +103,17 @@ impl<'a> ArtiqCodeGenerator<'a> {
}
}
#[must_use]
pub fn with_target_machine(
name: String,
ctx: &Context,
target_machine: &TargetMachine,
timeline: &'a (dyn TimeFns + Sync),
) -> ArtiqCodeGenerator<'a> {
let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None);
Self::new(name, llvm_usize, timeline)
}
/// If the generator is currently in a direct-`parallel` block context, emits IR that resets the
/// position of the timeline to the initial timeline position before entering the `parallel`
/// block.

View File

@ -703,14 +703,18 @@ impl Nac3 {
let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer);
})));
let size_t = context
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
.get_bit_width();
let num_threads = if is_multithreaded() { 4 } else { 1 };
let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect();
let threads: Vec<_> = thread_names
.iter()
.map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns)))
.map(|s| {
Box::new(ArtiqCodeGenerator::with_target_machine(
s.to_string(),
&context,
&self.get_llvm_target_machine(),
self.time_fns,
))
})
.collect();
let membuffer = membuffers.clone();
@ -719,8 +723,13 @@ impl Nac3 {
let (registry, handles) =
WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
let mut generator = ArtiqCodeGenerator::new("main".to_string(), size_t, self.time_fns);
let context = Context::create();
let mut generator = ArtiqCodeGenerator::with_target_machine(
"main".to_string(),
&context,
&self.get_llvm_target_machine(),
self.time_fns,
);
let module = context.create_module("main");
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());

View File

@ -1,5 +1,6 @@
use inkwell::{
context::Context,
targets::TargetMachine,
types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
};
@ -270,19 +271,27 @@ pub struct DefaultCodeGenerator {
impl DefaultCodeGenerator {
#[must_use]
pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator {
assert!(matches!(size_t, 32 | 64));
DefaultCodeGenerator { name, size_t }
pub fn new(name: String, size_t: IntType<'_>) -> DefaultCodeGenerator {
assert!(matches!(size_t.get_bit_width(), 32 | 64));
DefaultCodeGenerator { name, size_t: size_t.get_bit_width() }
}
#[must_use]
pub fn with_target_machine(
name: String,
ctx: &Context,
target_machine: &TargetMachine,
) -> DefaultCodeGenerator {
let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None);
Self::new(name, llvm_usize)
}
}
impl CodeGenerator for DefaultCodeGenerator {
/// Returns the name for this [`CodeGenerator`].
fn get_name(&self) -> &str {
&self.name
}
/// Returns an LLVM integer type representing `size_t`.
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
// it should be unsigned, but we don't really need unsigned and this could save us from
// having to do a bit cast...

View File

@ -97,6 +97,7 @@ fn test_primitives() {
"};
let statements = parse_program(source, FileName::default()).unwrap();
let context = inkwell::context::Context::create();
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
@ -107,7 +108,7 @@ fn test_primitives() {
Arc::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) })
as Arc<dyn SymbolResolver + Send + Sync>;
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let threads = vec![DefaultCodeGenerator::new("test".into(), context.i32_type()).into()];
let signature = FunSignature {
args: vec![
FuncArg {
@ -260,6 +261,7 @@ fn test_simple_call() {
"};
let statements_2 = parse_program(source_2, FileName::default()).unwrap();
let context = inkwell::context::Context::create();
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
@ -307,7 +309,7 @@ fn test_simple_call() {
unreachable!()
}
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let threads = vec![DefaultCodeGenerator::new("test".into(), context.i32_type()).into()];
let mut function_data = FunctionData {
resolver: resolver.clone(),
bound_variables: Vec::new(),
@ -439,7 +441,7 @@ fn test_simple_call() {
#[test]
fn test_classes_list_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type());
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
@ -459,7 +461,7 @@ fn test_classes_range_type_new() {
#[test]
fn test_classes_ndarray_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type());
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);

View File

@ -456,7 +456,13 @@ fn main() {
membuffer.lock().push(buffer);
})));
let threads = (0..threads)
.map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), size_t)))
.map(|i| {
Box::new(DefaultCodeGenerator::with_target_machine(
format!("module{i}"),
&context,
&target_machine,
))
})
.collect();
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task);