From d1dcfa19ff5696a8a109756398d0fcb3fc970536 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 13 Jan 2025 14:55:33 +0800 Subject: [PATCH 01/49] CodeGenerator: Add with_target_machine factory function Allows creating CodeGenerator with the LLVM target machine to infer the expected type for size_t. --- nac3artiq/src/codegen.rs | 18 +++++++++++++++--- nac3artiq/src/lib.rs | 19 ++++++++++++++----- nac3core/src/codegen/generator.rs | 19 ++++++++++++++----- nac3core/src/codegen/test.rs | 10 ++++++---- nac3standalone/src/main.rs | 8 +++++++- 5 files changed, 56 insertions(+), 18 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 9baa0af..daf539f 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -29,6 +29,7 @@ use nac3core::{ inkwell::{ context::Context, module::Linkage, + targets::TargetMachine, types::{BasicType, IntType}, values::{BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, OptimizationLevel, @@ -87,13 +88,13 @@ pub struct ArtiqCodeGenerator<'a> { impl<'a> ArtiqCodeGenerator<'a> { pub fn new( name: String, - size_t: u32, + size_t: IntType<'_>, timeline: &'a (dyn TimeFns + Sync), ) -> ArtiqCodeGenerator<'a> { - assert!(size_t == 32 || size_t == 64); + assert!(matches!(size_t.get_bit_width(), 32 | 64)); ArtiqCodeGenerator { name, - size_t, + size_t: size_t.get_bit_width(), name_counter: 0, start: None, end: None, @@ -102,6 +103,17 @@ impl<'a> ArtiqCodeGenerator<'a> { } } + #[must_use] + pub fn with_target_machine( + name: String, + ctx: &Context, + target_machine: &TargetMachine, + timeline: &'a (dyn TimeFns + Sync), + ) -> ArtiqCodeGenerator<'a> { + let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None); + Self::new(name, llvm_usize, timeline) + } + /// If the generator is currently in a direct-`parallel` block context, emits IR that resets the /// position of the timeline to the initial timeline position before entering the `parallel` /// block. diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 1601d4a..9a69c1a 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -703,14 +703,18 @@ impl Nac3 { let buffer = buffer.as_slice().into(); membuffer.lock().push(buffer); }))); - let size_t = context - .ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None) - .get_bit_width(); let num_threads = if is_multithreaded() { 4 } else { 1 }; let thread_names: Vec = (0..num_threads).map(|_| "main".to_string()).collect(); let threads: Vec<_> = thread_names .iter() - .map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns))) + .map(|s| { + Box::new(ArtiqCodeGenerator::with_target_machine( + s.to_string(), + &context, + &self.get_llvm_target_machine(), + self.time_fns, + )) + }) .collect(); let membuffer = membuffers.clone(); @@ -719,8 +723,13 @@ impl Nac3 { let (registry, handles) = WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f); - let mut generator = ArtiqCodeGenerator::new("main".to_string(), size_t, self.time_fns); let context = Context::create(); + let mut generator = ArtiqCodeGenerator::with_target_machine( + "main".to_string(), + &context, + &self.get_llvm_target_machine(), + self.time_fns, + ); let module = context.create_module("main"); let target_machine = self.llvm_options.create_target_machine().unwrap(); module.set_data_layout(&target_machine.get_target_data().get_data_layout()); diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index be007c2..a416f10 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -1,5 +1,6 @@ use inkwell::{ context::Context, + targets::TargetMachine, types::{BasicTypeEnum, IntType}, values::{BasicValueEnum, IntValue, PointerValue}, }; @@ -270,19 +271,27 @@ pub struct DefaultCodeGenerator { impl DefaultCodeGenerator { #[must_use] - pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator { - assert!(matches!(size_t, 32 | 64)); - DefaultCodeGenerator { name, size_t } + pub fn new(name: String, size_t: IntType<'_>) -> DefaultCodeGenerator { + assert!(matches!(size_t.get_bit_width(), 32 | 64)); + DefaultCodeGenerator { name, size_t: size_t.get_bit_width() } + } + + #[must_use] + pub fn with_target_machine( + name: String, + ctx: &Context, + target_machine: &TargetMachine, + ) -> DefaultCodeGenerator { + let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None); + Self::new(name, llvm_usize) } } impl CodeGenerator for DefaultCodeGenerator { - /// Returns the name for this [`CodeGenerator`]. fn get_name(&self) -> &str { &self.name } - /// Returns an LLVM integer type representing `size_t`. fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> { // it should be unsigned, but we don't really need unsigned and this could save us from // having to do a bit cast... diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 2701e13..48bef5f 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -97,6 +97,7 @@ fn test_primitives() { "}; let statements = parse_program(source, FileName::default()).unwrap(); + let context = inkwell::context::Context::create(); let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0; let mut unifier = composer.unifier.clone(); let primitives = composer.primitives_ty; @@ -107,7 +108,7 @@ fn test_primitives() { Arc::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) }) as Arc; - let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; + let threads = vec![DefaultCodeGenerator::new("test".into(), context.i32_type()).into()]; let signature = FunSignature { args: vec![ FuncArg { @@ -260,6 +261,7 @@ fn test_simple_call() { "}; let statements_2 = parse_program(source_2, FileName::default()).unwrap(); + let context = inkwell::context::Context::create(); let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0; let mut unifier = composer.unifier.clone(); let primitives = composer.primitives_ty; @@ -307,7 +309,7 @@ fn test_simple_call() { unreachable!() } - let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; + let threads = vec![DefaultCodeGenerator::new("test".into(), context.i32_type()).into()]; let mut function_data = FunctionData { resolver: resolver.clone(), bound_variables: Vec::new(), @@ -439,7 +441,7 @@ fn test_simple_call() { #[test] fn test_classes_list_type_new() { let ctx = inkwell::context::Context::create(); - let generator = DefaultCodeGenerator::new(String::new(), 64); + let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); @@ -459,7 +461,7 @@ fn test_classes_range_type_new() { #[test] fn test_classes_ndarray_type_new() { let ctx = inkwell::context::Context::create(); - let generator = DefaultCodeGenerator::new(String::new(), 64); + let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 2fce5d1..d54e08e 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -456,7 +456,13 @@ fn main() { membuffer.lock().push(buffer); }))); let threads = (0..threads) - .map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), size_t))) + .map(|i| { + Box::new(DefaultCodeGenerator::with_target_machine( + format!("module{i}"), + &context, + &target_machine, + )) + }) .collect(); let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f); registry.add_task(task); From 3ebd4ba5d14277a00e953f4dc354fb2fe230156b Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 13 Jan 2025 14:56:22 +0800 Subject: [PATCH 02/49] [core] codegen: Add assertion verifying size_t is compatible --- nac3core/src/codegen/mod.rs | 11 +++++++++++ nac3core/src/codegen/test.rs | 8 ++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index e74071b..28b9a65 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -989,6 +989,17 @@ pub fn gen_func_impl< debug_info: (dibuilder, compile_unit, func_scope.as_debug_info_scope()), }; + let target_llvm_usize = context.ptr_sized_int_type( + ®istry.llvm_options.create_target_machine().map(|tm| tm.get_target_data()).unwrap(), + None, + ); + let generator_llvm_usize = generator.get_size_type(context); + assert_eq!( + generator_llvm_usize, + target_llvm_usize, + "CodeGenerator (size_t = {generator_llvm_usize}) is not compatible with CodeGen Target (size_t = {target_llvm_usize})", + ); + let loc = code_gen_context.debug_info.0.create_debug_location( context, row as u32, diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 48bef5f..6518d85 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -98,7 +98,7 @@ fn test_primitives() { let statements = parse_program(source, FileName::default()).unwrap(); let context = inkwell::context::Context::create(); - let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0; + let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0; let mut unifier = composer.unifier.clone(); let primitives = composer.primitives_ty; let top_level = Arc::new(composer.make_top_level_context()); @@ -108,7 +108,7 @@ fn test_primitives() { Arc::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) }) as Arc; - let threads = vec![DefaultCodeGenerator::new("test".into(), context.i32_type()).into()]; + let threads = vec![DefaultCodeGenerator::new("test".into(), context.i64_type()).into()]; let signature = FunSignature { args: vec![ FuncArg { @@ -262,7 +262,7 @@ fn test_simple_call() { let statements_2 = parse_program(source_2, FileName::default()).unwrap(); let context = inkwell::context::Context::create(); - let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0; + let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0; let mut unifier = composer.unifier.clone(); let primitives = composer.primitives_ty; let top_level = Arc::new(composer.make_top_level_context()); @@ -309,7 +309,7 @@ fn test_simple_call() { unreachable!() } - let threads = vec![DefaultCodeGenerator::new("test".into(), context.i32_type()).into()]; + let threads = vec![DefaultCodeGenerator::new("test".into(), context.i64_type()).into()]; let mut function_data = FunctionData { resolver: resolver.clone(), bound_variables: Vec::new(), From f8530e0ef694c7d8ac7690ecd9eab838b5d59831 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 13 Jan 2025 20:26:15 +0800 Subject: [PATCH 03/49] [core] codegen: Add CodeGenContext::get_size_type Convenience method for getting the `size_t` LLVM type without the use of `CodeGenerator`. --- nac3core/src/codegen/generator.rs | 3 +++ nac3core/src/codegen/mod.rs | 25 +++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index a416f10..620ede0 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -19,6 +19,9 @@ pub trait CodeGenerator { fn get_name(&self) -> &str; /// Return an instance of [`IntType`] corresponding to the type of `size_t` for this instance. + /// + /// Prefer using [`CodeGenContext::get_size_type`] if [`CodeGenContext`] is available, as it is + /// equivalent to this function in a more concise syntax. fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>; /// Generate function call and returns the function return value. diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 28b9a65..797a62b 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,4 +1,5 @@ use std::{ + cell::OnceCell, collections::{HashMap, HashSet}, sync::{ atomic::{AtomicBool, Ordering}, @@ -19,7 +20,7 @@ use inkwell::{ module::Module, passes::PassBuilderOptions, targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple}, - types::{AnyType, BasicType, BasicTypeEnum}, + types::{AnyType, BasicType, BasicTypeEnum, IntType}, values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; @@ -226,14 +227,33 @@ pub struct CodeGenContext<'ctx, 'a> { /// The current source location. pub current_loc: Location, + + /// The cached type of `size_t`. + llvm_usize: OnceCell>, } -impl CodeGenContext<'_, '_> { +impl<'ctx> CodeGenContext<'ctx, '_> { /// Whether the [current basic block][Builder::get_insert_block] referenced by `builder` /// contains a [terminator statement][BasicBlock::get_terminator]. pub fn is_terminated(&self) -> bool { self.builder.get_insert_block().and_then(BasicBlock::get_terminator).is_some() } + + /// Returns a [`IntType`] representing `size_t` for the compilation target as specified by + /// [`self.registry`][WorkerRegistry]. + pub fn get_size_type(&self) -> IntType<'ctx> { + *self.llvm_usize.get_or_init(|| { + self.ctx.ptr_sized_int_type( + &self + .registry + .llvm_options + .create_target_machine() + .map(|tm| tm.get_target_data()) + .unwrap(), + None, + ) + }) + } } type Fp = Box; @@ -987,6 +1007,7 @@ pub fn gen_func_impl< need_sret: has_sret, current_loc: Location::default(), debug_info: (dibuilder, compile_unit, func_scope.as_debug_info_scope()), + llvm_usize: OnceCell::default(), }; let target_llvm_usize = context.ptr_sized_int_type( From c59fd286ff166c7f1e20f8e8fbaf80666f815faa Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 13 Jan 2025 20:43:57 +0800 Subject: [PATCH 04/49] [artiq] Move `get_llvm_*` to Isa, use `TargetMachine` to infer size_t --- nac3artiq/src/lib.rs | 113 +++++++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 53 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 9a69c1a..d35e66d 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -78,14 +78,62 @@ enum Isa { } impl Isa { - /// Returns the number of bits in `size_t` for the [`Isa`]. - fn get_size_type(self) -> u32 { - if self == Isa::Host { - 64u32 - } else { - 32u32 + /// Returns the [`TargetTriple`] used for compiling to this ISA. + pub fn get_llvm_target_triple(self) -> TargetTriple { + match self { + Isa::Host => TargetMachine::get_default_triple(), + Isa::RiscV32G | Isa::RiscV32IMA => TargetTriple::create("riscv32-unknown-linux"), + Isa::CortexA9 => TargetTriple::create("armv7-unknown-linux-gnueabihf"), } } + + /// Returns the [`String`] representing the target CPU used for compiling to this ISA. + pub fn get_llvm_target_cpu(self) -> String { + match self { + Isa::Host => TargetMachine::get_host_cpu_name().to_string(), + Isa::RiscV32G | Isa::RiscV32IMA => "generic-rv32".to_string(), + Isa::CortexA9 => "cortex-a9".to_string(), + } + } + + /// Returns the [`String`] representing the target features used for compiling to this ISA. + pub fn get_llvm_target_features(self) -> String { + match self { + Isa::Host => TargetMachine::get_host_cpu_features().to_string(), + Isa::RiscV32G => "+a,+m,+f,+d".to_string(), + Isa::RiscV32IMA => "+a,+m".to_string(), + Isa::CortexA9 => "+dsp,+fp16,+neon,+vfp3,+long-calls".to_string(), + } + } + + /// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine + /// options used for compiling to this ISA. + pub fn get_llvm_target_options(self) -> CodeGenTargetMachineOptions { + CodeGenTargetMachineOptions { + triple: self.get_llvm_target_triple().as_str().to_string_lossy().into_owned(), + cpu: self.get_llvm_target_cpu(), + features: self.get_llvm_target_features(), + reloc_mode: RelocMode::PIC, + ..CodeGenTargetMachineOptions::from_host() + } + } + + /// Returns an instance of [`TargetMachine`] used in compiling and linking of a program of this + /// ISA. + pub fn create_llvm_target_machine(self, opt_level: OptimizationLevel) -> TargetMachine { + self.get_llvm_target_options() + .create_target_machine(opt_level) + .expect("couldn't create target machine") + } + + /// Returns the number of bits in `size_t` for this ISA. + fn get_size_type(self, ctx: &Context) -> u32 { + ctx.ptr_sized_int_type( + &self.create_llvm_target_machine(OptimizationLevel::Default).get_target_data(), + None, + ) + .get_bit_width() + } } #[derive(Clone)] @@ -378,7 +426,7 @@ impl Nac3 { py: Python, link_fn: &dyn Fn(&Module) -> PyResult, ) -> PyResult { - let size_t = self.isa.get_size_type(); + let size_t = self.isa.get_size_type(&Context::create()); let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new( self.builtins.clone(), Self::get_lateinit_builtins(), @@ -848,52 +896,10 @@ impl Nac3 { link_fn(&main) } - /// Returns the [`TargetTriple`] used for compiling to [isa]. - fn get_llvm_target_triple(isa: Isa) -> TargetTriple { - match isa { - Isa::Host => TargetMachine::get_default_triple(), - Isa::RiscV32G | Isa::RiscV32IMA => TargetTriple::create("riscv32-unknown-linux"), - Isa::CortexA9 => TargetTriple::create("armv7-unknown-linux-gnueabihf"), - } - } - - /// Returns the [`String`] representing the target CPU used for compiling to [isa]. - fn get_llvm_target_cpu(isa: Isa) -> String { - match isa { - Isa::Host => TargetMachine::get_host_cpu_name().to_string(), - Isa::RiscV32G | Isa::RiscV32IMA => "generic-rv32".to_string(), - Isa::CortexA9 => "cortex-a9".to_string(), - } - } - - /// Returns the [`String`] representing the target features used for compiling to [isa]. - fn get_llvm_target_features(isa: Isa) -> String { - match isa { - Isa::Host => TargetMachine::get_host_cpu_features().to_string(), - Isa::RiscV32G => "+a,+m,+f,+d".to_string(), - Isa::RiscV32IMA => "+a,+m".to_string(), - Isa::CortexA9 => "+dsp,+fp16,+neon,+vfp3,+long-calls".to_string(), - } - } - - /// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine - /// options used for compiling to [isa]. - fn get_llvm_target_options(isa: Isa) -> CodeGenTargetMachineOptions { - CodeGenTargetMachineOptions { - triple: Nac3::get_llvm_target_triple(isa).as_str().to_string_lossy().into_owned(), - cpu: Nac3::get_llvm_target_cpu(isa), - features: Nac3::get_llvm_target_features(isa), - reloc_mode: RelocMode::PIC, - ..CodeGenTargetMachineOptions::from_host() - } - } - /// Returns an instance of [`TargetMachine`] used in compiling and linking of a program to the - /// target [isa]. + /// target [ISA][isa]. fn get_llvm_target_machine(&self) -> TargetMachine { - Nac3::get_llvm_target_options(self.isa) - .create_target_machine(self.llvm_options.opt_level) - .expect("couldn't create target machine") + self.isa.create_llvm_target_machine(self.llvm_options.opt_level) } } @@ -1001,7 +1007,8 @@ impl Nac3 { Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS, Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS, }; - let (primitive, _) = TopLevelComposer::make_primitives(isa.get_size_type()); + let (primitive, _) = + TopLevelComposer::make_primitives(isa.get_size_type(&Context::create())); let builtins = vec![ ( "now_mu".into(), @@ -1150,7 +1157,7 @@ impl Nac3 { deferred_eval_store: DeferredEvaluationStore::new(), llvm_options: CodeGenLLVMOptions { opt_level: OptimizationLevel::Default, - target: Nac3::get_llvm_target_options(isa), + target: isa.get_llvm_target_options(), }, }) } From bd66fe48d8acb011ba3b3dc1a6c507db91b56fb6 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 13 Jan 2025 21:05:27 +0800 Subject: [PATCH 05/49] [core] codegen: Refactor to use CodeGenContext::get_size_type Simplifies a lot of API usage. --- nac3artiq/src/codegen.rs | 10 +-- nac3artiq/src/symbol_resolver.rs | 4 +- nac3core/src/codegen/builtin_fns.rs | 18 ++-- nac3core/src/codegen/expr.rs | 48 +++++------ nac3core/src/codegen/irrt/list.rs | 4 +- nac3core/src/codegen/irrt/mod.rs | 8 +- nac3core/src/codegen/irrt/ndarray/array.rs | 18 ++-- nac3core/src/codegen/irrt/ndarray/basic.rs | 67 ++++++--------- .../src/codegen/irrt/ndarray/broadcast.rs | 9 +- nac3core/src/codegen/irrt/ndarray/indexing.rs | 2 +- nac3core/src/codegen/irrt/ndarray/iter.rs | 17 ++-- nac3core/src/codegen/irrt/ndarray/matmul.rs | 5 +- nac3core/src/codegen/irrt/ndarray/reshape.rs | 3 +- .../src/codegen/irrt/ndarray/transpose.rs | 4 +- nac3core/src/codegen/irrt/string.rs | 7 +- nac3core/src/codegen/mod.rs | 2 +- nac3core/src/codegen/numpy.rs | 18 ++-- nac3core/src/codegen/stmt.rs | 14 ++- nac3core/src/codegen/types/list.rs | 6 +- nac3core/src/codegen/types/ndarray/array.rs | 6 +- .../src/codegen/types/ndarray/contiguous.rs | 2 +- nac3core/src/codegen/types/ndarray/map.rs | 10 +-- nac3core/src/codegen/types/ndarray/mod.rs | 10 +-- nac3core/src/codegen/types/tuple.rs | 2 +- nac3core/src/codegen/values/array.rs | 2 +- nac3core/src/codegen/values/list.rs | 11 +-- .../src/codegen/values/ndarray/broadcast.rs | 6 +- .../src/codegen/values/ndarray/contiguous.rs | 4 +- .../src/codegen/values/ndarray/indexing.rs | 2 +- nac3core/src/codegen/values/ndarray/matmul.rs | 4 +- nac3core/src/codegen/values/ndarray/mod.rs | 85 ++++++------------- nac3core/src/codegen/values/ndarray/nditer.rs | 18 ++-- nac3core/src/codegen/values/ndarray/shape.rs | 2 +- nac3core/src/codegen/values/ndarray/view.rs | 8 +- nac3core/src/toplevel/builtins.rs | 6 +- 35 files changed, 176 insertions(+), 266 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index daf539f..cb75606 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -471,7 +471,7 @@ fn format_rpc_arg<'ctx>( // libproto_artiq: NDArray = [data[..], dim_sz[..]] let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let ndims = extract_ndims(&ctx.unifier, ndims); @@ -556,7 +556,7 @@ fn format_rpc_ret<'ctx>( let llvm_i32 = ctx.ctx.i32_type(); let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { @@ -697,7 +697,7 @@ fn format_rpc_ret<'ctx>( // debug_assert(nelems * sizeof(T) >= ndarray_nbytes) if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let num_elements = ndarray.size(generator, ctx); + let num_elements = ndarray.size(ctx); let expected_ndarray_nbytes = ctx.builder.build_int_mul(num_elements, itemsize, "").unwrap(); @@ -809,7 +809,7 @@ fn rpc_codegen_callback_fn<'ctx>( ) -> Result>, String> { let int8 = ctx.ctx.i8_type(); let int32 = ctx.ctx.i32_type(); - let size_type = generator.get_size_type(ctx.ctx); + let size_type = ctx.get_size_type(); let ptr_type = int8.ptr_type(AddressSpace::default()); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); @@ -1167,7 +1167,7 @@ fn polymorphic_print<'ctx>( let llvm_i32 = ctx.ctx.i32_type(); let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let suffix = suffix.unwrap_or_default(); diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 1e99ac2..d976866 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1007,7 +1007,7 @@ impl InnerResolver { } _ => unreachable!("must be list"), }; - let size_t = generator.get_size_type(ctx.ctx); + let size_t = ctx.get_size_type(); let ty = if len == 0 && matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. }) { @@ -1096,7 +1096,7 @@ impl InnerResolver { let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty); let dtype = llvm_ndarray.element_type(); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 9d36807..96f8c70 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -64,7 +64,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) .map_value(arg.into_pointer_value(), None); ctx.builder - .build_int_truncate_or_bit_cast(ndarray.len(generator, ctx), llvm_i32, "len") + .build_int_truncate_or_bit_cast(ndarray.len(ctx), llvm_i32, "len") .unwrap() } @@ -835,7 +835,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name)); let llvm_int64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); Ok(match a { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { @@ -870,7 +870,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let size_nez = ctx .builder - .build_int_compare(IntPredicate::NE, ndarray.size(generator, ctx), zero, "") + .build_int_compare(IntPredicate::NE, ndarray.size(ctx), zero, "") .unwrap(); ctx.make_assert( @@ -1676,7 +1676,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_qr"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; @@ -1728,7 +1728,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_svd"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; @@ -1821,7 +1821,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_pinv"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; @@ -1862,7 +1862,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "sp_linalg_lu"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; @@ -1915,7 +1915,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_matrix_power"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) @@ -1968,7 +1968,7 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_matrix_power"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index e74e7ad..00290d3 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -165,7 +165,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> { .build_global_string_ptr(v, "const") .map(|v| v.as_pointer_value().into()) .unwrap(); - let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); + let size = self.get_size_type().const_int(v.len() as u64, false); let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type(); ty.const_named_struct(&[str_ptr, size.into()]).into() } @@ -318,7 +318,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> { .build_global_string_ptr(v, "const") .map(|v| v.as_pointer_value().into()) .unwrap(); - let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); + let size = self.get_size_type().const_int(v.len() as u64, false); let ty = self.get_llvm_type(generator, self.primitives.str); let val = ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into(); @@ -820,7 +820,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( fun: (&FunSignature, DefinitionId), params: Vec<(Option, ValueEnum<'ctx>)>, ) -> Result>, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap(); let id; @@ -1020,7 +1020,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( } let is_vararg = args.iter().any(|arg| arg.is_vararg); if is_vararg { - params.push(generator.get_size_type(ctx.ctx).into()); + params.push(ctx.get_size_type().into()); } let fun_ty = match ret_type { Some(ret_type) if !has_sret => ret_type.fn_type(¶ms, is_vararg), @@ -1128,7 +1128,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( return Ok(None); }; let int32 = ctx.ctx.i32_type(); - let size_t = generator.get_size_type(ctx.ctx); + let size_t = ctx.get_size_type(); let zero_size_t = size_t.const_zero(); let zero_32 = int32.const_zero(); @@ -1258,15 +1258,13 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( } // Emits the content of `cont_bb` - let emit_cont_bb = - |ctx: &CodeGenContext<'ctx, '_>, generator: &dyn CodeGenerator, list: ListValue<'ctx>| { - ctx.builder.position_at_end(cont_bb); - list.store_size( - ctx, - generator, - ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap(), - ); - }; + let emit_cont_bb = |ctx: &CodeGenContext<'ctx, '_>, list: ListValue<'ctx>| { + ctx.builder.position_at_end(cont_bb); + list.store_size( + ctx, + ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap(), + ); + }; for cond in ifs { let result = if let Some(v) = generator.gen_expr(ctx, cond)? { @@ -1274,7 +1272,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( } else { // Bail if the predicate is an ellipsis - Emit cont_bb contents in case the // no element matches the predicate - emit_cont_bb(ctx, generator, list); + emit_cont_bb(ctx, list); return Ok(None); }; @@ -1287,7 +1285,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( let Some(elem) = generator.gen_expr(ctx, elt)? else { // Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents - emit_cont_bb(ctx, generator, list); + emit_cont_bb(ctx, list); return Ok(None); }; @@ -1304,7 +1302,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( .unwrap(); ctx.builder.build_unconditional_branch(test_bb).unwrap(); - emit_cont_bb(ctx, generator, list); + emit_cont_bb(ctx, list); Ok(Some(list.as_base_value().into())) } @@ -1350,7 +1348,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); if op.variant == BinopVariant::AugAssign { todo!("Augmented assignment operators not implemented for lists") @@ -1972,7 +1970,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( let rhs = rhs.into_struct_value(); let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(plhs, lhs).unwrap(); @@ -2000,7 +1998,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); - let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); + let result = call_string_eq(ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); if *op == Cmpop::NotEq { ctx.builder.build_not(result, "").unwrap() } else { @@ -2010,7 +2008,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( .iter() .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let gen_list_cmpop = |generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>| @@ -2375,7 +2373,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ) -> Result>, String> { ctx.current_loc = expr.location; let int32 = ctx.ctx.i32_type(); - let usize = generator.get_size_type(ctx.ctx); + let usize = ctx.get_size_type(); let zero = int32.const_int(0, false); let loc = ctx.debug_info.0.create_debug_location( @@ -2480,7 +2478,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { Some(elements[0].get_type()) }; - let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false); + let length = ctx.get_size_type().const_int(elements.len() as u64, false); let arr_str_ptr = if let Some(ty) = ty { ListType::new(generator, ctx.ctx, ty).construct( generator, @@ -3009,7 +3007,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( }; let raw_index = ctx .builder - .build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext") + .build_int_s_extend(raw_index, ctx.get_size_type(), "sext") .unwrap(); // handle negative index let is_negative = ctx @@ -3017,7 +3015,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .build_int_compare( IntPredicate::SLT, raw_index, - generator.get_size_type(ctx.ctx).const_zero(), + ctx.get_size_type().const_zero(), "is_neg", ) .unwrap(); diff --git a/nac3core/src/codegen/irrt/list.rs b/nac3core/src/codegen/irrt/list.rs index 2c57f8e..c01e2cb 100644 --- a/nac3core/src/codegen/irrt/list.rs +++ b/nac3core/src/codegen/irrt/list.rs @@ -24,7 +24,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( src_arr: ListValue<'ctx>, src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); let llvm_i32 = ctx.ctx.i32_type(); @@ -168,7 +168,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.position_at_end(update_bb); let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); - dest_arr.store_size(ctx, generator, new_len); + dest_arr.store_size(ctx, new_len); ctx.builder.build_unconditional_branch(cont_bb).unwrap(); ctx.builder.position_at_end(cont_bb); } diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 4cacdcc..8739178 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -68,13 +68,9 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) /// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`. /// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`. #[must_use] -pub fn get_usize_dependent_function_name( - generator: &G, - ctx: &CodeGenContext<'_, '_>, - name: &str, -) -> String { +pub fn get_usize_dependent_function_name(ctx: &CodeGenContext<'_, '_>, name: &str) -> String { let mut name = name.to_owned(); - match generator.get_size_type(ctx.ctx).get_bit_width() { + match ctx.get_size_type().get_bit_width() { 32 => {} 64 => name.push_str("64"), bit_width => { diff --git a/nac3core/src/codegen/irrt/ndarray/array.rs b/nac3core/src/codegen/irrt/ndarray/array.rs index 931b66c..5e9c0f0 100644 --- a/nac3core/src/codegen/irrt/ndarray/array.rs +++ b/nac3core/src/codegen/irrt/ndarray/array.rs @@ -21,7 +21,7 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato ndims: IntValue<'ctx>, shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); assert_eq!(ndims.get_type(), llvm_usize); assert_eq!( @@ -29,11 +29,8 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato llvm_usize.into() ); - let name = get_usize_dependent_function_name( - generator, - ctx, - "__nac3_ndarray_array_set_and_validate_list_shape", - ); + let name = + get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_set_and_validate_list_shape"); infer_and_call_function( ctx, @@ -55,19 +52,14 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato /// - `ndarray.ndims`: Must be initialized. /// - `ndarray.shape`: Must be initialized. /// - `ndarray.data`: Must be allocated and contiguous. -pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_array_write_list_to_array<'ctx>( ctx: &CodeGenContext<'ctx, '_>, list: ListValue<'ctx>, ndarray: NDArrayValue<'ctx>, ) { assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); - let name = get_usize_dependent_function_name( - generator, - ctx, - "__nac3_ndarray_array_write_list_to_array", - ); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_write_list_to_array"); infer_and_call_function( ctx, diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs index d11c9b8..aa792b1 100644 --- a/nac3core/src/codegen/irrt/ndarray/basic.rs +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -20,7 +20,7 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ctx: &CodeGenContext<'ctx, '_>, shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); assert_eq!( @@ -28,11 +28,8 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + llvm_usize.into() ); - let name = get_usize_dependent_function_name( - generator, - ctx, - "__nac3_ndarray_util_assert_shape_no_negative", - ); + let name = + get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative"); create_and_call_function( ctx, @@ -57,7 +54,7 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); assert_eq!( @@ -69,11 +66,8 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + llvm_usize.into() ); - let name = get_usize_dependent_function_name( - generator, - ctx, - "__nac3_ndarray_util_assert_output_shape_same", - ); + let name = + get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_output_shape_same"); create_and_call_function( ctx, @@ -94,15 +88,14 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + /// /// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an /// `ndarray`, corresponding to the value of `ndarray.size`. -pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_size<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_ndarray = ndarray.get_type().as_base_type(); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size"); create_and_call_function( ctx, @@ -120,15 +113,14 @@ pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( /// /// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the /// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`. -pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_nbytes<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_ndarray = ndarray.get_type().as_base_type(); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes"); create_and_call_function( ctx, @@ -146,15 +138,14 @@ pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( /// /// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of /// the `ndarray`, corresponding to the value of `ndarray.__len__`. -pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_len<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_ndarray = ndarray.get_type().as_base_type(); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len"); create_and_call_function( ctx, @@ -171,15 +162,14 @@ pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( /// Generates a call to `__nac3_ndarray_is_c_contiguous`. /// /// Returns an `i1` value indicating whether the `ndarray` is C-contiguous. -pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); let llvm_ndarray = ndarray.get_type().as_base_type(); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous"); create_and_call_function( ctx, @@ -196,20 +186,19 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( /// Generates a call to `__nac3_ndarray_get_nth_pelement`. /// /// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`. -pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, index: IntValue<'ctx>, ) -> PointerValue<'ctx> { let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_ndarray = ndarray.get_type().as_base_type(); assert_eq!(index.get_type(), llvm_usize); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_nth_pelement"); create_and_call_function( ctx, @@ -236,7 +225,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized ) -> PointerValue<'ctx> { let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_ndarray = ndarray.get_type().as_base_type(); @@ -245,8 +234,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized llvm_usize.into() ); - let name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_pelement_by_indices"); create_and_call_function( ctx, @@ -266,15 +254,13 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized /// Generates a call to `__nac3_ndarray_set_strides_by_shape`. /// /// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous. -pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) { let llvm_ndarray = ndarray.get_type().as_base_type(); - let name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape"); create_and_call_function( ctx, @@ -291,13 +277,12 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( /// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number /// of elements in `src_ndarray` must be greater than or equal to the number of elements in /// `dst_ndarray`. -pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_copy_data<'ctx>( ctx: &CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>, ) { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_copy_data"); infer_and_call_function( ctx, diff --git a/nac3core/src/codegen/irrt/ndarray/broadcast.rs b/nac3core/src/codegen/irrt/ndarray/broadcast.rs index cb1ecd4..fceba25 100644 --- a/nac3core/src/codegen/irrt/ndarray/broadcast.rs +++ b/nac3core/src/codegen/irrt/ndarray/broadcast.rs @@ -20,13 +20,12 @@ use crate::codegen::{ /// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`. /// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape. /// - `dst_ndarray.strides` must be allocated and may contain uninitialized values. -pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_ndarray_broadcast_to<'ctx>( ctx: &CodeGenContext<'ctx, '_>, src_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>, ) { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_to"); infer_and_call_function( ctx, &name, @@ -53,7 +52,7 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>( Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); assert_eq!(num_shape_entries.get_type(), llvm_usize); assert!(ShapeEntryType::is_type( @@ -65,7 +64,7 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>( assert_eq!(dst_ndims.get_type(), llvm_usize); assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into()); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_shapes"); infer_and_call_function( ctx, &name, diff --git a/nac3core/src/codegen/irrt/ndarray/indexing.rs b/nac3core/src/codegen/irrt/ndarray/indexing.rs index 3e2c908..df5b27d 100644 --- a/nac3core/src/codegen/irrt/ndarray/indexing.rs +++ b/nac3core/src/codegen/irrt/ndarray/indexing.rs @@ -17,7 +17,7 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( src_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>, ) { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_index"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_index"); infer_and_call_function( ctx, &name, diff --git a/nac3core/src/codegen/irrt/ndarray/iter.rs b/nac3core/src/codegen/irrt/ndarray/iter.rs index 47cd5b2..ad90178 100644 --- a/nac3core/src/codegen/irrt/ndarray/iter.rs +++ b/nac3core/src/codegen/irrt/ndarray/iter.rs @@ -25,7 +25,7 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( ndarray: NDArrayValue<'ctx>, indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); assert_eq!( @@ -33,7 +33,7 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( llvm_usize.into() ); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize"); + let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize"); create_and_call_function( ctx, @@ -53,12 +53,11 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( /// /// Returns an `i1` value indicating whether there are elements left to traverse for the `iter` /// object. -pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_nac3_nditer_has_element<'ctx>( ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>, ) -> IntValue<'ctx> { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_has_element"); + let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_has_element"); infer_and_call_function( ctx, @@ -75,12 +74,8 @@ pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>( /// Generates a call to `__nac3_nditer_next`. /// /// Moves `iter` to point to the next element. -pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - iter: NDIterValue<'ctx>, -) { - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_next"); +pub fn call_nac3_nditer_next<'ctx>(ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>) { + let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_next"); infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None); } diff --git a/nac3core/src/codegen/irrt/ndarray/matmul.rs b/nac3core/src/codegen/irrt/ndarray/matmul.rs index 551cb7c..0df774f 100644 --- a/nac3core/src/codegen/irrt/ndarray/matmul.rs +++ b/nac3core/src/codegen/irrt/ndarray/matmul.rs @@ -20,7 +20,7 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); assert_eq!( BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(), @@ -43,8 +43,7 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized llvm_usize.into() ); - let name = - get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes"); infer_and_call_function( ctx, diff --git a/nac3core/src/codegen/irrt/ndarray/reshape.rs b/nac3core/src/codegen/irrt/ndarray/reshape.rs index 32de2fa..66cbf13 100644 --- a/nac3core/src/codegen/irrt/ndarray/reshape.rs +++ b/nac3core/src/codegen/irrt/ndarray/reshape.rs @@ -18,14 +18,13 @@ pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenera new_ndims: IntValue<'ctx>, new_shape: ArraySliceValue<'ctx>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); assert_eq!(size.get_type(), llvm_usize); assert_eq!(new_ndims.get_type(), llvm_usize); assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into()); let name = get_usize_dependent_function_name( - generator, ctx, "__nac3_ndarray_reshape_resolve_and_check_new_shape", ); diff --git a/nac3core/src/codegen/irrt/ndarray/transpose.rs b/nac3core/src/codegen/irrt/ndarray/transpose.rs index 57661fa..6d152dd 100644 --- a/nac3core/src/codegen/irrt/ndarray/transpose.rs +++ b/nac3core/src/codegen/irrt/ndarray/transpose.rs @@ -23,12 +23,12 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( dst_ndarray: NDArrayValue<'ctx>, axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>, ) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize)); assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into())); - let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose"); + let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_transpose"); infer_and_call_function( ctx, &name, diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs index 6ee40e4..e2fd8c0 100644 --- a/nac3core/src/codegen/irrt/string.rs +++ b/nac3core/src/codegen/irrt/string.rs @@ -2,11 +2,10 @@ use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}; use itertools::Either; use super::get_usize_dependent_function_name; -use crate::codegen::{CodeGenContext, CodeGenerator}; +use crate::codegen::CodeGenContext; /// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. -pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, +pub fn call_string_eq<'ctx>( ctx: &CodeGenContext<'ctx, '_>, str1_ptr: PointerValue<'ctx>, str1_len: IntValue<'ctx>, @@ -15,7 +14,7 @@ pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); - let func_name = get_usize_dependent_function_name(generator, ctx, "nac3_str_eq"); + let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq"); let func = ctx.module.get_function(&func_name).unwrap_or_else(|| { ctx.module.add_function( diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 797a62b..f7483ef 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1212,7 +1212,7 @@ pub fn type_aligned_alloca<'ctx, G: CodeGenerator + ?Sized>( let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let align_ty = align_ty.into(); let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "").unwrap(); diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 6c16be9..6700af4 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -207,7 +207,7 @@ pub fn gen_ndarray_eye<'ctx>( let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); - let llvm_usize = generator.get_size_type(context.ctx); + let llvm_usize = context.get_size_type(); let llvm_dtype = context.get_llvm_type(generator, dtype); let nrows = context @@ -244,7 +244,7 @@ pub fn gen_ndarray_identity<'ctx>( let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); - let llvm_usize = generator.get_size_type(context.ctx); + let llvm_usize = context.get_size_type(); let llvm_dtype = context.get_llvm_type(generator, dtype); let n = context @@ -325,8 +325,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); // Check shapes. - let a_size = a.size(generator, ctx); - let b_size = b.size(generator, ctx); + let a_size = a.size(ctx); + let b_size = b.size(ctx); let same_shape = ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap(); ctx.make_assert( @@ -353,9 +353,9 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b); Ok((a_iter, b_iter)) }, - |generator, ctx, (a_iter, _b_iter)| { + |_, ctx, (a_iter, _b_iter)| { // Only a_iter drives the condition, b_iter should have the same status. - Ok(a_iter.has_element(generator, ctx)) + Ok(a_iter.has_element(ctx)) }, |_, ctx, _hooks, (a_iter, b_iter)| { let a_scalar = a_iter.get_scalar(ctx); @@ -385,9 +385,9 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_store(result, new_result).unwrap(); Ok(()) }, - |generator, ctx, (a_iter, b_iter)| { - a_iter.next(generator, ctx); - b_iter.next(generator, ctx); + |_, ctx, (a_iter, b_iter)| { + a_iter.next(ctx); + b_iter.next(ctx); Ok(()) }, ) diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 2a3bd06..c327405 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -306,7 +306,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => { // Handle list item assignment - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let target_item_ty = iter_type_vars(list_params).next().unwrap().ty; let target = generator @@ -367,10 +367,8 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( .unwrap() .to_basic_value_enum(ctx, generator, key_ty)? .into_int_value(); - let index = ctx - .builder - .build_int_s_extend(index, generator.get_size_type(ctx.ctx), "sext") - .unwrap(); + let index = + ctx.builder.build_int_s_extend(index, ctx.get_size_type(), "sext").unwrap(); // handle negative index let is_negative = ctx @@ -378,7 +376,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( .build_int_compare( IntPredicate::SLT, index, - generator.get_size_type(ctx.ctx).const_zero(), + ctx.get_size_type().const_zero(), "is_neg", ) .unwrap(); @@ -460,7 +458,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( let target = broadcast_result.ndarrays[0]; let value = broadcast_result.ndarrays[1]; - target.copy_data_from(generator, ctx, value); + target.copy_data_from(ctx, value); } _ => { panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); @@ -484,7 +482,7 @@ pub fn gen_for( let var_assignment = ctx.var_assignment.clone(); let int32 = ctx.ctx.i32_type(); - let size_t = generator.get_size_type(ctx.ctx); + let size_t = ctx.get_size_type(); let zero = int32.const_zero(); let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap(); let body_bb = ctx.ctx.append_basic_block(current, "for.body"); diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 337d049..9ea4aca 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -152,7 +152,7 @@ impl<'ctx> ListType<'ctx> { _ => panic!("Expected `list` type, but got {}", ctx.unifier.stringify(ty)), }; - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_elem_type = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) { None } else { @@ -273,7 +273,7 @@ impl<'ctx> ListType<'ctx> { } let plist = self.alloca_var(generator, ctx, name); - plist.store_size(ctx, generator, len); + plist.store_size(ctx, len); let item = self.item.unwrap_or(self.llvm_usize.into()); plist.create_data(ctx, item, None); @@ -300,7 +300,7 @@ impl<'ctx> ListType<'ctx> { ) -> >::Value { let plist = self.alloca_var(generator, ctx, name); - plist.store_size(ctx, generator, self.llvm_usize.const_zero()); + plist.store_size(ctx, self.llvm_usize.const_zero()); plist.create_data(ctx, self.item.unwrap_or(self.llvm_usize.into()), None); plist diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index 0f30f0e..b0c9d63 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -67,9 +67,7 @@ impl<'ctx> NDArrayType<'ctx> { unsafe { ndarray.create_data(generator, ctx) }; // Copy all contents from the list. - irrt::ndarray::call_nac3_ndarray_array_write_list_to_array( - generator, ctx, list_value, ndarray, - ); + irrt::ndarray::call_nac3_ndarray_array_write_list_to_array(ctx, list_value, ndarray); ndarray } @@ -116,7 +114,7 @@ impl<'ctx> NDArrayType<'ctx> { } // Set strides, the `data` is contiguous - ndarray.set_strides_contiguous(generator, ctx); + ndarray.set_strides_contiguous(ctx); ndarray } else { diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index e5fb8cd..f4a8b73 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -140,7 +140,7 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let llvm_dtype = ctx.get_llvm_type(generator, dtype); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize } } diff --git a/nac3core/src/codegen/types/ndarray/map.rs b/nac3core/src/codegen/types/ndarray/map.rs index bf82b4d..6fdd9e1 100644 --- a/nac3core/src/codegen/types/ndarray/map.rs +++ b/nac3core/src/codegen/types/ndarray/map.rs @@ -86,10 +86,10 @@ impl<'ctx> NDArrayType<'ctx> { .collect_vec(); Ok((nditer, other_nditers)) }, - |generator, ctx, (out_nditer, _in_nditers)| { + |_, ctx, (out_nditer, _in_nditers)| { // We can simply use `out_nditer`'s `has_element()`. // `in_nditers`' `has_element()`s should return the same value. - Ok(out_nditer.has_element(generator, ctx)) + Ok(out_nditer.has_element(ctx)) }, |generator, ctx, _hooks, (out_nditer, in_nditers)| { // Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`, @@ -104,10 +104,10 @@ impl<'ctx> NDArrayType<'ctx> { Ok(()) }, - |generator, ctx, (out_nditer, in_nditers)| { + |_, ctx, (out_nditer, in_nditers)| { // Advance all iterators - out_nditer.next(generator, ctx); - in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx)); + out_nditer.next(ctx); + in_nditers.iter().for_each(|nditer| nditer.next(ctx)); Ok(()) }, )?; diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 353ace3..a7bcb7e 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -158,7 +158,7 @@ impl<'ctx> NDArrayType<'ctx> { let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let llvm_dtype = ctx.get_llvm_type(generator, dtype); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let ndims = extract_ndims(&ctx.unifier, ndims); NDArrayType { @@ -259,9 +259,9 @@ impl<'ctx> NDArrayType<'ctx> { .builder .build_int_truncate_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "") .unwrap(); - ndarray.store_itemsize(ctx, generator, itemsize); + ndarray.store_itemsize(ctx, itemsize); - ndarray.store_ndims(ctx, generator, ndims); + ndarray.store_ndims(ctx, ndims); ndarray.create_shape(ctx, self.llvm_usize, ndims); ndarray.create_strides(ctx, self.llvm_usize, ndims); @@ -307,7 +307,7 @@ impl<'ctx> NDArrayType<'ctx> { let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); // Write shape let ndarray_shape = ndarray.shape(); @@ -342,7 +342,7 @@ impl<'ctx> NDArrayType<'ctx> { let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); // Write shape let ndarray_shape = ndarray.shape(); diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index ccb63b4..947f95a 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -52,7 +52,7 @@ impl<'ctx> TupleType<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, ty: Type, ) -> Self { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); // Sanity check on object type. let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else { diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs index b756f27..9f6652b 100644 --- a/nac3core/src/codegen/values/array.rs +++ b/nac3core/src/codegen/values/array.rs @@ -418,7 +418,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); + debug_assert_eq!(idx.get_type(), ctx.get_size_type()); let size = self.size(ctx, generator); let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index c497f8f..08d2b6b 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -97,13 +97,8 @@ impl<'ctx> ListValue<'ctx> { } /// Stores the `size` of this `list` into this instance. - pub fn store_size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - size: IntValue<'ctx>, - ) { - debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx)); + pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) { + debug_assert_eq!(size.get_type(), ctx.get_size_type()); self.len_field(ctx).set(ctx, self.value, size, self.name); } @@ -213,7 +208,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); + debug_assert_eq!(idx.get_type(), ctx.get_size_type()); let size = self.size(ctx, generator); let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index 1b99f46..b145746 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -112,7 +112,7 @@ impl<'ctx> NDArrayValue<'ctx> { target_shape.base_ptr(ctx, generator), ); - irrt::ndarray::call_nac3_ndarray_broadcast_to(generator, ctx, *self, broadcast_ndarray); + irrt::ndarray::call_nac3_ndarray_broadcast_to(ctx, *self, broadcast_ndarray); broadcast_ndarray } } @@ -146,7 +146,7 @@ fn broadcast_shapes<'ctx, G, Shape>( Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx); assert!(in_shape_entries @@ -199,7 +199,7 @@ impl<'ctx> NDArrayType<'ctx> { ) -> BroadcastAllResult<'ctx, G> { assert!(!ndarrays.is_empty()); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); // Infer the broadcast output ndims. let broadcast_ndims_int = diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 8eb700b..52082df 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -130,7 +130,7 @@ impl<'ctx> NDArrayValue<'ctx> { gen_if_callback( generator, ctx, - |generator, ctx| Ok(self.is_c_contiguous(generator, ctx)), + |_, ctx| Ok(self.is_c_contiguous(ctx)), |_, ctx| { // This ndarray is contiguous. let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name); @@ -184,7 +184,7 @@ impl<'ctx> NDArrayValue<'ctx> { // Copy shape and update strides let shape = carray.load_shape(ctx); ndarray.copy_shape_from_array(generator, ctx, shape); - ndarray.set_strides_contiguous(generator, ctx); + ndarray.set_strides_contiguous(ctx); // Share data let data = carray.load_data(ctx); diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 3821f23..1a96522 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -245,7 +245,7 @@ impl<'ctx> RustNDIndex<'ctx> { } RustNDIndex::Slice(in_rust_slice) => { let user_slice_ptr = - SliceType::new(ctx.ctx, ctx.ctx.i32_type(), generator.get_size_type(ctx.ctx)) + SliceType::new(ctx.ctx, ctx.ctx.i32_type(), ctx.get_size_type()) .alloca_var(generator, ctx, None); in_rust_slice.write_to_slice(ctx, user_slice_ptr); diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs index f802c0c..a24316b 100644 --- a/nac3core/src/codegen/values/ndarray/matmul.rs +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -35,7 +35,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty); let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty); - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype); // Deduce ndims of the result of matmul. @@ -315,7 +315,7 @@ impl<'ctx> NDArrayValue<'ctx> { let result_shape = result.shape(); out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape); - out_ndarray.copy_data_from(generator, ctx, result); + out_ndarray.copy_data_from(ctx, result); out_ndarray } } diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index b32a8f6..595345e 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -81,13 +81,8 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Stores the number of dimensions `ndims` into this instance. - pub fn store_ndims( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ndims: IntValue<'ctx>, - ) { - debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); + pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, ndims: IntValue<'ctx>) { + debug_assert_eq!(ndims.get_type(), ctx.get_size_type()); let pndims = self.ptr_to_ndims(ctx); ctx.builder.build_store(pndims, ndims).unwrap(); @@ -104,13 +99,8 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Stores the size of each element `itemsize` into this instance. - pub fn store_itemsize( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - itemsize: IntValue<'ctx>, - ) { - debug_assert_eq!(itemsize.get_type(), generator.get_size_type(ctx.ctx)); + pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, itemsize: IntValue<'ctx>) { + debug_assert_eq!(itemsize.get_type(), ctx.get_size_type()); self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name); } @@ -205,12 +195,12 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) { - let nbytes = self.nbytes(generator, ctx); + let nbytes = self.nbytes(ctx); let data = type_aligned_alloca(generator, ctx, self.dtype, nbytes, None); self.store_data(ctx, data); - self.set_strides_contiguous(generator, ctx); + self.set_strides_contiguous(ctx); } /// Returns a proxy object to the field storing the data of this `NDArray`. @@ -284,52 +274,32 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Get the `np.size()` of this ndarray. - pub fn size( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_size(generator, ctx, *self) + pub fn size(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_ndarray_size(ctx, *self) } /// Get the `ndarray.nbytes` of this ndarray. - pub fn nbytes( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_nbytes(generator, ctx, *self) + pub fn nbytes(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_ndarray_nbytes(ctx, *self) } /// Get the `len()` of this ndarray. - pub fn len( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self) + pub fn len(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_ndarray_len(ctx, *self) } /// Check if this ndarray is C-contiguous. /// /// See NumPy's `flags["C_CONTIGUOUS"]`: - pub fn is_c_contiguous( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_is_c_contiguous(generator, ctx, *self) + pub fn is_c_contiguous(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_ndarray_is_c_contiguous(ctx, *self) } /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. /// /// Update the ndarray's strides to make the ndarray contiguous. - pub fn set_strides_contiguous( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) { - irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self); + pub fn set_strides_contiguous(&self, ctx: &CodeGenContext<'ctx, '_>) { + irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(ctx, *self); } /// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and @@ -347,7 +317,7 @@ impl<'ctx> NDArrayValue<'ctx> { let shape = self.shape(); clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); unsafe { clone.create_data(generator, ctx) }; - clone.copy_data_from(generator, ctx, *self); + clone.copy_data_from(ctx, *self); clone } @@ -357,14 +327,9 @@ impl<'ctx> NDArrayValue<'ctx> { /// do not matter. The copying order is determined by how their flattened views look. /// /// Panics if the `dtype`s of ndarrays are different. - pub fn copy_data_from( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - src: NDArrayValue<'ctx>, - ) { + pub fn copy_data_from(&self, ctx: &CodeGenContext<'ctx, '_>, src: NDArrayValue<'ctx>) { assert_eq!(self.dtype, src.dtype, "self and src dtype should match"); - irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self); + irrt::ndarray::call_nac3_ndarray_copy_data(ctx, src, *self); } /// Fill the ndarray with a scalar. @@ -468,7 +433,7 @@ impl<'ctx> NDArrayValue<'ctx> { ) -> Option> { if self.is_unsized() { // NOTE: `np.size(self) == 0` here is never possible. - let zero = generator.get_size_type(ctx.ctx).const_zero(); + let zero = ctx.get_size_type().const_zero(); let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; Some(value) @@ -756,9 +721,9 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { fn size( &self, ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + _: &G, ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self.0) + irrt::ndarray::call_nac3_ndarray_len(ctx, *self.0) } } @@ -770,7 +735,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(generator, ctx, *self.0, *idx); + let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(ctx, *self.0, *idx); // Current implementation is transparent - The returned pointer type is // already cast into the expected type, allowing for immediately @@ -834,7 +799,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> indices: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { - assert_eq!(indices.element_type(ctx, generator), generator.get_size_type(ctx.ctx).into()); + assert_eq!(indices.element_type(ctx, generator), ctx.get_size_type().into()); let indices = TypedArrayLikeAdapter::from( indices.as_slice_value(ctx, generator), @@ -867,7 +832,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> indices: &Index, name: Option<&str>, ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let indices_size = indices.size(ctx, generator); let nidx_leq_ndims = ctx diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 4b4e07a..3784193 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -53,20 +53,16 @@ impl<'ctx> NDIterValue<'ctx> { /// If `ndarray` is unsized, this returns true only for the first iteration. /// If `ndarray` is 0-sized, this always returns false. #[must_use] - pub fn has_element( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> IntValue<'ctx> { - irrt::ndarray::call_nac3_nditer_has_element(generator, ctx, *self) + pub fn has_element(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + irrt::ndarray::call_nac3_nditer_has_element(ctx, *self) } /// Go to the next element. If `has_element()` is false, then this has undefined behavior. /// /// If `ndarray` is unsized, this can only be called once. /// If `ndarray` is 0-sized, this can never be called. - pub fn next(&self, generator: &G, ctx: &CodeGenContext<'ctx, '_>) { - irrt::ndarray::call_nac3_nditer_next(generator, ctx, *self); + pub fn next(&self, ctx: &CodeGenContext<'ctx, '_>) { + irrt::ndarray::call_nac3_nditer_next(ctx, *self); } fn element_field( @@ -167,10 +163,10 @@ impl<'ctx> NDArrayValue<'ctx> { |generator, ctx| { Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self)) }, - |generator, ctx, nditer| Ok(nditer.has_element(generator, ctx)), + |_, ctx, nditer| Ok(nditer.has_element(ctx)), |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), - |generator, ctx, nditer| { - nditer.next(generator, ctx); + |_, ctx, nditer| { + nditer.next(ctx); Ok(()) }, ) diff --git a/nac3core/src/codegen/values/ndarray/shape.rs b/nac3core/src/codegen/values/ndarray/shape.rs index 190a1e4..3ac2795 100644 --- a/nac3core/src/codegen/values/ndarray/shape.rs +++ b/nac3core/src/codegen/values/ndarray/shape.rs @@ -30,7 +30,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, (input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>), ) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> { - let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize = ctx.get_size_type(); let zero = llvm_usize.const_zero(); let one = llvm_usize.const_int(1, false); diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index 5027be5..f68931f 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -70,7 +70,7 @@ impl<'ctx> NDArrayValue<'ctx> { dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); // Resolve negative indices - let size = self.size(generator, ctx); + let size = self.size(ctx); let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims(), false); let dst_shape = dst_ndarray.shape(); irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape( @@ -84,10 +84,10 @@ impl<'ctx> NDArrayValue<'ctx> { gen_if_callback( generator, ctx, - |generator, ctx| Ok(self.is_c_contiguous(generator, ctx)), + |_, ctx| Ok(self.is_c_contiguous(ctx)), |generator, ctx| { // Reshape is possible without copying - dst_ndarray.set_strides_contiguous(generator, ctx); + dst_ndarray.set_strides_contiguous(ctx); dst_ndarray.store_data(ctx, self.data().base_ptr(ctx, generator)); Ok(()) @@ -97,7 +97,7 @@ impl<'ctx> NDArrayValue<'ctx> { unsafe { dst_ndarray.create_data(generator, ctx); } - dst_ndarray.copy_data_from(generator, ctx, *self); + dst_ndarray.copy_data_from(ctx, *self); Ok(()) }, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 600276b..a0673a1 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1278,11 +1278,7 @@ impl<'a> BuiltinBuilder<'a> { let size = ctx .builder - .build_int_truncate_or_bit_cast( - ndarray.size(generator, ctx), - ctx.ctx.i32_type(), - "", - ) + .build_int_truncate_or_bit_cast(ndarray.size(ctx), ctx.ctx.i32_type(), "") .unwrap(); Ok(Some(size.into())) }), From 8e614d83de169a5dbd4a36219a3ab37a09c236e7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 14 Jan 2025 18:20:09 +0800 Subject: [PATCH 06/49] [core] codegen: Add ProxyType::new overloads and refactor to use them --- nac3artiq/src/codegen.rs | 6 +- nac3core/src/codegen/builtin_fns.rs | 130 ++++++++++-------- nac3core/src/codegen/expr.rs | 42 ++---- nac3core/src/codegen/mod.rs | 6 +- nac3core/src/codegen/numpy.rs | 16 +-- nac3core/src/codegen/stmt.rs | 3 +- nac3core/src/codegen/test.rs | 4 +- nac3core/src/codegen/types/list.rs | 45 ++++-- nac3core/src/codegen/types/ndarray/array.rs | 11 +- .../src/codegen/types/ndarray/broadcast.rs | 20 ++- .../src/codegen/types/ndarray/contiguous.rs | 22 ++- .../src/codegen/types/ndarray/indexing.rs | 17 ++- nac3core/src/codegen/types/ndarray/map.rs | 14 +- nac3core/src/codegen/types/ndarray/mod.rs | 80 +++++++---- nac3core/src/codegen/types/ndarray/nditer.rs | 20 ++- nac3core/src/codegen/types/tuple.rs | 27 +++- nac3core/src/codegen/types/utils/slice.rs | 24 +++- nac3core/src/codegen/values/list.rs | 8 +- .../src/codegen/values/ndarray/broadcast.rs | 4 +- .../src/codegen/values/ndarray/contiguous.rs | 11 +- .../src/codegen/values/ndarray/indexing.rs | 8 +- nac3core/src/codegen/values/ndarray/matmul.rs | 2 +- nac3core/src/codegen/values/ndarray/mod.rs | 22 +-- nac3core/src/codegen/values/ndarray/nditer.rs | 4 +- nac3core/src/codegen/values/ndarray/view.rs | 2 +- 25 files changed, 320 insertions(+), 228 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index cb75606..c968198 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -476,8 +476,8 @@ fn format_rpc_arg<'ctx>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let dtype = ctx.get_llvm_type(generator, elem_ty); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, ndims) - .map_value(arg.into_pointer_value(), None); + let ndarray = + NDArrayType::new(ctx, dtype, ndims).map_value(arg.into_pointer_value(), None); let ndims = llvm_usize.const_int(ndims, false); @@ -609,7 +609,7 @@ fn format_rpc_ret<'ctx>( let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); let dtype_llvm = ctx.get_llvm_type(generator, dtype); let ndims = extract_ndims(&ctx.unifier, ndims); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, ndims) + let ndarray = NDArrayType::new(ctx, dtype_llvm, ndims) .construct_uninitialized(generator, ctx, None); // NOTE: Current content of `ndarray`: diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 96f8c70..911e3dc 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -752,24 +752,20 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); let llvm_common_dtype = x1.get_type().element_type(); - let result = NDArrayType::new_broadcast( - generator, - ctx.ctx, - llvm_common_dtype, - &[x1.get_type(), x2.get_type()], - ) - .broadcast_starmap( - generator, - ctx, - &[x1, x2], - NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, - |_, ctx, scalars| { - let x1_scalar = scalars[0]; - let x2_scalar = scalars[1]; - Ok(call_min(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) - }, - ) - .unwrap(); + let result = + NDArrayType::new_broadcast(ctx, llvm_common_dtype, &[x1.get_type(), x2.get_type()]) + .broadcast_starmap( + generator, + ctx, + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_min(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) + }, + ) + .unwrap(); result.as_base_value().into() } @@ -1015,24 +1011,20 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); let llvm_common_dtype = x1.get_type().element_type(); - let result = NDArrayType::new_broadcast( - generator, - ctx.ctx, - llvm_common_dtype, - &[x1.get_type(), x2.get_type()], - ) - .broadcast_starmap( - generator, - ctx, - &[x1, x2], - NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, - |_, ctx, scalars| { - let x1_scalar = scalars[0]; - let x2_scalar = scalars[1]; - Ok(call_max(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) - }, - ) - .unwrap(); + let result = + NDArrayType::new_broadcast(ctx, llvm_common_dtype, &[x1.get_type(), x2.get_type()]) + .broadcast_starmap( + generator, + ctx, + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_max(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) + }, + ) + .unwrap(); result.as_base_value().into() } @@ -1652,7 +1644,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1694,7 +1686,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { q.create_data(generator, ctx) }; @@ -1715,8 +1707,11 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( let q = q.as_base_value().as_basic_value_enum(); let r = r.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[q.get_type(), r.get_type()]) - .construct_from_objects(ctx, [q, r], None); + let tuple = TupleType::new(ctx, &[q.get_type(), r.get_type()]).construct_from_objects( + ctx, + [q, r], + None, + ); Ok(tuple.as_base_value().into()) } @@ -1746,8 +1741,8 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1); - let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray1_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 1); + let out_ndarray2_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None); unsafe { u.create_data(generator, ctx) }; @@ -1775,7 +1770,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( let u = u.as_base_value().as_basic_value_enum(); let s = s.as_base_value().as_basic_value_enum(); let vh = vh.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[u.get_type(), s.get_type(), vh.get_type()]) + let tuple = TupleType::new(ctx, &[u.get_type(), s.get_type(), vh.get_type()]) .construct_from_objects(ctx, [u, s, vh], None); Ok(tuple.as_base_value().into()) } @@ -1796,7 +1791,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1838,8 +1833,12 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) }; - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) - .construct_dyn_shape(generator, ctx, &[d0, d1], None); + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2).construct_dyn_shape( + generator, + ctx, + &[d0, d1], + None, + ); unsafe { out.create_data(generator, ctx) }; let x1_c = x1.make_contiguous_ndarray(generator, ctx); @@ -1880,7 +1879,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { l.create_data(generator, ctx) }; @@ -1901,8 +1900,11 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let l = l.as_base_value().as_basic_value_enum(); let u = u.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[l.get_type(), u.get_type()]) - .construct_from_objects(ctx, [l, u], None); + let tuple = TupleType::new(ctx, &[l.get_type(), u.get_type()]).construct_from_objects( + ctx, + [l, u], + None, + ); Ok(tuple.as_base_value().into()) } @@ -1936,11 +1938,11 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) }; - let x2 = NDArrayType::new_unsized(generator, ctx.ctx, ctx.ctx.f64_type().into()) + let x2 = NDArrayType::new_unsized(ctx, ctx.ctx.f64_type().into()) .construct_unsized(generator, ctx, &x2, None); // x2.shape == [] let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1] - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1979,8 +1981,12 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( } // The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call. - let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1) - .construct_const_shape(generator, ctx, &[1], None); + let det = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 1).construct_const_shape( + generator, + ctx, + &[1], + None, + ); unsafe { det.create_data(generator, ctx) }; let x1_c = x1.make_contiguous_ndarray(generator, ctx); @@ -2014,7 +2020,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let t = out_ndarray_ty.construct_uninitialized(generator, ctx, None); t.copy_shape_from_ndarray(generator, ctx, x1); @@ -2037,8 +2043,11 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let t = t.as_base_value().as_basic_value_enum(); let z = z.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[t.get_type(), z.get_type()]) - .construct_from_objects(ctx, [t, z], None); + let tuple = TupleType::new(ctx, &[t.get_type(), z.get_type()]).construct_from_objects( + ctx, + [t, z], + None, + ); Ok(tuple.as_base_value().into()) } @@ -2059,7 +2068,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let h = out_ndarray_ty.construct_uninitialized(generator, ctx, None); h.copy_shape_from_ndarray(generator, ctx, x1); @@ -2082,7 +2091,10 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let h = h.as_base_value().as_basic_value_enum(); let q = q.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[h.get_type(), q.get_type()]) - .construct_from_objects(ctx, [h, q], None); + let tuple = TupleType::new(ctx, &[h.get_type(), q.get_type()]).construct_from_objects( + ctx, + [h, q], + None, + ); Ok(tuple.as_base_value().into()) } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 00290d3..8f52e92 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1167,7 +1167,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( "listcomp.alloc_size", ) .unwrap(); - list = ListType::new(generator, ctx.ctx, elem_ty).construct( + list = ListType::new(ctx, &elem_ty).construct( generator, ctx, list_alloc_size.into_int_value(), @@ -1218,12 +1218,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( Some("length"), ) .into_int_value(); - list = ListType::new(generator, ctx.ctx, elem_ty).construct( - generator, - ctx, - length, - Some("listcomp"), - ); + list = ListType::new(ctx, &elem_ty).construct(generator, ctx, length, Some("listcomp")); let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; // counter = -1 @@ -1386,8 +1381,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( .build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "") .unwrap(); - let new_list = ListType::new(generator, ctx.ctx, llvm_elem_ty) - .construct(generator, ctx, size, None); + let new_list = + ListType::new(ctx, &llvm_elem_ty).construct(generator, ctx, size, None); let lhs_size = ctx .builder @@ -1474,7 +1469,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); let sizeof_elem = elem_llvm_ty.size_of().unwrap(); - let new_list = ListType::new(generator, ctx.ctx, elem_llvm_ty).construct( + let new_list = ListType::new(ctx, &elem_llvm_ty).construct( generator, ctx, ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(), @@ -1576,8 +1571,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let right = right.to_ndarray(generator, ctx); let result = NDArrayType::new_broadcast( - generator, - ctx.ctx, + ctx, llvm_common_dtype, &[left.get_type(), right.get_type()], ) @@ -1850,8 +1844,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( .to_ndarray(generator, ctx); let result_ndarray = NDArrayType::new_broadcast( - generator, - ctx.ctx, + ctx, ctx.ctx.i8_type().into(), &[left.get_type(), right.get_type()], ) @@ -2480,18 +2473,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( }; let length = ctx.get_size_type().const_int(elements.len() as u64, false); let arr_str_ptr = if let Some(ty) = ty { - ListType::new(generator, ctx.ctx, ty).construct( - generator, - ctx, - length, - Some("list"), - ) + ListType::new(ctx, &ty).construct(generator, ctx, length, Some("list")) } else { - ListType::new_untyped(generator, ctx.ctx).construct_empty( - generator, - ctx, - Some("list"), - ) + ListType::new_untyped(ctx).construct_empty(generator, ctx, Some("list")) }; let arr_ptr = arr_str_ptr.data(); for (i, v) in elements.iter().enumerate() { @@ -2970,12 +2954,8 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .unwrap(), step, ); - let res_array_ret = ListType::new(generator, ctx.ctx, ty).construct( - generator, - ctx, - length, - Some("ret"), - ); + let res_array_ret = + ListType::new(ctx, &ty).construct(generator, ctx, length, Some("ret")); let Some(res_ind) = handle_slice_indices( &None, &None, diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index f7483ef..dcfa2b8 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -530,7 +530,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( *params.iter().next().unwrap().1, ); - ListType::new(generator, ctx, element_type).as_base_type().into() + ListType::new_with_generator(generator, ctx, element_type).as_base_type().into() } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { @@ -540,7 +540,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( ctx, module, generator, unifier, top_level, type_cache, dtype, ); - NDArrayType::new(generator, ctx, element_type, ndims).as_base_type().into() + NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_base_type().into() } _ => unreachable!( @@ -594,7 +594,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) }) .collect_vec(); - TupleType::new(generator, ctx, &fields).as_base_type().into() + TupleType::new_with_generator(generator, ctx, &fields).as_base_type().into() } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 6700af4..3cdd1ef 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -42,7 +42,7 @@ pub fn gen_ndarray_empty<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) + let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_empty(generator, context, &shape, None); Ok(ndarray.as_base_value()) } @@ -67,7 +67,7 @@ pub fn gen_ndarray_zeros<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) + let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_zeros(generator, context, dtype, &shape, None); Ok(ndarray.as_base_value()) } @@ -92,7 +92,7 @@ pub fn gen_ndarray_ones<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) + let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_ones(generator, context, dtype, &shape, None); Ok(ndarray.as_base_value()) } @@ -120,7 +120,7 @@ pub fn gen_ndarray_full<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims).construct_numpy_full( + let ndarray = NDArrayType::new(context, llvm_dtype, ndims).construct_numpy_full( generator, context, &shape, @@ -223,7 +223,7 @@ pub fn gen_ndarray_eye<'ctx>( .build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "") .unwrap(); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2) + let ndarray = NDArrayType::new(context, llvm_dtype, 2) .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); Ok(ndarray.as_base_value()) } @@ -251,7 +251,7 @@ pub fn gen_ndarray_identity<'ctx>( .builder .build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "") .unwrap(); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2) + let ndarray = NDArrayType::new(context, llvm_dtype, 2) .construct_numpy_identity(generator, context, dtype, n, None); Ok(ndarray.as_base_value()) } @@ -349,8 +349,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( ctx, Some("np_dot"), |generator, ctx| { - let a_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, a); - let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b); + let a_iter = NDIterType::new(ctx).construct(generator, ctx, a); + let b_iter = NDIterType::new(ctx).construct(generator, ctx, b); Ok((a_iter, b_iter)) }, |_, ctx, (a_iter, _b_iter)| { diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index c327405..85a894a 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -448,8 +448,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()].into_iter().max().unwrap(); let broadcast_result = NDArrayType::new( - generator, - ctx.ctx, + ctx, value.get_type().element_type(), broadcast_ndims, ) diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 6518d85..a58a984 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -446,7 +446,7 @@ fn test_classes_list_type_new() { let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); - let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into()); + let llvm_list = ListType::new_with_generator(&generator, &ctx, llvm_i32.into()); assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok()); } @@ -466,6 +466,6 @@ fn test_classes_ndarray_type_new() { let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); - let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), 2); + let llvm_ndarray = NDArrayType::new_with_generator(&generator, &ctx, llvm_i32.into(), 2); assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); } diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 9ea4aca..637cced 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -104,7 +104,7 @@ impl<'ctx> ListType<'ctx> { element_type: Option>, llvm_usize: IntType<'ctx>, ) -> PointerType<'ctx> { - let element_type = element_type.unwrap_or(llvm_usize.into()); + let element_type = element_type.map_or(llvm_usize.into(), |ty| ty.as_basic_type_enum()); let field_tys = Self::fields(element_type, llvm_usize).into_iter().map(|field| field.1).collect_vec(); @@ -112,26 +112,45 @@ impl<'ctx> ListType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } + fn new_impl( + ctx: &'ctx Context, + element_type: Option>, + llvm_usize: IntType<'ctx>, + ) -> Self { + let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize); + + Self { ty: llvm_list, item: element_type, llvm_usize } + } + /// Creates an instance of [`ListType`]. #[must_use] - pub fn new( + pub fn new(ctx: &CodeGenContext<'ctx, '_>, element_type: &impl BasicType<'ctx>) -> Self { + Self::new_impl(ctx.ctx, Some(element_type.as_basic_type_enum()), ctx.get_size_type()) + } + + /// Creates an instance of [`ListType`]. + #[must_use] + pub fn new_with_generator( generator: &G, ctx: &'ctx Context, element_type: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_list = Self::llvm_type(ctx, Some(element_type), llvm_usize); - - Self { ty: llvm_list, item: Some(element_type), llvm_usize } + Self::new_impl(ctx, Some(element_type.as_basic_type_enum()), generator.get_size_type(ctx)) } /// Creates an instance of [`ListType`] with an unknown element type. #[must_use] - pub fn new_untyped(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_list = Self::llvm_type(ctx, None, llvm_usize); + pub fn new_untyped(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, None, ctx.get_size_type()) + } - Self { ty: llvm_list, item: None, llvm_usize } + /// Creates an instance of [`ListType`] with an unknown element type. + #[must_use] + pub fn new_untyped_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, None, generator.get_size_type(ctx)) } /// Creates an [`ListType`] from a [unifier type][Type]. @@ -159,11 +178,7 @@ impl<'ctx> ListType<'ctx> { Some(ctx.get_llvm_type(generator, elem_type)) }; - Self { - ty: Self::llvm_type(ctx.ctx, llvm_elem_type, llvm_usize), - item: llvm_elem_type, - llvm_usize, - } + Self::new_impl(ctx.ctx, llvm_elem_type, llvm_usize) } /// Creates an [`ListType`] from a [`PointerType`]. diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index b0c9d63..7061112 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -44,7 +44,7 @@ impl<'ctx> NDArrayType<'ctx> { assert!(self.ndims >= ndims_int); assert_eq!(dtype, self.dtype); - let list_value = list.as_i8_list(generator, ctx); + let list_value = list.as_i8_list(ctx); // Validate `list` has a consistent shape. // Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`. @@ -61,8 +61,8 @@ impl<'ctx> NDArrayType<'ctx> { generator, ctx, list_value, ndims, &shape, ); - let ndarray = Self::new(generator, ctx.ctx, dtype, ndims_int) - .construct_uninitialized(generator, ctx, name); + let ndarray = + Self::new(ctx, dtype, ndims_int).construct_uninitialized(generator, ctx, name); ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); unsafe { ndarray.create_data(generator, ctx) }; @@ -96,8 +96,7 @@ impl<'ctx> NDArrayType<'ctx> { let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let ndarray = Self::new(generator, ctx.ctx, dtype, 1) - .construct_uninitialized(generator, ctx, name); + let ndarray = Self::new(ctx, dtype, 1).construct_uninitialized(generator, ctx, name); // Set data let data = ctx @@ -168,7 +167,7 @@ impl<'ctx> NDArrayType<'ctx> { .map(BasicValueEnum::into_pointer_value) .unwrap(); - NDArrayType::new(generator, ctx.ctx, dtype, ndims).map_value(ndarray, None) + NDArrayType::new(ctx, dtype, ndims).map_value(ndarray, None) } /// Implementation of `np_array(, copy=copy)`. diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs index 5ee2845..3a1fd8d 100644 --- a/nac3core/src/codegen/types/ndarray/broadcast.rs +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -79,15 +79,27 @@ impl<'ctx> ShapeEntryType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`ShapeEntryType`]. - #[must_use] - pub fn new(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { let llvm_ty = Self::llvm_type(ctx, llvm_usize); Self { ty: llvm_ty, llvm_usize } } + /// Creates an instance of [`ShapeEntryType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`ShapeEntryType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + /// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index f4a8b73..c751d57 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -117,17 +117,26 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } + fn new_impl(ctx: &'ctx Context, item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize); + + Self { ty: llvm_cndarray, item, llvm_usize } + } + /// Creates an instance of [`ContiguousNDArrayType`]. #[must_use] - pub fn new( + pub fn new(ctx: &CodeGenContext<'ctx, '_>, item: &impl BasicType<'ctx>) -> Self { + Self::new_impl(ctx.ctx, item.as_basic_type_enum(), ctx.get_size_type()) + } + + /// Creates an instance of [`ContiguousNDArrayType`]. + #[must_use] + pub fn new_with_generator( generator: &G, ctx: &'ctx Context, item: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize); - - Self { ty: llvm_cndarray, item, llvm_usize } + Self::new_impl(ctx, item, generator.get_size_type(ctx)) } /// Creates an [`ContiguousNDArrayType`] from a [unifier type][Type]. @@ -140,9 +149,8 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let llvm_dtype = ctx.get_llvm_type(generator, dtype); - let llvm_usize = ctx.get_size_type(); - Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize } + Self::new_impl(ctx.ctx, llvm_dtype, ctx.get_size_type()) } /// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`. diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 644e173..3e4e136 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -75,14 +75,25 @@ impl<'ctx> NDIndexType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - #[must_use] - pub fn new(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { let llvm_ndindex = Self::llvm_type(ctx, llvm_usize); Self { ty: llvm_ndindex, llvm_usize } } + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); diff --git a/nac3core/src/codegen/types/ndarray/map.rs b/nac3core/src/codegen/types/ndarray/map.rs index 6fdd9e1..ae24458 100644 --- a/nac3core/src/codegen/types/ndarray/map.rs +++ b/nac3core/src/codegen/types/ndarray/map.rs @@ -46,9 +46,8 @@ impl<'ctx> NDArrayType<'ctx> { let out_ndarray = match out { NDArrayOut::NewNDArray { dtype } => { // Create a new ndarray based on the broadcast shape. - let result_ndarray = - NDArrayType::new(generator, ctx.ctx, dtype, broadcast_result.ndims) - .construct_uninitialized(generator, ctx, None); + let result_ndarray = NDArrayType::new(ctx, dtype, broadcast_result.ndims) + .construct_uninitialized(generator, ctx, None); result_ndarray.copy_shape_from_array( generator, ctx, @@ -70,7 +69,7 @@ impl<'ctx> NDArrayType<'ctx> { }; // Map element-wise and store results into `mapped_ndarray`. - let nditer = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, out_ndarray); + let nditer = NDIterType::new(ctx).construct(generator, ctx, out_ndarray); gen_for_callback( generator, ctx, @@ -80,9 +79,7 @@ impl<'ctx> NDArrayType<'ctx> { let other_nditers = broadcast_result .ndarrays .iter() - .map(|ndarray| { - NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *ndarray) - }) + .map(|ndarray| NDIterType::new(ctx).construct(generator, ctx, *ndarray)) .collect_vec(); Ok((nditer, other_nditers)) }, @@ -169,8 +166,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> { // Promote all input to ndarrays and map through them. let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec(); let ndarray = NDArrayType::new_broadcast( - generator, - ctx.ctx, + ctx, ret_dtype, &inputs.iter().map(NDArrayValue::get_type).collect_vec(), ) diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index a7bcb7e..fe73307 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -107,24 +107,56 @@ impl<'ctx> NDArrayType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`NDArrayType`]. - #[must_use] - pub fn new( - generator: &G, + fn new_impl( ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, ndims: u64, + llvm_usize: IntType<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize } } + /// Creates an instance of [`NDArrayType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>, dtype: BasicTypeEnum<'ctx>, ndims: u64) -> Self { + Self::new_impl(ctx.ctx, dtype, ndims, ctx.get_size_type()) + } + + /// Creates an instance of [`NDArrayType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + dtype: BasicTypeEnum<'ctx>, + ndims: u64, + ) -> Self { + Self::new_impl(ctx, dtype, ndims, generator.get_size_type(ctx)) + } + /// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more /// `ndarray` operands. #[must_use] - pub fn new_broadcast( + pub fn new_broadcast( + ctx: &CodeGenContext<'ctx, '_>, + dtype: BasicTypeEnum<'ctx>, + inputs: &[NDArrayType<'ctx>], + ) -> Self { + assert!(!inputs.is_empty()); + + Self::new_impl( + ctx.ctx, + dtype, + inputs.iter().map(NDArrayType::ndims).max().unwrap(), + ctx.get_size_type(), + ) + } + + /// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more + /// `ndarray` operands. + #[must_use] + pub fn new_broadcast_with_generator( generator: &G, ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, @@ -132,20 +164,28 @@ impl<'ctx> NDArrayType<'ctx> { ) -> Self { assert!(!inputs.is_empty()); - Self::new(generator, ctx, dtype, inputs.iter().map(NDArrayType::ndims).max().unwrap()) + Self::new_impl( + ctx, + dtype, + inputs.iter().map(NDArrayType::ndims).max().unwrap(), + generator.get_size_type(ctx), + ) } /// Creates an instance of [`NDArrayType`] with `ndims` of 0. #[must_use] - pub fn new_unsized( + pub fn new_unsized(ctx: &CodeGenContext<'ctx, '_>, dtype: BasicTypeEnum<'ctx>) -> Self { + Self::new_impl(ctx.ctx, dtype, 0, ctx.get_size_type()) + } + + /// Creates an instance of [`NDArrayType`] with `ndims` of 0. + #[must_use] + pub fn new_unsized_with_generator( generator: &G, ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); - - NDArrayType { ty: llvm_ndarray, dtype, ndims: 0, llvm_usize } + Self::new_impl(ctx, dtype, 0, generator.get_size_type(ctx)) } /// Creates an [`NDArrayType`] from a [unifier type][Type]. @@ -158,15 +198,9 @@ impl<'ctx> NDArrayType<'ctx> { let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let llvm_dtype = ctx.get_llvm_type(generator, dtype); - let llvm_usize = ctx.get_size_type(); let ndims = extract_ndims(&ctx.unifier, ndims); - NDArrayType { - ty: Self::llvm_type(ctx.ctx, llvm_usize), - dtype: llvm_dtype, - ndims, - llvm_usize, - } + Self::new_impl(ctx.ctx, llvm_dtype, ndims, ctx.get_size_type()) } /// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`. @@ -304,7 +338,7 @@ impl<'ctx> NDArrayType<'ctx> { ) -> >::Value { assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) + let ndarray = Self::new(ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); let llvm_usize = ctx.get_size_type(); @@ -339,7 +373,7 @@ impl<'ctx> NDArrayType<'ctx> { ) -> >::Value { assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) + let ndarray = Self::new(ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); let llvm_usize = ctx.get_size_type(); @@ -389,8 +423,8 @@ impl<'ctx> NDArrayType<'ctx> { .build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - let ndarray = Self::new_unsized(generator, ctx.ctx, value.get_type()) - .construct_uninitialized(generator, ctx, name); + let ndarray = + Self::new_unsized(ctx, value.get_type()).construct_uninitialized(generator, ctx, name); ctx.builder.build_store(ndarray.ptr_to_data(ctx), data).unwrap(); ndarray } diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index c77e457..1d83742 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -86,15 +86,27 @@ impl<'ctx> NDIterType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`NDIter`]. - #[must_use] - pub fn new(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { let llvm_nditer = Self::llvm_type(ctx, llvm_usize); Self { ty: llvm_nditer, llvm_usize } } + /// Creates an instance of [`NDIter`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`NDIter`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + /// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index 947f95a..5c73652 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -32,17 +32,34 @@ impl<'ctx> TupleType<'ctx> { ctx.struct_type(tys, false) } + fn new_impl( + ctx: &'ctx Context, + tys: &[BasicTypeEnum<'ctx>], + llvm_usize: IntType<'ctx>, + ) -> Self { + let llvm_tuple = Self::llvm_type(ctx, tys); + + Self { ty: llvm_tuple, llvm_usize } + } + /// Creates an instance of [`TupleType`]. #[must_use] - pub fn new( + pub fn new(ctx: &CodeGenContext<'ctx, '_>, tys: &[impl BasicType<'ctx>]) -> Self { + Self::new_impl( + ctx.ctx, + &tys.iter().map(BasicType::as_basic_type_enum).collect_vec(), + ctx.get_size_type(), + ) + } + + /// Creates an instance of [`TupleType`]. + #[must_use] + pub fn new_with_generator( generator: &G, ctx: &'ctx Context, tys: &[BasicTypeEnum<'ctx>], ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_tuple = Self::llvm_type(ctx, tys); - - Self { ty: llvm_tuple, llvm_usize } + Self::new_impl(ctx, tys, generator.get_size_type(ctx)) } /// Creates an [`TupleType`] from a [unifier type][Type]. diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index fa5a347..0ef4d1b 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -122,19 +122,31 @@ impl<'ctx> SliceType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type. - #[must_use] - pub fn new(ctx: &'ctx Context, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + fn new_impl(ctx: &'ctx Context, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { let llvm_ty = Self::llvm_type(ctx, int_ty); Self { ty: llvm_ty, int_ty, llvm_usize } } + /// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>, int_ty: IntType<'ctx>) -> Self { + Self::new_impl(ctx.ctx, int_ty, ctx.get_size_type()) + } + /// Creates an instance of [`SliceType`] with `usize` as its backing integer type. #[must_use] - pub fn new_usize(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); - Self::new(ctx, llvm_usize, llvm_usize) + pub fn new_usize(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type(), ctx.get_size_type()) + } + + /// Creates an instance of [`SliceType`] with `usize` as its backing integer type. + #[must_use] + pub fn new_usize_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx), generator.get_size_type(ctx)) } /// Creates an [`SliceType`] from a [`PointerType`] representing a `slice`. diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 08d2b6b..4ba5b6a 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -114,13 +114,9 @@ impl<'ctx> ListValue<'ctx> { /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. #[must_use] - pub fn as_i8_list( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> ListValue<'ctx> { + pub fn as_i8_list(&self, ctx: &CodeGenContext<'ctx, '_>) -> ListValue<'ctx> { let llvm_i8 = ctx.ctx.i8_type(); - let llvm_list_i8 = ::Type::new(generator, ctx.ctx, llvm_i8.into()); + let llvm_list_i8 = ::Type::new(ctx, &llvm_i8); Self::from_pointer_value( ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_base_type(), "").unwrap(), diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index b145746..b5182a2 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -104,7 +104,7 @@ impl<'ctx> NDArrayValue<'ctx> { assert!(self.ndims <= target_ndims); assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into()); - let broadcast_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, target_ndims) + let broadcast_ndarray = NDArrayType::new(ctx, self.dtype, target_ndims) .construct_uninitialized(generator, ctx, None); broadcast_ndarray.copy_shape_from_array( generator, @@ -147,7 +147,7 @@ fn broadcast_shapes<'ctx, G, Shape>( + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, { let llvm_usize = ctx.get_size_type(); - let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx); + let llvm_shape_ty = ShapeEntryType::new(ctx); assert!(in_shape_entries .iter() diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 52082df..0fbb85f 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -117,8 +117,8 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> ContiguousNDArrayValue<'ctx> { - let result = ContiguousNDArrayType::new(generator, ctx.ctx, self.dtype) - .alloca_var(generator, ctx, self.name); + let result = + ContiguousNDArrayType::new(ctx, &self.dtype).alloca_var(generator, ctx, self.name); // Set ndims and shape. let ndims = self.llvm_usize.const_int(self.ndims, false); @@ -178,8 +178,11 @@ impl<'ctx> NDArrayValue<'ctx> { // TODO: Debug assert `ndims == carray.ndims` to catch bugs. // Allocate the resulting ndarray. - let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, ndims) - .construct_uninitialized(generator, ctx, carray.name); + let ndarray = NDArrayType::new(ctx, carray.item, ndims).construct_uninitialized( + generator, + ctx, + carray.name, + ); // Copy shape and update strides let shape = carray.load_shape(ctx); diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 1a96522..60c9c3b 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -128,11 +128,10 @@ impl<'ctx> NDArrayValue<'ctx> { indices: &[RustNDIndex<'ctx>], ) -> Self { let dst_ndims = self.deduce_ndims_after_indexing_with(indices); - let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, dst_ndims) + let dst_ndarray = NDArrayType::new(ctx, self.dtype, dst_ndims) .construct_uninitialized(generator, ctx, None); - let indices = - NDIndexType::new(generator, ctx.ctx).construct_ndindices(generator, ctx, indices); + let indices = NDIndexType::new(ctx).construct_ndindices(generator, ctx, indices); irrt::ndarray::call_nac3_ndarray_index(generator, ctx, indices, *self, dst_ndarray); dst_ndarray @@ -245,8 +244,7 @@ impl<'ctx> RustNDIndex<'ctx> { } RustNDIndex::Slice(in_rust_slice) => { let user_slice_ptr = - SliceType::new(ctx.ctx, ctx.ctx.i32_type(), ctx.get_size_type()) - .alloca_var(generator, ctx, None); + SliceType::new(ctx, ctx.ctx.i32_type()).alloca_var(generator, ctx, None); in_rust_slice.write_to_slice(ctx, user_slice_ptr); dst_ndindex.store_data( diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs index a24316b..f12d36c 100644 --- a/nac3core/src/codegen/values/ndarray/matmul.rs +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -108,7 +108,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape); let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape); - let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, ndims_int) + let dst = NDArrayType::new(ctx, llvm_dst_dtype, ndims_int) .construct_uninitialized(generator, ctx, None); dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator)); unsafe { diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 595345e..705412e 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -377,12 +377,8 @@ impl<'ctx> NDArrayValue<'ctx> { .map(|obj| obj.as_basic_value_enum()) .collect_vec(); - TupleType::new( - generator, - ctx.ctx, - &repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(), - ) - .construct_from_objects(ctx, objects, None) + TupleType::new(ctx, &repeat_n(llvm_i32, self.ndims as usize).collect_vec()) + .construct_from_objects(ctx, objects, None) } /// Create the strides tuple of this ndarray like @@ -411,12 +407,8 @@ impl<'ctx> NDArrayValue<'ctx> { .map(|obj| obj.as_basic_value_enum()) .collect_vec(); - TupleType::new( - generator, - ctx.ctx, - &repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(), - ) - .construct_from_objects(ctx, objects, None) + TupleType::new(ctx, &repeat_n(llvm_i32, self.ndims as usize).collect_vec()) + .construct_from_objects(ctx, objects, None) } /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. @@ -998,10 +990,8 @@ impl<'ctx> ScalarOrNDArray<'ctx> { ) -> NDArrayValue<'ctx> { match self { ScalarOrNDArray::NDArray(ndarray) => *ndarray, - ScalarOrNDArray::Scalar(scalar) => { - NDArrayType::new_unsized(generator, ctx.ctx, scalar.get_type()) - .construct_unsized(generator, ctx, scalar, None) - } + ScalarOrNDArray::Scalar(scalar) => NDArrayType::new_unsized(ctx, scalar.get_type()) + .construct_unsized(generator, ctx, scalar, None), } } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 3784193..dd900d6 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -160,9 +160,7 @@ impl<'ctx> NDArrayValue<'ctx> { generator, ctx, Some("ndarray_foreach"), - |generator, ctx| { - Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self)) - }, + |generator, ctx| Ok(NDIterType::new(ctx).construct(generator, ctx, *self)), |_, ctx, nditer| Ok(nditer.has_element(ctx)), |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), |_, ctx, nditer| { diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index f68931f..9ab3d30 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -65,7 +65,7 @@ impl<'ctx> NDArrayValue<'ctx> { // not contiguous but could be reshaped without copying data. Look into how numpy does // it. - let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, new_ndims) + let dst_ndarray = NDArrayType::new(ctx, self.dtype, new_ndims) .construct_uninitialized(generator, ctx, None); dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); From 762a2447c3f57e9a09c014d39fe2869499e0dbd9 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 15 Jan 2025 16:08:55 +0800 Subject: [PATCH 07/49] [core] codegen: Remove obsolete comments Comments regarding the need for `llvm.stack{save,restore}` is obsolete now that `NDIter::indices` is allocated at the beginning of the function. --- nac3core/src/codegen/types/ndarray/nditer.rs | 5 ----- nac3core/src/codegen/values/ndarray/nditer.rs | 4 ---- 2 files changed, 9 deletions(-) diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 1d83742..45b6bb0 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -163,11 +163,6 @@ impl<'ctx> NDIterType<'ctx> { } /// Allocate an [`NDIter`] that iterates through the given `ndarray`. - /// - /// Note: This function allocates an array on the stack at the current builder location, which - /// may lead to stack explosion if called in a hot loop. Therefore, callers are recommended to - /// call `llvm.stacksave` before calling this function and call `llvm.stackrestore` after the - /// [`NDIter`] is no longer needed. #[must_use] pub fn construct( &self, diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index dd900d6..86f370e 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -137,10 +137,6 @@ impl<'ctx> NDArrayValue<'ctx> { /// /// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterValue`] to /// get properties of the current iteration (e.g., the current element, indices, etc.) - /// - /// Note: The caller is recommended to call `llvm.stacksave` and `llvm.stackrestore` before and - /// after invoking this function respectively. See [`NDIterType::construct`] for an explanation - /// on why this is suggested. pub fn foreach<'a, G, F>( &self, generator: &mut G, From 357970a793b837b2553cf44e5c9ee3d92104b1d2 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 15 Jan 2025 15:19:18 +0800 Subject: [PATCH 08/49] [core] codegen/stmt: Add build_{break,continue}_branch functions --- nac3core/src/codegen/expr.rs | 4 +--- nac3core/src/codegen/stmt.rs | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8f52e92..30b8dcd 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2122,9 +2122,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().const_zero(), ) .unwrap(); - ctx.builder - .build_unconditional_branch(hooks.exit_bb) - .unwrap(); + hooks.build_break_branch(&ctx.builder); Ok(()) }, diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 85a894a..7b99bc2 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -1,6 +1,7 @@ use inkwell::{ attributes::{Attribute, AttributeLoc}, basic_block::BasicBlock, + builder::Builder, types::{BasicType, BasicTypeEnum}, values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, IntPredicate, @@ -662,11 +663,25 @@ pub fn gen_for( #[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)] pub struct BreakContinueHooks<'ctx> { /// The [exit block][`BasicBlock`] to branch to when `break`-ing out of a loop. - pub exit_bb: BasicBlock<'ctx>, + exit_bb: BasicBlock<'ctx>, /// The [latch basic block][`BasicBlock`] to branch to for `continue`-ing to the next iteration /// of the loop. - pub latch_bb: BasicBlock<'ctx>, + latch_bb: BasicBlock<'ctx>, +} + +impl<'ctx> BreakContinueHooks<'ctx> { + /// Creates a [`br` instruction][Builder::build_unconditional_branch] to the exit + /// [`BasicBlock`], as if by calling `break`. + pub fn build_break_branch(&self, builder: &Builder<'ctx>) { + builder.build_unconditional_branch(self.exit_bb).unwrap(); + } + + /// Creates a [`br` instruction][Builder::build_unconditional_branch] to the latch + /// [`BasicBlock`], as if by calling `continue`. + pub fn build_continue_branch(&self, builder: &Builder<'ctx>) { + builder.build_unconditional_branch(self.latch_bb).unwrap(); + } } /// Generates a C-style `for` construct using lambdas, similar to the following C code: From 18e8e5269fbbac609021b60bf95a6f1125166f3f Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 15 Jan 2025 15:35:31 +0800 Subject: [PATCH 09/49] [core] codegen/values/ndarray: Add fold utilities Needed for np_{any,all}. --- nac3core/src/codegen/values/ndarray/fold.rs | 101 ++++++++++++++++++++ nac3core/src/codegen/values/ndarray/mod.rs | 1 + 2 files changed, 102 insertions(+) create mode 100644 nac3core/src/codegen/values/ndarray/fold.rs diff --git a/nac3core/src/codegen/values/ndarray/fold.rs b/nac3core/src/codegen/values/ndarray/fold.rs new file mode 100644 index 0000000..7c8aebd --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/fold.rs @@ -0,0 +1,101 @@ +use inkwell::values::{BasicValue, BasicValueEnum}; + +use super::{NDArrayValue, NDIterValue, ScalarOrNDArray}; +use crate::codegen::{ + stmt::{gen_for_callback, BreakContinueHooks}, + types::ndarray::NDIterType, + CodeGenContext, CodeGenerator, +}; + +impl<'ctx> NDArrayValue<'ctx> { + /// Folds the elements of this ndarray into an accumulator value by applying `f`, returning the + /// final value. + /// + /// `f` has access to [`BreakContinueHooks`] to short-circuit the `fold` operation, an instance + /// of `V` representing the current accumulated value, and an [`NDIterValue`] to get the + /// properties of the current iterated element. + pub fn fold<'a, G, V, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + init: V, + f: F, + ) -> Result + where + G: CodeGenerator + ?Sized, + V: BasicValue<'ctx> + TryFrom>, + >>::Error: std::fmt::Debug, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + V, + NDIterValue<'ctx>, + ) -> Result, + { + let acc_ptr = + generator.gen_var_alloc(ctx, init.as_basic_value_enum().get_type(), None).unwrap(); + ctx.builder.build_store(acc_ptr, init).unwrap(); + + gen_for_callback( + generator, + ctx, + Some("ndarray_fold"), + |generator, ctx| Ok(NDIterType::new(ctx).construct(generator, ctx, *self)), + |_, ctx, nditer| Ok(nditer.has_element(ctx)), + |generator, ctx, hooks, nditer| { + let acc = V::try_from(ctx.builder.build_load(acc_ptr, "").unwrap()).unwrap(); + let acc = f(generator, ctx, hooks, acc, nditer)?; + ctx.builder.build_store(acc_ptr, acc).unwrap(); + Ok(()) + }, + |_, ctx, nditer| { + nditer.next(ctx); + Ok(()) + }, + )?; + + let acc = ctx.builder.build_load(acc_ptr, "").unwrap(); + Ok(V::try_from(acc).unwrap()) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// See [`NDArrayValue::fold`]. + /// + /// The primary differences between this function and `NDArrayValue::fold` are: + /// + /// - The 3rd parameter of `f` is an `Option` of hooks, since `break`/`continue` hooks are not + /// available if this instance represents a scalar value. + /// - The 5th parameter of `f` is a [`BasicValueEnum`], since no [iterator][`NDIterValue`] will + /// be created if this instance represents a scalar value. + pub fn fold<'a, G, V, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + init: V, + f: F, + ) -> Result + where + G: CodeGenerator + ?Sized, + V: BasicValue<'ctx> + TryFrom>, + >>::Error: std::fmt::Debug, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + Option<&BreakContinueHooks<'ctx>>, + V, + BasicValueEnum<'ctx>, + ) -> Result, + { + match self { + ScalarOrNDArray::Scalar(v) => f(generator, ctx, None, init, *v), + ScalarOrNDArray::NDArray(v) => { + v.fold(generator, ctx, init, |generator, ctx, hooks, acc, nditer| { + let elem = nditer.get_scalar(ctx); + f(generator, ctx, Some(&hooks), acc, elem) + }) + } + } + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 705412e..1bf5db3 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -30,6 +30,7 @@ pub use nditer::*; mod broadcast; mod contiguous; +mod fold; mod indexing; mod map; mod matmul; From 1cfaa1a77952f12456440557ae1ad1a3493253ff Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 15 Jan 2025 15:36:03 +0800 Subject: [PATCH 10/49] [core] toplevel: Implement np_{any,all} --- nac3core/src/toplevel/builtins.rs | 70 +++++++++++++++++++++++++-- nac3core/src/toplevel/helper.rs | 4 ++ nac3standalone/demo/interpret_demo.py | 2 + nac3standalone/demo/src/ndarray.py | 56 +++++++++++++++++++++ 4 files changed, 129 insertions(+), 3 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index a0673a1..e06366c 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -6,7 +6,8 @@ use strum::IntoEnumIterator; use super::{ helper::{ - debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDef, PrimDefDetails, + arraylike_flatten_element_type, debug_assert_prim_is_allowed, extract_ndims, + make_exception_fields, PrimDef, PrimDefDetails, }, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, *, @@ -15,9 +16,12 @@ use crate::{ codegen::{ builtin_fns, numpy::*, - stmt::exn_constructor, + stmt::{exn_constructor, gen_if_callback}, types::ndarray::NDArrayType, - values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, RangeValue}, + values::{ + ndarray::{shape::parse_numpy_int_sequence, ScalarOrNDArray}, + ProxyValue, RangeValue, + }, }, symbol_resolver::SymbolValue, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, @@ -405,6 +409,8 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim), + PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim), + PrimDef::FunNpSin | PrimDef::FunNpCos | PrimDef::FunNpTan @@ -1720,6 +1726,64 @@ impl<'a> BuiltinBuilder<'a> { ) } + fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]); + + let param_ty = &[(self.num_or_ndarray_ty.ty, "a")]; + let ret_ty = self.primitives.bool; + let var_map = &self.num_or_ndarray_var_map; + let codegen_callback: Box = + Box::new(move |ctx, _, fun, args, generator| { + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_i1_k0 = llvm_i1.const_zero(); + let llvm_i1_k1 = llvm_i1.const_all_ones(); + + let a_ty = fun.0.args[0].ty; + let a_val = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; + let a = ScalarOrNDArray::from_value(generator, ctx, (a_ty, a_val)); + let a_elem_ty = arraylike_flatten_element_type(&mut ctx.unifier, a_ty); + + let (init, sc_val) = match prim { + PrimDef::FunNpAny => (llvm_i1_k0, llvm_i1_k1), + PrimDef::FunNpAll => (llvm_i1_k1, llvm_i1_k0), + _ => unreachable!(), + }; + + let acc = a.fold(generator, ctx, init, |generator, ctx, hooks, acc, elem| { + gen_if_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare(IntPredicate::EQ, acc, sc_val, "") + .unwrap()) + }, + |_, ctx| { + if let Some(hooks) = hooks { + hooks.build_break_branch(&ctx.builder); + } + Ok(()) + }, + |_, _| Ok(()), + )?; + + let is_truthy = + builtin_fns::call_bool(generator, ctx, (a_elem_ty, elem))?.into_int_value(); + + Ok(match prim { + PrimDef::FunNpAny => ctx.builder.build_or(acc, is_truthy, "").unwrap(), + PrimDef::FunNpAll => ctx.builder.build_and(acc, is_truthy, "").unwrap(), + _ => unreachable!(), + }) + })?; + + Ok(Some(acc.as_basic_value_enum())) + }); + + create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback) + } + /// Build 1-ary numpy/scipy functions that take in a float or an ndarray and return a value of the same type as the input. fn build_np_sp_float_or_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef { debug_assert_prim_is_allowed( diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index de90a41..72d3eaa 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -111,6 +111,8 @@ pub enum PrimDef { FunNpLdExp, FunNpHypot, FunNpNextAfter, + FunNpAny, + FunNpAll, // Linalg functions FunNpDot, @@ -305,6 +307,8 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), + PrimDef::FunNpAny => fun("np_any", None), + PrimDef::FunNpAll => fun("np_all", None), // Linalg functions PrimDef::FunNpDot => fun("np_dot", None), diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index fa91ed3..180d24f 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -232,6 +232,8 @@ def patch(module): module.np_ldexp = np.ldexp module.np_hypot = np.hypot module.np_nextafter = np.nextafter + module.np_any = np.any + module.np_all = np.all # SciPy Math functions module.sp_spec_erf = special.erf diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index b668860..d077b82 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1551,6 +1551,59 @@ def test_ndarray_nextafter_broadcast_rhs_scalar(): output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_ones) + +def test_ndarray_any(): + s0 = 0 + output_bool(np_any(s0)) + s1 = 1 + output_bool(np_any(s1)) + + x1 = np_identity(5) + y1 = np_any(x1) + output_ndarray_float_2(x1) + output_bool(y1) + + x2 = np_identity(1) + y2 = np_any(x2) + output_ndarray_float_2(x2) + output_bool(y2) + + x3 = np_array([[1.0, 2.0], [3.0, 4.0]]) + y3 = np_any(x3) + output_ndarray_float_2(x3) + output_bool(y3) + + x4 = np_zeros([3, 5]) + y4 = np_any(x4) + output_ndarray_float_2(x4) + output_bool(y4) + +def test_ndarray_all(): + s0 = 0 + output_bool(np_all(s0)) + s1 = 1 + output_bool(np_all(s1)) + + x1 = np_identity(5) + y1 = np_all(x1) + output_ndarray_float_2(x1) + output_bool(y1) + + x2 = np_identity(1) + y2 = np_all(x2) + output_ndarray_float_2(x2) + output_bool(y2) + + x3 = np_array([[1.0, 2.0], [3.0, 4.0]]) + y3 = np_all(x3) + output_ndarray_float_2(x3) + output_bool(y3) + + x4 = np_zeros([3, 5]) + y4 = np_all(x4) + output_ndarray_float_2(x4) + output_bool(y4) + def test_ndarray_dot(): x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) @@ -1851,6 +1904,9 @@ def run() -> int32: test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar() + test_ndarray_any() + test_ndarray_all() + test_ndarray_dot() test_ndarray_cholesky() test_ndarray_qr() From 933804e2707d3e961496d60615693144351f0049 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Wed, 15 Jan 2025 21:18:45 +0800 Subject: [PATCH 11/49] update dependencies --- Cargo.lock | 111 ++++++++++++++++++++++++++++------------------------- flake.lock | 6 +-- 2 files changed, 62 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 646700d..c1d9352 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,11 +65,12 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "3.0.6" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", + "once_cell", "windows-sys 0.59.0", ] @@ -105,9 +106,9 @@ checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" [[package]] name = "bitflags" -version = "2.6.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" [[package]] name = "block-buffer" @@ -126,9 +127,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cc" -version = "1.2.7" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a012a0df96dd6d06ba9a1b29d6402d1a5d77c6befd2566afdc26e10603dc93d7" +checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" dependencies = [ "shlex", ] @@ -141,9 +142,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.23" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" dependencies = [ "clap_builder", "clap_derive", @@ -151,9 +152,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.23" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" dependencies = [ "anstream", "anstyle", @@ -163,14 +164,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.18" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -472,7 +473,7 @@ checksum = "9dd28cfd4cfba665d47d31c08a6ba637eed16770abca2eccbbc3ca831fef1e44" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -581,9 +582,9 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "llvm-sys" @@ -610,9 +611,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "memchr" @@ -678,7 +679,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", "trybuild", ] @@ -761,45 +762,45 @@ dependencies = [ [[package]] name = "phf" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" dependencies = [ "phf_macros", - "phf_shared 0.11.2", + "phf_shared 0.11.3", ] [[package]] name = "phf_codegen" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" dependencies = [ "phf_generator", - "phf_shared 0.11.2", + "phf_shared 0.11.3", ] [[package]] name = "phf_generator" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ - "phf_shared 0.11.2", + "phf_shared 0.11.3", "rand", ] [[package]] name = "phf_macros" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b" +checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" dependencies = [ "phf_generator", - "phf_shared 0.11.2", + "phf_shared 0.11.3", "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -808,16 +809,16 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" dependencies = [ - "siphasher", + "siphasher 0.3.11", ] [[package]] name = "phf_shared" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" dependencies = [ - "siphasher", + "siphasher 1.0.1", ] [[package]] @@ -873,9 +874,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] @@ -927,7 +928,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -940,7 +941,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -1049,9 +1050,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.42" +version = "0.38.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85" +checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" dependencies = [ "bitflags", "errno", @@ -1110,14 +1111,14 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] name = "serde_json" -version = "1.0.134" +version = "1.0.135" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d" +checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" dependencies = [ "itoa", "memchr", @@ -1174,6 +1175,12 @@ version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "smallvec" version = "1.13.2" @@ -1226,7 +1233,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -1242,9 +1249,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.94" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "987bc0be1cdea8b10216bd06e2ca407d40b9543468fafd3ddfb02f36e77f71f3" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -1326,7 +1333,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", ] [[package]] @@ -1604,9 +1611,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.22" +version = "0.6.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39281189af81c07ec09db316b302a3e67bf9bd7cbf6c820b50e35fee9c2fa980" +checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" dependencies = [ "memchr", ] @@ -1638,5 +1645,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn 2.0.96", ] diff --git a/flake.lock b/flake.lock index 7672c21..3e4af70 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1735834308, - "narHash": "sha256-dklw3AXr3OGO4/XT1Tu3Xz9n/we8GctZZ75ZWVqAVhk=", + "lastModified": 1736798957, + "narHash": "sha256-qwpCtZhSsSNQtK4xYGzMiyEDhkNzOCz/Vfu4oL2ETsQ=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "6df24922a1400241dae323af55f30e4318a6ca65", + "rev": "9abb87b552b7f55ac8916b6fc9e5cb486656a2f3", "type": "github" }, "original": { From c15062ab4c1ff391b327b91f580b541f37cdde56 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Wed, 15 Jan 2025 21:33:58 +0800 Subject: [PATCH 12/49] msys2: update --- nix/windows/msys2_packages.nix | 162 ++++++++++++++++----------------- 1 file changed, 81 insertions(+), 81 deletions(-) diff --git a/nix/windows/msys2_packages.nix b/nix/windows/msys2_packages.nix index 0ac1aa8..0859244 100644 --- a/nix/windows/msys2_packages.nix +++ b/nix/windows/msys2_packages.nix @@ -1,15 +1,15 @@ { pkgs } : [ (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunwind-19.1.4-1-any.pkg.tar.zst"; - sha256 = "0frb5k16bbxdf8g379d16vl3qrh7n9pydn83gpfxpvwf3qlvnzyl"; - name = "mingw-w64-clang-x86_64-libunwind-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunwind-19.1.6-1-any.pkg.tar.zst"; + sha256 = "1gv6hbqvfgjzirpljql1shlchldmf5ww3rfsspg90pq1frnwavjl"; + name = "mingw-w64-clang-x86_64-libunwind-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc++-19.1.4-1-any.pkg.tar.zst"; - sha256 = "0wh5km0v8j50pqz9bxb4f0w7r8zhsvssrjvc94np53iq8wjagk86"; - name = "mingw-w64-clang-x86_64-libc++-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc++-19.1.6-1-any.pkg.tar.zst"; + sha256 = "1wbkvrx14ahc04cgkydvlxwmsl8jfnqwhy9sy4kn4wkdzmlcp1ax"; + name = "mingw-w64-clang-x86_64-libc++-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -19,15 +19,15 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libiconv-1.17-4-any.pkg.tar.zst"; - sha256 = "1g2bkhgf60dywccxw911ydyigf3m25yqfh81m5099swr7mjsmzyf"; - name = "mingw-w64-clang-x86_64-libiconv-1.17-4-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libiconv-1.18-1-any.pkg.tar.zst"; + sha256 = "0vn5xgx9jjg66f8r9ylm9220qdbjdkffykfl6nwj14zv9y7xh4nj"; + name = "mingw-w64-clang-x86_64-libiconv-1.18-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-gettext-runtime-0.22.5-2-any.pkg.tar.zst"; - sha256 = "0ll6ci6d3mc7g04q0xixjc209bh8r874dqbczgns69jsad3wg6mi"; - name = "mingw-w64-clang-x86_64-gettext-runtime-0.22.5-2-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-gettext-runtime-0.23.1-1-any.pkg.tar.zst"; + sha256 = "0wbp5pmrr0rk4mx7d1frvqlk4a061zw31zscs57srmvl0wv3pi2a"; + name = "mingw-w64-clang-x86_64-gettext-runtime-0.23.1-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -55,69 +55,69 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-libs-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1clrbm8dk893byj8s15pgcgqqijm2zkd10zgyakamd8m354kj9q4"; - name = "mingw-w64-clang-x86_64-llvm-libs-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-libs-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0fpsnfyf0bg39a4ygzga06sr4wv4jp1jnc8lk6sr3z0nim0nlhjn"; + name = "mingw-w64-clang-x86_64-llvm-libs-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1iz2c9475h8p20ydpp0znbhyb62rlrk7wr7xl7cmwbam7wkwr8rn"; - name = "mingw-w64-clang-x86_64-llvm-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0whqs9nvfmgxj3c83px6dipcdw9zi858kgd8130201fy1mbnafp1"; + name = "mingw-w64-clang-x86_64-llvm-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-libs-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1hidciwlakxrp4kyb0j2v6g4lv76nn834g6b88w1j94fk3qc765d"; - name = "mingw-w64-clang-x86_64-clang-libs-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-libs-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0rmzri7h043i73jy3c2jcrg3hy40dr5s9n96kmxgaghfhvlpilps"; + name = "mingw-w64-clang-x86_64-clang-libs-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-compiler-rt-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1m1yhjkgzlbk10sv966qk4yji009ga0lr25gpgj2w7mcd2wixcr3"; - name = "mingw-w64-clang-x86_64-compiler-rt-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-compiler-rt-19.1.6-1-any.pkg.tar.zst"; + sha256 = "04cqlh35asvlh06nmhwnx9h0yrqk8zxd9lpzxmm1xh64kvm9maxn"; + name = "mingw-w64-clang-x86_64-compiler-rt-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; - sha256 = "08gxc7h2achckknn6fz3p6yi7gxxvbaday8fpm4j56c4sa04n0df"; - name = "mingw-w64-clang-x86_64-headers-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; + sha256 = "05zsqgq8zwdcfacyqdxdjcf80447bgnrz71xv5cds0y135yziy7l"; + name = "mingw-w64-clang-x86_64-headers-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-crt-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; - sha256 = "0fxd1pb197ki0gzw6z8gmd6wgpd9d28js6cp5d31d55kw7d1vz13"; - name = "mingw-w64-clang-x86_64-crt-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-crt-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; + sha256 = "12fkxpk7rwy36snvvc7sdivx81pd4ckzh5ilyh7gl6ly4qayppp6"; + name = "mingw-w64-clang-x86_64-crt-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-lld-19.1.4-1-any.pkg.tar.zst"; - sha256 = "1a8pjyhrzpc2z3784xxwix4i7yrz03ygnsk1wv9k0yq8m8wi9nbw"; - name = "mingw-w64-clang-x86_64-lld-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-lld-19.1.6-1-any.pkg.tar.zst"; + sha256 = "102bbv5acq1fvrfn8bp1x3503cb8hvcxmlpr86qsba4vm11l0wrw"; + name = "mingw-w64-clang-x86_64-lld-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; - sha256 = "140m312jx1sywqjkvfij69d268m4jpdmilq5bb8khkf0ayb16036"; - name = "mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; + sha256 = "1sris0qczxk5px9xy85976hbmqrpg49ns7yyzd9p455ckf740cid"; + name = "mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; - sha256 = "017j4h511wg37bacym73f8g6s0jcfgzbzabzxpc6anr3gy4kkpbg"; - name = "mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; + sha256 = "1r0m5xpsxdl00a2daj4p0wgl6037700pvw6p6zl91h1dr092r6pa"; + name = "mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-19.1.4-1-any.pkg.tar.zst"; - sha256 = "11f4i4ai2bzvq6f06vxk1ymv7056c9707vdw489f1i2bdrf0c0ii"; - name = "mingw-w64-clang-x86_64-clang-19.1.4-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0j4a642fpnvqs79chhinc8r5q53q1wllmc1bzb01a4y7w9rqg4hw"; + name = "mingw-w64-clang-x86_64-clang-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rust-1.83.0-3-any.pkg.tar.zst"; - sha256 = "0nxs571vb4f1i5vp91134p5blns9ml2r25nx6kdlg0zhd5x85kvm"; - name = "mingw-w64-clang-x86_64-rust-1.83.0-3-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rust-1.84.0-1-any.pkg.tar.zst"; + sha256 = "0nrz9788grl50nkbhxswry143rrwpdnc6pk6f0k30kcp19qq6y2d"; + name = "mingw-w64-clang-x86_64-rust-1.84.0-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -127,9 +127,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-c-ares-1.34.3-1-any.pkg.tar.zst"; - sha256 = "1mpn397qsdz3l2fav6ymwjlj96ialn9m8sldii3ymbcyhranl3xx"; - name = "mingw-w64-clang-x86_64-c-ares-1.34.3-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-c-ares-1.34.4-1-any.pkg.tar.zst"; + sha256 = "1dppwwx3wrn0lzrlk2q7bpsainbidrpw1ndp1aasyv42xhxl1sn1"; + name = "mingw-w64-clang-x86_64-c-ares-1.34.4-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -139,9 +139,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunistring-1.2-1-any.pkg.tar.zst"; - sha256 = "13nz49li39z1zgfx1q9jg4vrmyrmqb6qdq0nqshidaqc6zr16k3g"; - name = "mingw-w64-clang-x86_64-libunistring-1.2-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunistring-1.3-1-any.pkg.tar.zst"; + sha256 = "1zg58qbfybyqzcj0dalb13l48f9jsras318h02rka65r7wi0pdcg"; + name = "mingw-w64-clang-x86_64-libunistring-1.3-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -169,9 +169,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ca-certificates-20240203-1-any.pkg.tar.zst"; - sha256 = "1q5nxhsk04gidz66ai5wgd4dr04lfyakkfja9p0r5hrgg4ppqqjg"; - name = "mingw-w64-clang-x86_64-ca-certificates-20240203-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ca-certificates-20241223-1-any.pkg.tar.zst"; + sha256 = "0c36lg63imzw8i6j1ard42v5wgzpc83phzk8lvifvm0djndq2bbj"; + name = "mingw-w64-clang-x86_64-ca-certificates-20241223-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -193,9 +193,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp3-1.6.0-1-any.pkg.tar.zst"; - sha256 = "1p7q47fin12vzyf126v1azbbpgpa0y6ighfh6mbfdb6zcyq74kbd"; - name = "mingw-w64-clang-x86_64-nghttp3-1.6.0-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp3-1.7.0-1-any.pkg.tar.zst"; + sha256 = "0kd2f7yh90815kyldxvdy8c6jyxyw0wv4f7k3shwp98w874m0mxd"; + name = "mingw-w64-clang-x86_64-nghttp3-1.7.0-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -271,15 +271,15 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rhash-1.4.4-3-any.pkg.tar.zst"; - sha256 = "1ysbxirpfr0yf7pvyps75lnwc897w2a2kcid3nb4j6ilw6n64jmc"; - name = "mingw-w64-clang-x86_64-rhash-1.4.4-3-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rhash-1.4.5-1-any.pkg.tar.zst"; + sha256 = "0gdn1351knjwgsqgyaa3l55qs135k7dn6mlf04vzjxlc1895wx5z"; + name = "mingw-w64-clang-x86_64-rhash-1.4.5-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cmake-3.31.2-3-any.pkg.tar.zst"; - sha256 = "139f91r392c68hsajm0c81690pmzkywb0p4x8ms8ms53ncxnz6gz"; - name = "mingw-w64-clang-x86_64-cmake-3.31.2-3-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cmake-3.31.4-1-any.pkg.tar.zst"; + sha256 = "1xjjwgkqf2j97pcx0yd6j0lgmzgbgqjjf0s7j29mc03g89fhdhw0"; + name = "mingw-w64-clang-x86_64-cmake-3.31.4-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -289,9 +289,9 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ncurses-6.5.20240831-1-any.pkg.tar.zst"; - sha256 = "1hlfj9g4s767s502sawwbcv4a0xd3ym3ip4jswmhq48wh5050iyb"; - name = "mingw-w64-clang-x86_64-ncurses-6.5.20240831-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ncurses-6.5.20241228-3-any.pkg.tar.zst"; + sha256 = "0f98pzrwsxil90n55hz2ym2x2rzrrjrmnj8i2203n189qbxbg2c9"; + name = "mingw-w64-clang-x86_64-ncurses-6.5.20241228-3-any.pkg.tar.zst"; }) (pkgs.fetchurl { @@ -331,32 +331,32 @@ }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-3.12.7-3-any.pkg.tar.zst"; - sha256 = "1v15j2pzy9wj4n1rjngdi2hf8h0l9z4lri3xb86yvdv1xl2msj6h"; - name = "mingw-w64-clang-x86_64-python-3.12.7-3-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-3.12.8-2-any.pkg.tar.zst"; + sha256 = "0lksgrmylvpr7yyjcc1szm30pnag7ixrj7vhdql1ryi4k9309v8s"; + name = "mingw-w64-clang-x86_64-python-3.12.8-2-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-openmp-19.1.4-2-any.pkg.tar.zst"; - sha256 = "1pn1fbj74rx837s9z8gqs4b0cr7kqi5m1m2mi9ibjpw64m1aqwxv"; - name = "mingw-w64-clang-x86_64-llvm-openmp-19.1.4-2-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-openmp-19.1.6-1-any.pkg.tar.zst"; + sha256 = "0d3mm26hnw716n0ppzqhydxcgm4im081hiiy6l4zp267ad3kfg93"; + name = "mingw-w64-clang-x86_64-llvm-openmp-19.1.6-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openblas-0.3.28-2-any.pkg.tar.zst"; - sha256 = "18p1zhf7h3k3phf3bl483jg3k7y9zq375z6ww75g62158ic9lfyc"; - name = "mingw-w64-clang-x86_64-openblas-0.3.28-2-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openblas-0.3.29-1-any.pkg.tar.zst"; + sha256 = "006f2s12jmk35rppkp20rlm7k4kknsnh5h4krqs2ry2rd6qqkk9h"; + name = "mingw-w64-clang-x86_64-openblas-0.3.29-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-numpy-2.1.1-2-any.pkg.tar.zst"; - sha256 = "1kiy7ail04ias47xbbhl9vpsz02g0g3f29ncgx5gcks9vgqldp6m"; - name = "mingw-w64-clang-x86_64-python-numpy-2.1.1-2-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-numpy-2.2.1-1-any.pkg.tar.zst"; + sha256 = "0sgkhax9cwmkkrfrir45l91h6pgg339gaw6147gsayf8h8ag4brg"; + name = "mingw-w64-clang-x86_64-python-numpy-2.2.1-1-any.pkg.tar.zst"; }) (pkgs.fetchurl { - url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-setuptools-75.6.0-1-any.pkg.tar.zst"; - sha256 = "03l04kjmy5p9whaw0h619gdg7yw1gxbz8phifq4pzh3c1wlw7yfd"; - name = "mingw-w64-clang-x86_64-python-setuptools-75.6.0-1-any.pkg.tar.zst"; + url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-setuptools-75.8.0-1-any.pkg.tar.zst"; + sha256 = "12ivpaj967y4bi8396q3fpii4fy5aakidxpv16rkyg1b831k0h93"; + name = "mingw-w64-clang-x86_64-python-setuptools-75.8.0-1-any.pkg.tar.zst"; }) ] From 4bd5349381d4b6ed4b3bfe430ffe13649f24050b Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 10 Jan 2025 10:49:13 +0800 Subject: [PATCH 13/49] [core] add attributes to class string --- nac3core/src/toplevel/helper.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 72d3eaa..eb72d37 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -379,21 +379,29 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { - TopLevelDef::Class { name, ancestors, fields, methods, type_vars, .. } => { + TopLevelDef::Class { + name, ancestors, fields, methods, attributes, type_vars, .. + } => { let fields_str = fields .iter() .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty))) .collect_vec(); + let attributes_str = attributes + .iter() + .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty))) + .collect_vec(); + let methods_str = methods .iter() .map(|(n, ty, id)| (n.to_string(), unifier.stringify(*ty), *id)) .collect_vec(); format!( - "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", + "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nattributes: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", name, ancestors.iter().map(|ancestor| ancestor.stringify(unifier)).collect_vec(), fields_str.iter().map(|(a, _)| a).collect_vec(), + attributes_str.iter().map(|(a, _)| a).collect_vec(), methods_str.iter().map(|(a, b, _)| (a, b)).collect_vec(), type_vars.iter().map(|id| unifier.stringify(*id)).collect_vec(), ) From febfd1241d64486b6496b4a1c92ea18aa743e4d7 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 10 Jan 2025 11:06:14 +0800 Subject: [PATCH 14/49] [core] add module type --- nac3artiq/src/lib.rs | 1 + nac3core/src/codegen/expr.rs | 2 +- nac3core/src/toplevel/composer.rs | 3 ++- nac3core/src/toplevel/helper.rs | 7 +++++++ nac3core/src/toplevel/mod.rs | 12 ++++++++++++ nac3core/src/typecheck/type_inferencer/mod.rs | 3 ++- 6 files changed, 25 insertions(+), 3 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index d35e66d..59d4dbe 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -713,6 +713,7 @@ impl Nac3 { "Unsupported @rpc annotation on global variable", ))) } + TopLevelDef::Module { .. } => unreachable!("Type module cannot be decorated with @rpc"), } } } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 30b8dcd..6d2057e 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -979,7 +979,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( TopLevelDef::Class { .. } => { return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?)) } - TopLevelDef::Variable { .. } => unreachable!(), + TopLevelDef::Variable { .. } | TopLevelDef::Module { .. } => unreachable!(), } } .or_else(|_: String| { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index bd9a921..b293fb4 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -101,7 +101,8 @@ impl TopLevelComposer { let builtin_name_list = definition_ast_list .iter() .map(|def_ast| match *def_ast.0.read() { - TopLevelDef::Class { name, .. } => name.to_string(), + TopLevelDef::Class { name, .. } + | TopLevelDef::Module { name, .. } => name.to_string(), TopLevelDef::Function { simple_name, .. } | TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(), }) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index eb72d37..72502aa 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -379,6 +379,13 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { + TopLevelDef::Module { name, attributes, .. } => { + let method_str = attributes.iter().map(|(n, _)| n.to_string()).collect_vec(); + format!( + "Module {{\nname: {:?},\nattributes{:?}\n}}", + name, method_str + ) + } TopLevelDef::Class { name, ancestors, fields, methods, attributes, type_vars, .. } => { diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index cba2f5e..88c007e 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -92,6 +92,18 @@ pub struct FunInstance { #[derive(Debug, Clone)] pub enum TopLevelDef { + Module { + /// Name of the module + name: StrRef, + /// Module ID used for [`TypeEnum`] + module_id: DefinitionId, + /// DefinitionId of `TopLevelDef::{Class, Function, Variable}` within the module + attributes: HashMap, + /// Symbol resolver of the module defined the class. + resolver: Option>, + /// Definition location. + loc: Option, + }, Class { /// Name for error messages and symbols. name: StrRef, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 742fa19..a58045b 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -2734,7 +2734,8 @@ impl Inferencer<'_> { .read() .iter() .map(|def| match *def.read() { - TopLevelDef::Class { name, .. } => (name, false), + TopLevelDef::Class { name, .. } + | TopLevelDef::Module { name, .. } => (name, false), TopLevelDef::Function { simple_name, .. } => (simple_name, false), TopLevelDef::Variable { simple_name, .. } => (simple_name, true), }) From 7fac801936ff2cd5ef1ebdba73559eab2642a2cc Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 10 Jan 2025 11:06:50 +0800 Subject: [PATCH 15/49] [artiq] add module primitive type --- nac3artiq/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 59d4dbe..4174fc8 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -159,6 +159,7 @@ pub struct PrimitivePythonId { generic_alias: (u64, u64), virtual_id: u64, option: u64, + module: u64, } type TopLevelComponent = (Stmt, String, PyObject); @@ -1097,6 +1098,7 @@ impl Nac3 { tuple: get_attr_id(builtins_mod, "tuple"), exception: get_attr_id(builtins_mod, "Exception"), option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()), + module: get_attr_id(types_mod, "ModuleType"), }; let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); From f15a64cc1b3e949dab9551698b4ed9d1c68d19e7 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 10 Jan 2025 12:05:11 +0800 Subject: [PATCH 16/49] [artiq] register modules --- nac3artiq/src/lib.rs | 25 +++++++++++++++++++++---- nac3core/src/toplevel/composer.rs | 29 +++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 4174fc8..78f427e 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -43,7 +43,7 @@ use nac3core::{ OptimizationLevel, }, nac3parser::{ - ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef}, + ast::{self, Constant, ExprKind, Located, Stmt, StmtKind, StrRef}, parser::parse_program, }, symbol_resolver::SymbolResolver, @@ -470,12 +470,14 @@ impl Nac3 { ]; add_exceptions(&mut composer, &mut builtins_def, &mut builtins_ty, &exception_names); + // Stores a mapping from module id to attributes let mut module_to_resolver_cache: HashMap = HashMap::new(); let mut rpc_ids = vec![]; for (stmt, path, module) in &self.top_levels { let py_module: &PyAny = module.extract(py)?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?; + let module_name: String = py_module.getattr("__name__")?.extract()?; let helper = helper.clone(); let class_obj; if let StmtKind::ClassDef { name, .. } = &stmt.node { @@ -490,7 +492,7 @@ impl Nac3 { } else { class_obj = None; } - let (name_to_pyid, resolver) = + let (name_to_pyid, resolver, _, _) = module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| { let mut name_to_pyid: HashMap = HashMap::new(); let members: &PyDict = @@ -519,9 +521,10 @@ impl Nac3 { }))) as Arc; let name_to_pyid = Rc::new(name_to_pyid); + let module_location = ast::Location::new(1, 1, stmt.location.file); module_to_resolver_cache - .insert(module_id, (name_to_pyid.clone(), resolver.clone())); - (name_to_pyid, resolver) + .insert(module_id, (name_to_pyid.clone(), resolver.clone(), module_name.clone(), Some(module_location))); + (name_to_pyid, resolver, module_name, Some(module_location)) }); let (name, def_id, ty) = composer @@ -595,6 +598,20 @@ impl Nac3 { } } + // Adding top level module definitions + for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in module_to_resolver_cache.into_iter() { + let def_id= composer.register_top_level_module( + module_name, + module_name_to_pyid, + module_resolver, + module_location + ).map_err(|e| { + CompileError::new_err(format!("compilation failed\n----------\n{e}")) + })?; + + self.pyid_to_def.write().insert(module_id, def_id); + } + let id_fun = PyModule::import(py, "builtins")?.getattr("id")?; let mut name_to_pyid: HashMap = HashMap::new(); let module = PyModule::new(py, "tmp")?; diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index b293fb4..a4ca27f 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -202,6 +202,35 @@ impl TopLevelComposer { self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec() } + /// register top level modules + pub fn register_top_level_module( + &mut self, + module_name: String, + name_to_pyid: Rc>, + resolver: Arc, + location: Option + ) -> Result { + let mut attributes: HashMap = HashMap::new(); + for (name, _) in name_to_pyid.iter() { + if let Ok(def_id) = resolver.get_identifier_def(*name) { + // Avoid repeated attribute instances resulting from multiple imports of same module + if self.defined_names.contains(&format!("{module_name}.{name}")) { + attributes.insert(*name, def_id); + } + }; + } + let module_def = TopLevelDef::Module { + name: module_name.clone().into(), + module_id: DefinitionId(self.definition_ast_list.len()), + attributes, + resolver: Some(resolver), + loc: location + }; + + self.definition_ast_list.push((Arc::new(RwLock::new(module_def)).into(), None)); + Ok(DefinitionId(self.definition_ast_list.len() - 1)) + } + /// register, just remember the names of top level classes/function /// and check duplicate class/method/function definition pub fn register_top_level( From ce40a46f8a44f3d095e236e7773790e097b6c048 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 16 Jan 2025 10:54:07 +0800 Subject: [PATCH 17/49] [core] add module type --- nac3core/src/codegen/concrete_type.rs | 13 ++ nac3core/src/typecheck/type_inferencer/mod.rs | 137 ++++++++++-------- nac3core/src/typecheck/typedef/mod.rs | 33 ++++- 3 files changed, 118 insertions(+), 65 deletions(-) diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index d5c1fc3..f0c92ed 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -205,6 +205,19 @@ impl ConcreteTypeStore { }) .collect(), }, + TypeEnum::TModule { module_id, attributes } => ConcreteTypeEnum::TModule { + module_id: *module_id, + methods: attributes + .iter() + .filter_map(|(name, ty)| match &*unifier.get_ty(ty.0) { + TypeEnum::TFunc(..) | TypeEnum::TObj { .. } => None, + _ => Some(( + *name, + (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1), + )), + }) + .collect(), + }, TypeEnum::TVirtual { ty } => ConcreteTypeEnum::TVirtual { ty: self.from_unifier_type(unifier, primitives, *ty, cache), }, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index a58045b..7ce659f 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -2008,72 +2008,90 @@ impl Inferencer<'_> { ctx: ExprContext, ) -> InferenceResult { let ty = value.custom.unwrap(); - if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) { - // just a fast path - match (fields.get(&attr), ctx == ExprContext::Store) { - (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), - (Some((ty, false)), true) => report_type_error( - TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), - Some(value.location), - self.unifier, - ), - (None, mutable) => { - // Check whether it is a class attribute - let defs = self.top_level.definitions.read(); - let result = { - if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() { - attributes.iter().find_map(|f| { - if f.0 == attr { - return Some(f.1); - } - None - }) - } else { - None - } - }; - match result { - Some(res) if !mutable => Ok(res), - Some(_) => report_error( - &format!("Class Attribute `{attr}` is immutable"), - value.location, - ), - None => report_type_error( - TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty), - Some(value.location), - self.unifier, - ), - } - } - } - } else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) { - // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 - let result = { - self.top_level.definitions.read().iter().find_map(|def| { - if let Some(rear_guard) = def.try_read() { - if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard { - if name.to_string() == self.unifier.stringify(sign.ret) { - return attributes.iter().find_map(|f| { + match &*self.unifier.get_ty(ty) { + TypeEnum::TObj { obj_id, fields, .. } => { + // just a fast path + match (fields.get(&attr), ctx == ExprContext::Store) { + (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), + (Some((ty, false)), true) => report_type_error( + TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), + Some(value.location), + self.unifier, + ), + (None, mutable) => { + // Check whether it is a class attribute + let defs = self.top_level.definitions.read(); + let result = { + if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() { + attributes.iter().find_map(|f| { if f.0 == attr { - return Some(f.clone().1); + return Some(f.1); } None - }); + }) + } else { + None } + }; + match result { + Some(res) if !mutable => Ok(res), + Some(_) => report_error( + &format!("Class Attribute `{attr}` is immutable"), + value.location, + ), + None => report_type_error( + TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty), + Some(value.location), + self.unifier, + ), } } - None - }) - }; - match result { - Some(f) if ctx != ExprContext::Store => Ok(f), - Some(_) => { - report_error(&format!("Class Attribute `{attr}` is immutable"), value.location) } - None => self.infer_general_attribute(value, attr, ctx), } - } else { - self.infer_general_attribute(value, attr, ctx) + TypeEnum::TFunc(sign) => { + // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 + let result = { + self.top_level.definitions.read().iter().find_map(|def| { + if let Some(rear_guard) = def.try_read() { + if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard { + if name.to_string() == self.unifier.stringify(sign.ret) { + return attributes.iter().find_map(|f| { + if f.0 == attr { + return Some(f.clone().1); + } + None + }); + } + } + } + None + }) + }; + match result { + Some(f) if ctx != ExprContext::Store => Ok(f), + Some(_) => report_error( + &format!("Class Attribute `{attr}` is immutable"), + value.location, + ), + None => self.infer_general_attribute(value, attr, ctx), + } + } + TypeEnum::TModule { attributes, .. } => { + match (attributes.get(&attr), ctx == ExprContext::Load) { + (Some((ty, _)), true) | (Some((ty, false)), false) => Ok(*ty), + (Some((ty, true)), false) => report_type_error( + TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), + Some(value.location), + self.unifier, + ), + (None, _) => report_type_error( + TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty), + Some(value.location), + self.unifier, + ), + } + } + _ => self.infer_general_attribute(value, attr, ctx), } } @@ -2734,8 +2752,7 @@ impl Inferencer<'_> { .read() .iter() .map(|def| match *def.read() { - TopLevelDef::Class { name, .. } - | TopLevelDef::Module { name, .. } => (name, false), + TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => (name, false), TopLevelDef::Function { simple_name, .. } => (simple_name, false), TopLevelDef::Variable { simple_name, .. } => (simple_name, true), }) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index e190c4c..f2f9ed6 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -270,6 +270,19 @@ pub enum TypeEnum { /// A function type. TFunc(FunSignature), + + /// Module Type + TModule { + /// The [`DefinitionId`] of this object type. + module_id: DefinitionId, + + /// The attributes present in this object type. + /// + /// The key of the [Mapping] is the identifier of the field, while the value is a tuple + /// containing the [Type] of the field, and a `bool` indicating whether the field is a + /// variable (as opposed to a function). + attributes: Mapping, + }, } impl TypeEnum { @@ -284,6 +297,7 @@ impl TypeEnum { TypeEnum::TVirtual { .. } => "TVirtual", TypeEnum::TCall { .. } => "TCall", TypeEnum::TFunc { .. } => "TFunc", + TypeEnum::TModule { .. } => "TModule", } } } @@ -593,7 +607,8 @@ impl Unifier { | TLiteral { .. } // functions are instantiated for each call sites, so the function type can contain // type variables. - | TFunc { .. } => true, + | TFunc { .. } + | TModule { .. } => true, TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, @@ -1315,10 +1330,12 @@ impl Unifier { || format!("{id}"), |top_level| { let top_level_def = &top_level.definitions.read()[id]; - let TopLevelDef::Class { name, .. } = &*top_level_def.read() else { - unreachable!("expected class definition") + let top_level_def = top_level_def.read(); + let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) = + &*top_level_def + else { + unreachable!("expected module/class definition") }; - name.to_string() }, ) @@ -1446,6 +1463,10 @@ impl Unifier { let ret = self.internal_stringify(signature.ret, obj_to_name, var_to_name, notes); format!("fn[[{params}], {ret}]") } + TypeEnum::TModule { module_id, .. } => { + let name = obj_to_name(module_id.0); + name.to_string() + } } } @@ -1521,7 +1542,9 @@ impl Unifier { // variables, i.e. things like TRecord, TCall should not occur, and we // should be safe to not implement the substitution for those variants. match &*ty { - TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None, + TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } | TypeEnum::TModule { .. } => { + None + } TypeEnum::TVar { id, .. } => mapping.get(id).copied(), TypeEnum::TTuple { ty, is_vararg_ctx } => { let mut new_ty = Cow::from(ty); From 32f24261f280cfe36d75a08313697e540822655e Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 16 Jan 2025 11:08:55 +0800 Subject: [PATCH 18/49] [artiq] add global variables to modules --- nac3artiq/src/lib.rs | 41 +++++--- nac3artiq/src/symbol_resolver.rs | 160 +++++++++++++++++++++++++++++- nac3core/src/toplevel/composer.rs | 56 +++++++---- nac3core/src/toplevel/helper.rs | 9 +- nac3core/src/toplevel/mod.rs | 6 +- 5 files changed, 234 insertions(+), 38 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 78f427e..ba6c4fa 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -277,6 +277,10 @@ impl Nac3 { } }) } + // Allow global variable declaration with `Kernel` type annotation + StmtKind::AnnAssign { ref annotation, .. } => { + matches!(&annotation.node, ExprKind::Subscript { value, .. } if matches!(&value.node, ExprKind::Name {id, ..} if id == &"Kernel".into())) + } _ => false, }; @@ -522,8 +526,15 @@ impl Nac3 { as Arc; let name_to_pyid = Rc::new(name_to_pyid); let module_location = ast::Location::new(1, 1, stmt.location.file); - module_to_resolver_cache - .insert(module_id, (name_to_pyid.clone(), resolver.clone(), module_name.clone(), Some(module_location))); + module_to_resolver_cache.insert( + module_id, + ( + name_to_pyid.clone(), + resolver.clone(), + module_name.clone(), + Some(module_location), + ), + ); (name_to_pyid, resolver, module_name, Some(module_location)) }); @@ -599,15 +610,19 @@ impl Nac3 { } // Adding top level module definitions - for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in module_to_resolver_cache.into_iter() { - let def_id= composer.register_top_level_module( - module_name, - module_name_to_pyid, - module_resolver, - module_location - ).map_err(|e| { - CompileError::new_err(format!("compilation failed\n----------\n{e}")) - })?; + for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in + module_to_resolver_cache + { + let def_id = composer + .register_top_level_module( + &module_name, + &module_name_to_pyid, + module_resolver, + module_location, + ) + .map_err(|e| { + CompileError::new_err(format!("compilation failed\n----------\n{e}")) + })?; self.pyid_to_def.write().insert(module_id, def_id); } @@ -731,7 +746,9 @@ impl Nac3 { "Unsupported @rpc annotation on global variable", ))) } - TopLevelDef::Module { .. } => unreachable!("Type module cannot be decorated with @rpc"), + TopLevelDef::Module { .. } => { + unreachable!("Type module cannot be decorated with @rpc") + } } } } diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index d976866..4b398a9 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -23,7 +23,7 @@ use nac3core::{ inkwell::{ module::Linkage, types::{BasicType, BasicTypeEnum}, - values::BasicValueEnum, + values::{BasicValue, BasicValueEnum}, AddressSpace, }, nac3parser::ast::{self, StrRef}, @@ -674,6 +674,48 @@ impl InnerResolver { }) }); + // check if obj is module + if self.helper.id_fn.call1(py, (ty.clone(),))?.extract::(py)? + == self.primitive_ids.module + && self.pyid_to_def.read().contains_key(&py_obj_id) + { + let def_id = self.pyid_to_def.read()[&py_obj_id]; + let def = defs[def_id.0].read(); + let TopLevelDef::Module { name: module_name, module_id, attributes, methods, .. } = + &*def + else { + unreachable!("must be a module here"); + }; + // Construct the module return type + let mut module_attributes = HashMap::new(); + for (name, _) in attributes { + let attribute_obj = obj.getattr(name.to_string().as_str())?; + let attribute_ty = + self.get_obj_type(py, attribute_obj, unifier, defs, primitives)?; + if let Ok(attribute_ty) = attribute_ty { + module_attributes.insert(*name, (attribute_ty, false)); + } else { + return Ok(Err(format!("Unable to resolve {module_name}.{name}"))); + } + } + + for name in methods.keys() { + let method_obj = obj.getattr(name.to_string().as_str())?; + let method_ty = self.get_obj_type(py, method_obj, unifier, defs, primitives)?; + if let Ok(method_ty) = method_ty { + module_attributes.insert(*name, (method_ty, true)); + } else { + return Ok(Err(format!("Unable to resolve {module_name}.{name}"))); + } + } + + let module_ty = + TypeEnum::TModule { module_id: *module_id, attributes: module_attributes }; + + let ty = unifier.add_ty(module_ty); + return Ok(Ok(ty)); + } + if let Some(ty) = constructor_ty { self.pyid_to_type.write().insert(py_obj_id, ty); return Ok(Ok(ty)); @@ -1373,6 +1415,77 @@ impl InnerResolver { None => Ok(None), } } + } else if ty_id == self.primitive_ids.module { + let id_str = id.to_string(); + + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + + let top_level_defs = ctx.top_level.definitions.read(); + let ty = self + .get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)? + .unwrap(); + let ty = ctx + .get_llvm_type(generator, ty) + .into_pointer_type() + .get_element_type() + .into_struct_type(); + + { + if self.global_value_ids.read().contains_key(&id) { + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) + }); + return Ok(Some(global.as_pointer_value().into())); + } + self.global_value_ids.write().insert(id, obj.into()); + } + + let fields = { + let definition = + top_level_defs.get(self.pyid_to_def.read().get(&id).unwrap().0).unwrap().read(); + let TopLevelDef::Module { attributes, .. } = &*definition else { unreachable!() }; + attributes + .iter() + .filter_map(|f| { + let definition = top_level_defs.get(f.1 .0).unwrap().read(); + if let TopLevelDef::Variable { ty, .. } = &*definition { + Some((f.0, *ty)) + } else { + None + } + }) + .collect_vec() + }; + + let values: Result>, _> = fields + .iter() + .map(|(name, ty)| { + self.get_obj_value( + py, + obj.getattr(name.to_string().as_str())?, + ctx, + generator, + *ty, + ) + .map_err(|e| { + super::CompileError::new_err(format!("Error getting field {name}: {e}")) + }) + }) + .collect(); + let values = values?; + + if let Some(values) = values { + let val = ty.const_named_struct(&values); + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) + }); + global.set_initializer(&val); + Ok(Some(global.as_pointer_value().into())) + } else { + Ok(None) + } } else { let id_str = id.to_string(); @@ -1555,9 +1668,50 @@ impl SymbolResolver for Resolver { fn get_symbol_value<'ctx>( &self, id: StrRef, - _: &mut CodeGenContext<'ctx, '_>, - _: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, ) -> Option> { + if let Some(def_id) = self.0.id_to_def.read().get(&id) { + let top_levels = ctx.top_level.definitions.read(); + if matches!(&*top_levels[def_id.0].read(), TopLevelDef::Variable { .. }) { + let module_val = &self.0.module; + let ret = Python::with_gil(|py| -> PyResult> { + let module_val = module_val.as_ref(py); + + let ty = self.0.get_obj_type( + py, + module_val, + &mut ctx.unifier, + &top_levels, + &ctx.primitives, + )?; + if let Err(ty) = ty { + return Ok(Err(ty)); + } + let ty = ty.unwrap(); + let obj = self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap(); + let (idx, _) = ctx.get_attr_index(ty, id); + let ret = unsafe { + ctx.builder.build_gep( + obj.into_pointer_value(), + &[ + ctx.ctx.i32_type().const_zero(), + ctx.ctx.i32_type().const_int(idx as u64, false), + ], + id.to_string().as_str(), + ) + } + .unwrap(); + Ok(Ok(ret.as_basic_value_enum())) + }) + .unwrap(); + if ret.is_err() { + return None; + } + return Some(ret.unwrap().into()); + } + } + let sym_value = { let id_to_val = self.0.id_to_pyval.read(); id_to_val.get(&id).cloned() diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index a4ca27f..a6a0ce7 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -101,8 +101,9 @@ impl TopLevelComposer { let builtin_name_list = definition_ast_list .iter() .map(|def_ast| match *def_ast.0.read() { - TopLevelDef::Class { name, .. } - | TopLevelDef::Module { name, .. } => name.to_string(), + TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => { + name.to_string() + } TopLevelDef::Function { simple_name, .. } | TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(), }) @@ -205,29 +206,37 @@ impl TopLevelComposer { /// register top level modules pub fn register_top_level_module( &mut self, - module_name: String, - name_to_pyid: Rc>, + module_name: &str, + name_to_pyid: &Rc>, resolver: Arc, - location: Option + location: Option, ) -> Result { - let mut attributes: HashMap = HashMap::new(); + let mut methods: HashMap = HashMap::new(); + let mut attributes: Vec<(StrRef, DefinitionId)> = Vec::new(); + for (name, _) in name_to_pyid.iter() { if let Ok(def_id) = resolver.get_identifier_def(*name) { // Avoid repeated attribute instances resulting from multiple imports of same module if self.defined_names.contains(&format!("{module_name}.{name}")) { - attributes.insert(*name, def_id); + match &*self.definition_ast_list[def_id.0].0.read() { + TopLevelDef::Class { .. } | TopLevelDef::Function { .. } => { + methods.insert(*name, def_id); + } + _ => attributes.push((*name, def_id)), + } } }; } - let module_def = TopLevelDef::Module { - name: module_name.clone().into(), - module_id: DefinitionId(self.definition_ast_list.len()), - attributes, - resolver: Some(resolver), - loc: location + let module_def = TopLevelDef::Module { + name: module_name.to_string().into(), + module_id: DefinitionId(self.definition_ast_list.len()), + methods, + attributes, + resolver: Some(resolver), + loc: location, }; - self.definition_ast_list.push((Arc::new(RwLock::new(module_def)).into(), None)); + self.definition_ast_list.push((Arc::new(RwLock::new(module_def)), None)); Ok(DefinitionId(self.definition_ast_list.len() - 1)) } @@ -499,10 +508,10 @@ impl TopLevelComposer { self.analyze_top_level_class_definition()?; self.analyze_top_level_class_fields_methods()?; self.analyze_top_level_function()?; + self.analyze_top_level_variables()?; if inference { self.analyze_function_instance()?; } - self.analyze_top_level_variables()?; Ok(()) } @@ -1440,7 +1449,7 @@ impl TopLevelComposer { Ok(()) } - /// step 4, analyze and call type inferencer to fill the `instance_to_stmt` of + /// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of /// [`TopLevelDef::Function`] fn analyze_function_instance(&mut self) -> Result<(), HashSet> { // first get the class constructor type correct for the following type check in function body @@ -1971,7 +1980,7 @@ impl TopLevelComposer { Ok(()) } - /// Step 5. Analyze and populate the types of global variables. + /// Step 4. Analyze and populate the types of global variables. fn analyze_top_level_variables(&mut self) -> Result<(), HashSet> { let def_list = &self.definition_ast_list; let temp_def_list = self.extract_def_list(); @@ -1989,6 +1998,19 @@ impl TopLevelComposer { let resolver = &**resolver.as_ref().unwrap(); if let Some(ty_decl) = ty_decl { + let ty_decl = match &ty_decl.node { + ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if self.core_config.kernel_ann.map_or(false, |c| id == &c.into()) + ) => + { + slice + } + _ if self.core_config.kernel_ann.is_none() => ty_decl, + _ => unreachable!("Global variables should be annotated with Kernel[]"), // ignore fields annotated otherwise + }; + let ty_annotation = parse_ast_to_type_annotation_kinds( resolver, &temp_def_list, diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 72502aa..4ca5464 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -379,11 +379,12 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { - TopLevelDef::Module { name, attributes, .. } => { - let method_str = attributes.iter().map(|(n, _)| n.to_string()).collect_vec(); + TopLevelDef::Module { name, attributes, methods, .. } => { format!( - "Module {{\nname: {:?},\nattributes{:?}\n}}", - name, method_str + "Module {{\nname: {:?},\nattributes: {:?}\nmethods: {:?}\n}}", + name, + attributes.iter().map(|(n, _)| n.to_string()).collect_vec(), + methods.iter().map(|(n, _)| n.to_string()).collect_vec() ) } TopLevelDef::Class { diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 88c007e..3ffd568 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -97,8 +97,10 @@ pub enum TopLevelDef { name: StrRef, /// Module ID used for [`TypeEnum`] module_id: DefinitionId, - /// DefinitionId of `TopLevelDef::{Class, Function, Variable}` within the module - attributes: HashMap, + /// `DefinitionId` of `TopLevelDef::{Class, Function}` within the module + methods: HashMap, + /// `DefinitionId` of `TopLevelDef::{Variable}` within the module + attributes: Vec<(StrRef, DefinitionId)>, /// Symbol resolver of the module defined the class. resolver: Option>, /// Definition location. From 5fdbc34b430bd5875623eb5a0e0839d99422555d Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 16 Jan 2025 11:11:53 +0800 Subject: [PATCH 19/49] [core] implement codegen for modules --- nac3artiq/src/codegen.rs | 28 +++++++++++++++++++++++ nac3core/src/codegen/concrete_type.rs | 13 +++++++++++ nac3core/src/codegen/expr.rs | 26 +++++++++++++++++----- nac3core/src/codegen/mod.rs | 32 +++++++++++++++++++++++++++ nac3core/src/symbol_resolver.rs | 8 ++++--- 5 files changed, 98 insertions(+), 9 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index c968198..e472705 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -1052,6 +1052,34 @@ pub fn attributes_writeback<'ctx>( )); } } + TypeEnum::TModule { attributes, .. } => { + let mut fields = Vec::new(); + let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); + + for (name, (field_ty, is_method)) in attributes { + if *is_method { + continue; + } + if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { + fields.push(name.to_string()); + let (index, _) = ctx.get_attr_index(ty, *name); + values.push(( + *field_ty, + ctx.build_gep_and_load( + obj.into_pointer_value(), + &[zero, int32.const_int(index as u64, false)], + None, + ), + )); + } + } + if !fields.is_empty() { + let pydict = PyDict::new(py); + pydict.set_item("obj", val)?; + pydict.set_item("fields", fields)?; + host_attributes.append(pydict)?; + } + } _ => {} } } diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index f0c92ed..503a4ae 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -56,6 +56,10 @@ pub enum ConcreteTypeEnum { fields: HashMap, params: IndexMap, }, + TModule { + module_id: DefinitionId, + methods: HashMap, + }, TVirtual { ty: ConcreteType, }, @@ -297,6 +301,15 @@ impl ConcreteTypeStore { TypeVar { id, ty } })), }, + ConcreteTypeEnum::TModule { module_id, methods } => TypeEnum::TModule { + module_id: *module_id, + attributes: methods + .iter() + .map(|(name, cty)| { + (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) + }) + .collect::>(), + }, ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { args: args .iter() diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 6d2057e..fd2cd28 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -61,8 +61,13 @@ pub fn get_subst_key( ) -> String { let mut vars = obj .map(|ty| { - let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() }; - params.clone() + if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) { + params.clone() + } else if let TypeEnum::TModule { .. } = &*unifier.get_ty(ty) { + indexmap::IndexMap::new() + } else { + unreachable!() + } }) .unwrap_or_default(); vars.extend(fun_vars); @@ -120,6 +125,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> { pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option) { let obj_id = match &*self.unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } => *obj_id, + TypeEnum::TModule { module_id, .. } => *module_id, // we cannot have other types, virtual type should be handled by function calls _ => codegen_unreachable!(self), }; @@ -131,6 +137,8 @@ impl<'ctx> CodeGenContext<'ctx, '_> { let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap(); (attribute_index.0, Some(attribute_index.1 .2.clone())) } + } else if let TopLevelDef::Module { attributes, .. } = &*def.read() { + (attributes.iter().find_position(|x| x.0 == attr).unwrap().0, None) } else { codegen_unreachable!(self) }; @@ -2805,6 +2813,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( &*ctx.unifier.get_ty(value.custom.unwrap()) { *obj_id + } else if let TypeEnum::TModule { module_id, .. } = + &*ctx.unifier.get_ty(value.custom.unwrap()) + { + *module_id } else { codegen_unreachable!(ctx) }; @@ -2815,11 +2827,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); - let TopLevelDef::Class { methods, .. } = &*obj_def else { + if let TopLevelDef::Class { methods, .. } = &*obj_def { + methods.iter().find(|method| method.0 == *attr).unwrap().2 + } else if let TopLevelDef::Module { methods, .. } = &*obj_def { + *methods.iter().find(|method| method.0 == attr).unwrap().1 + } else { codegen_unreachable!(ctx) - }; - - methods.iter().find(|method| method.0 == *attr).unwrap().2 + } }; // directly generate code for option.unwrap // since it needs to return static value to optimize for kernel invariant diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index dcfa2b8..37e1bb3 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -501,6 +501,38 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| { let ty_enum = unifier.get_ty(ty); let result = match &*ty_enum { + TModule {module_id, attributes} => { + let top_level_defs = top_level.definitions.read(); + let definition = top_level_defs.get(module_id.0).unwrap(); + let TopLevelDef::Module { name, attributes: attribute_fields, .. } = &*definition.read() else { + unreachable!() + }; + let ty: BasicTypeEnum<'_> = if let Some(t) = module.get_struct_type(&name.to_string()) { + t.ptr_type(AddressSpace::default()).into() + } else { + let struct_type = ctx.opaque_struct_type(&name.to_string()); + type_cache.insert( + unifier.get_representative(ty), + struct_type.ptr_type(AddressSpace::default()).into(), + ); + let module_fields: Vec> = attribute_fields.iter() + .map(|f| { + get_llvm_type( + ctx, + module, + generator, + unifier, + top_level, + type_cache, + attributes[&f.0].0, + ) + }) + .collect_vec(); + struct_type.set_body(&module_fields, false); + struct_type.ptr_type(AddressSpace::default()).into() + }; + return ty; + }, TObj { obj_id, fields, .. } => { // check to avoid treating non-class primitives as classes if PrimDef::contains_id(*obj_id) { diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 2378dd6..4829093 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -598,10 +598,12 @@ impl dyn SymbolResolver + Send + Sync { unifier.internal_stringify( ty, &mut |id| { - let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else { - unreachable!("expected class definition") + let top_level_def = &*top_level_defs[id].read(); + let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) = + top_level_def + else { + unreachable!("expected class/module definition") }; - name.to_string() }, &mut |id| format!("typevar{id}"), From 14e80dfab7dbd0522f5761c16a3fb30650c43951 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 16 Jan 2025 12:41:30 +0800 Subject: [PATCH 20/49] update snapshots --- ...c3core__toplevel__test__test_analyze__generic_class.snap | 4 ++-- ..._toplevel__test__test_analyze__inheritance_override.snap | 6 +++--- ...e__toplevel__test__test_analyze__list_tuple_generic.snap | 4 ++-- .../nac3core__toplevel__test__test_analyze__self1.snap | 4 ++-- ..._toplevel__test__test_analyze__simple_class_compose.snap | 6 +++--- ..._toplevel__test__test_analyze__simple_pass_in_class.snap | 4 +--- 6 files changed, 13 insertions(+), 15 deletions(-) diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 4332b47..8c827ee 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", + "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 60e0c19..b8a80a5 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -3,13 +3,13 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", - "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 4660181..05f4488 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -4,10 +4,10 @@ expression: res_vec --- [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", - "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index da58d12..7d3922e 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n", "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 8f384fa..b55e998 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -3,14 +3,14 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n", - "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap index 5178f1b..2f37789 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap @@ -1,9 +1,7 @@ --- source: nac3core/src/toplevel/test.rs -assertion_line: 549 expression: res_vec - --- [ - "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nmethods: [],\ntype_vars: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nattributes: [],\nmethods: [],\ntype_vars: []\n}\n", ] From 879b063968235b2c14950baed75d2b539311d235 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 16 Jan 2025 12:42:13 +0800 Subject: [PATCH 21/49] [artiq] add tests for module support --- nac3artiq/demo/module_support.py | 29 +++++++++++++++++++ nac3artiq/demo/tests/global_variables.py | 14 +++++++++ .../{ => tests}/string_attribute_issue337.py | 12 ++------ .../support_class_attr_issue102.py | 3 -- 4 files changed, 46 insertions(+), 12 deletions(-) create mode 100644 nac3artiq/demo/module_support.py create mode 100644 nac3artiq/demo/tests/global_variables.py rename nac3artiq/demo/{ => tests}/string_attribute_issue337.py (57%) rename nac3artiq/demo/{ => tests}/support_class_attr_issue102.py (99%) diff --git a/nac3artiq/demo/module_support.py b/nac3artiq/demo/module_support.py new file mode 100644 index 0000000..a863b38 --- /dev/null +++ b/nac3artiq/demo/module_support.py @@ -0,0 +1,29 @@ +from min_artiq import * +import tests.string_attribute_issue337 as issue337 +import tests.support_class_attr_issue102 as issue102 +import tests.global_variables as global_variables + +@nac3 +class TestModuleSupport: + core: KernelInvariant[Core] + + def __init__(self): + self.core = Core() + + @kernel + def run(self): + # Accessing classes + issue337.Demo().run() + obj = issue102.Demo() + obj.attr3 = 3 + + # Calling functions + global_variables.inc_X() + global_variables.display_X() + + # Updating global variables + global_variables.X = 9 + global_variables.display_X() + +if __name__ == "__main__": + TestModuleSupport().run() \ No newline at end of file diff --git a/nac3artiq/demo/tests/global_variables.py b/nac3artiq/demo/tests/global_variables.py new file mode 100644 index 0000000..ac0e0cf --- /dev/null +++ b/nac3artiq/demo/tests/global_variables.py @@ -0,0 +1,14 @@ +from min_artiq import * +from numpy import int32 + +X: Kernel[int32] = 1 + +@rpc +def display_X(): + print_int32(X) + +@kernel +def inc_X(): + global X + X += 1 + diff --git a/nac3artiq/demo/string_attribute_issue337.py b/nac3artiq/demo/tests/string_attribute_issue337.py similarity index 57% rename from nac3artiq/demo/string_attribute_issue337.py rename to nac3artiq/demo/tests/string_attribute_issue337.py index 9749462..c0b36ed 100644 --- a/nac3artiq/demo/string_attribute_issue337.py +++ b/nac3artiq/demo/tests/string_attribute_issue337.py @@ -1,16 +1,13 @@ from min_artiq import * from numpy import int32 - @nac3 class Demo: - core: KernelInvariant[Core] - attr1: KernelInvariant[str] - attr2: KernelInvariant[int32] - + attr1: Kernel[str] + attr2: Kernel[int32] + @kernel def __init__(self): - self.core = Core() self.attr2 = 32 self.attr1 = "SAMPLE" @@ -19,6 +16,3 @@ class Demo: print_int32(self.attr2) self.attr1 - -if __name__ == "__main__": - Demo().run() diff --git a/nac3artiq/demo/support_class_attr_issue102.py b/nac3artiq/demo/tests/support_class_attr_issue102.py similarity index 99% rename from nac3artiq/demo/support_class_attr_issue102.py rename to nac3artiq/demo/tests/support_class_attr_issue102.py index 1b93144..0482e3f 100644 --- a/nac3artiq/demo/support_class_attr_issue102.py +++ b/nac3artiq/demo/tests/support_class_attr_issue102.py @@ -1,7 +1,6 @@ from min_artiq import * from numpy import int32 - @nac3 class Demo: attr1: KernelInvariant[int32] = 2 @@ -12,7 +11,6 @@ class Demo: def __init__(self): self.attr3 = 8 - @nac3 class NAC3Devices: core: KernelInvariant[Core] @@ -35,6 +33,5 @@ class NAC3Devices: NAC3Devices.attr4 # Attributes accessible for classes without __init__ - if __name__ == "__main__": NAC3Devices().run() From 2783834cb1e2df6e41ed2689cbd1a2db63ee18d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Bourdeauducq?= Date: Fri, 17 Jan 2025 12:45:51 +0800 Subject: [PATCH 22/49] nac3artiq/demo: merge EmbeddingMap into min_artiq --- nac3artiq/demo/embedding_map.py | 39 ------------------------------- nac3artiq/demo/min_artiq.py | 41 ++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 40 deletions(-) delete mode 100644 nac3artiq/demo/embedding_map.py diff --git a/nac3artiq/demo/embedding_map.py b/nac3artiq/demo/embedding_map.py deleted file mode 100644 index a43af69..0000000 --- a/nac3artiq/demo/embedding_map.py +++ /dev/null @@ -1,39 +0,0 @@ -class EmbeddingMap: - def __init__(self): - self.object_inverse_map = {} - self.object_map = {} - self.string_map = {} - self.string_reverse_map = {} - self.function_map = {} - self.attributes_writeback = [] - - def store_function(self, key, fun): - self.function_map[key] = fun - return key - - def store_object(self, obj): - obj_id = id(obj) - if obj_id in self.object_inverse_map: - return self.object_inverse_map[obj_id] - key = len(self.object_map) + 1 - self.object_map[key] = obj - self.object_inverse_map[obj_id] = key - return key - - def store_str(self, s): - if s in self.string_reverse_map: - return self.string_reverse_map[s] - key = len(self.string_map) - self.string_map[key] = s - self.string_reverse_map[s] = key - return key - - def retrieve_function(self, key): - return self.function_map[key] - - def retrieve_object(self, key): - return self.object_map[key] - - def retrieve_str(self, key): - return self.string_map[key] - diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 62d32cc..fef018b 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -6,7 +6,6 @@ from typing import Generic, TypeVar from math import floor, ceil import nac3artiq -from embedding_map import EmbeddingMap __all__ = [ @@ -193,6 +192,46 @@ def print_int64(x: int64): raise NotImplementedError("syscall not simulated") +class EmbeddingMap: + def __init__(self): + self.object_inverse_map = {} + self.object_map = {} + self.string_map = {} + self.string_reverse_map = {} + self.function_map = {} + self.attributes_writeback = [] + + def store_function(self, key, fun): + self.function_map[key] = fun + return key + + def store_object(self, obj): + obj_id = id(obj) + if obj_id in self.object_inverse_map: + return self.object_inverse_map[obj_id] + key = len(self.object_map) + 1 + self.object_map[key] = obj + self.object_inverse_map[obj_id] = key + return key + + def store_str(self, s): + if s in self.string_reverse_map: + return self.string_reverse_map[s] + key = len(self.string_map) + self.string_map[key] = s + self.string_reverse_map[s] = key + return key + + def retrieve_function(self, key): + return self.function_map[key] + + def retrieve_object(self, key): + return self.object_map[key] + + def retrieve_str(self, key): + return self.string_map[key] + + @nac3 class Core: ref_period: KernelInvariant[float] From 2d275949b8fc274afe7d441f63b90db3556d1bfd Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 17 Jan 2025 13:04:16 +0800 Subject: [PATCH 23/49] move tests from artiq to standalone --- .../demo/tests/string_attribute_issue337.py | 18 --------- .../demo/tests/support_class_attr_issue102.py | 37 ------------------- nac3standalone/demo/src/class_attributes.py | 35 ++++++++++++++++++ 3 files changed, 35 insertions(+), 55 deletions(-) delete mode 100644 nac3artiq/demo/tests/string_attribute_issue337.py delete mode 100644 nac3artiq/demo/tests/support_class_attr_issue102.py create mode 100644 nac3standalone/demo/src/class_attributes.py diff --git a/nac3artiq/demo/tests/string_attribute_issue337.py b/nac3artiq/demo/tests/string_attribute_issue337.py deleted file mode 100644 index c0b36ed..0000000 --- a/nac3artiq/demo/tests/string_attribute_issue337.py +++ /dev/null @@ -1,18 +0,0 @@ -from min_artiq import * -from numpy import int32 - -@nac3 -class Demo: - attr1: Kernel[str] - attr2: Kernel[int32] - - @kernel - def __init__(self): - self.attr2 = 32 - self.attr1 = "SAMPLE" - - @kernel - def run(self): - print_int32(self.attr2) - self.attr1 - diff --git a/nac3artiq/demo/tests/support_class_attr_issue102.py b/nac3artiq/demo/tests/support_class_attr_issue102.py deleted file mode 100644 index 0482e3f..0000000 --- a/nac3artiq/demo/tests/support_class_attr_issue102.py +++ /dev/null @@ -1,37 +0,0 @@ -from min_artiq import * -from numpy import int32 - -@nac3 -class Demo: - attr1: KernelInvariant[int32] = 2 - attr2: int32 = 4 - attr3: Kernel[int32] - - @kernel - def __init__(self): - self.attr3 = 8 - -@nac3 -class NAC3Devices: - core: KernelInvariant[Core] - attr4: KernelInvariant[int32] = 16 - - def __init__(self): - self.core = Core() - - @kernel - def run(self): - Demo.attr1 # Supported - # Demo.attr2 # Field not accessible on Kernel - # Demo.attr3 # Only attributes can be accessed in this way - # Demo.attr1 = 2 # Attributes are immutable - - self.attr4 # Attributes can be accessed within class - - obj = Demo() - obj.attr1 # Attributes can be accessed by class objects - - NAC3Devices.attr4 # Attributes accessible for classes without __init__ - -if __name__ == "__main__": - NAC3Devices().run() diff --git a/nac3standalone/demo/src/class_attributes.py b/nac3standalone/demo/src/class_attributes.py new file mode 100644 index 0000000..b58958f --- /dev/null +++ b/nac3standalone/demo/src/class_attributes.py @@ -0,0 +1,35 @@ +@extern +def output_int32(x: int32): + ... + +@extern +def output_strln(x: str): + ... + + +class A: + a: int32 = 1 + b: int32 + c: str = "test" + d: str + + def __init__(self): + self.b = 2 + self.d = "test" + + output_int32(self.a) # Attributes can be accessed within class + + +def run() -> int32: + output_int32(A.a) # Attributes can be directly accessed with class name + # A.b # Only attributes can be accessed in this way + # A.a = 2 # Attributes are immutable + + obj = A() + output_int32(obj.a) # Attributes can be accessed by class objects + + output_strln(obj.c) + output_strln(obj.d) + + return 0 + From f817d3347bb4b385754141a50029ebf8e03a9912 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 17 Jan 2025 17:56:43 +0800 Subject: [PATCH 24/49] [artiq] cleanup module functionality tests --- nac3artiq/demo/module.py | 26 ++++++++++++++++++++++++ nac3artiq/demo/module_support.py | 17 +++++++--------- nac3artiq/demo/tests/global_variables.py | 14 ------------- 3 files changed, 33 insertions(+), 24 deletions(-) create mode 100644 nac3artiq/demo/module.py delete mode 100644 nac3artiq/demo/tests/global_variables.py diff --git a/nac3artiq/demo/module.py b/nac3artiq/demo/module.py new file mode 100644 index 0000000..58f9245 --- /dev/null +++ b/nac3artiq/demo/module.py @@ -0,0 +1,26 @@ +from min_artiq import * +from numpy import int32 + +# Global Variable Definition +X: Kernel[int32] = 1 + +# TopLevelFunction Defintion +@kernel +def display_X(): + print_int32(X) + +# TopLevel Class Definition +@nac3 +class A: + @kernel + def __init__(self): + self.set_x(1) + + @kernel + def set_x(self, new_val: int32): + global X + X = new_val + + @kernel + def get_X(self) -> int32: + return X diff --git a/nac3artiq/demo/module_support.py b/nac3artiq/demo/module_support.py index a863b38..78ef656 100644 --- a/nac3artiq/demo/module_support.py +++ b/nac3artiq/demo/module_support.py @@ -1,7 +1,5 @@ from min_artiq import * -import tests.string_attribute_issue337 as issue337 -import tests.support_class_attr_issue102 as issue102 -import tests.global_variables as global_variables +import module as module_definition @nac3 class TestModuleSupport: @@ -13,17 +11,16 @@ class TestModuleSupport: @kernel def run(self): # Accessing classes - issue337.Demo().run() - obj = issue102.Demo() - obj.attr3 = 3 + obj = module_definition.A() + obj.get_X() + obj.set_x(2) # Calling functions - global_variables.inc_X() - global_variables.display_X() + module_definition.display_X() # Updating global variables - global_variables.X = 9 - global_variables.display_X() + module_definition.X = 9 + module_definition.display_X() if __name__ == "__main__": TestModuleSupport().run() \ No newline at end of file diff --git a/nac3artiq/demo/tests/global_variables.py b/nac3artiq/demo/tests/global_variables.py deleted file mode 100644 index ac0e0cf..0000000 --- a/nac3artiq/demo/tests/global_variables.py +++ /dev/null @@ -1,14 +0,0 @@ -from min_artiq import * -from numpy import int32 - -X: Kernel[int32] = 1 - -@rpc -def display_X(): - print_int32(X) - -@kernel -def inc_X(): - global X - X += 1 - From 05fd1a519902b3a8c050ec031ca7496b32626339 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 27 Jan 2025 22:09:03 +0800 Subject: [PATCH 25/49] [meta] Use lld as linker --- .cargo/config.toml | 2 ++ flake.nix | 5 +++-- nac3standalone/demo/run_demo.sh | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) create mode 100644 .cargo/config.toml diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..188308d --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[target.x86_64-unknown-linux-gnu] +rustflags = ["-C", "link-arg=-fuse-ld=lld"] diff --git a/flake.nix b/flake.nix index a20e1f1..a48ff69 100644 --- a/flake.nix +++ b/flake.nix @@ -41,7 +41,7 @@ lockFile = ./Cargo.lock; }; passthru.cargoLock = cargoLock; - nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ]; + nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out pkgs.llvmPackages_14.bintools llvm-nac3 ]; buildInputs = [ pkgs.python3 llvm-nac3 ]; checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ]; checkPhase = @@ -120,6 +120,7 @@ buildInputs = [ (python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb ps.platformdirs nac3artiq-instrumented ])) pkgs.llvmPackages_14.llvm.out + pkgs.llvmPackages_14.bintools ]; phases = [ "buildPhase" "installPhase" ]; buildPhase = @@ -168,7 +169,7 @@ buildInputs = with pkgs; [ # build dependencies packages.x86_64-linux.llvm-nac3 - (pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos + (pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out llvmPackages_14.bintools # for running nac3standalone demos packages.x86_64-linux.llvm-tools-irrt cargo rustc diff --git a/nac3standalone/demo/run_demo.sh b/nac3standalone/demo/run_demo.sh index bec2eb6..78e32dd 100755 --- a/nac3standalone/demo/run_demo.sh +++ b/nac3standalone/demo/run_demo.sh @@ -58,7 +58,7 @@ rm -f ./*.o ./*.bc demo if [ -z "$i686" ]; then $nac3standalone "${nac3args[@]}" clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c - clang -o demo module.o demo.o $DEMO_LINALG_STUB -lm -Wl,--no-warn-search-mismatch + clang -o demo module.o demo.o $DEMO_LINALG_STUB -fuse-ld=lld -lm else $nac3standalone --triple i686-unknown-linux-gnu --target-features +sse2 "${nac3args[@]}" clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c From 37df08b803c4c58b1d95eb735d9d1d4b97910e10 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 24 Jan 2025 11:01:30 +0800 Subject: [PATCH 26/49] [meta] Update dependencies --- Cargo.lock | 168 ++++++++++++++++++++++--------------------- nac3artiq/Cargo.toml | 4 +- nac3ast/Cargo.toml | 2 +- nac3core/Cargo.toml | 7 +- 4 files changed, 94 insertions(+), 87 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c1d9352..9435fed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "ahash" @@ -9,7 +9,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", "zerocopy", @@ -127,9 +127,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cc" -version = "1.2.9" +version = "1.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" +checksum = "13208fcbb66eaeffe09b99fffbe1af420f00a7b35aa99ad683dfc1aa76145229" dependencies = [ "shlex", ] @@ -142,9 +142,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" +checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796" dependencies = [ "clap_builder", "clap_derive", @@ -152,9 +152,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" +checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" dependencies = [ "anstream", "anstyle", @@ -200,9 +200,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -334,9 +334,15 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "fixedbitset" -version = "0.4.2" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + +[[package]] +name = "foldhash" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" [[package]] name = "fxhash" @@ -374,7 +380,19 @@ checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets", ] [[package]] @@ -389,20 +407,14 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -dependencies = [ - "ahash", -] - [[package]] name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "foldhash", +] [[package]] name = "heck" @@ -437,9 +449,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -498,9 +510,9 @@ checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] name = "itertools" -version = "0.13.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" dependencies = [ "either", ] @@ -522,9 +534,9 @@ dependencies = [ [[package]] name = "lalrpop" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06093b57658c723a21da679530e061a8c25340fa5a6f98e313b542268c7e2a1f" +checksum = "7047a26de42016abf8f181b46b398aef0b77ad46711df41847f6ed869a2a1d5b" dependencies = [ "ascii-canvas", "bit-set", @@ -544,9 +556,9 @@ dependencies = [ [[package]] name = "lalrpop-util" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feee752d43abd0f4807a921958ab4131f692a44d4d599733d4419c5d586176ce" +checksum = "e8d05b3fe34b8bd562c338db725dfa9beb9451a48f65f129ccb9538b48d2c93b" dependencies = [ "regex-automata", "rustversion", @@ -656,7 +668,7 @@ name = "nac3core" version = "0.1.0" dependencies = [ "crossbeam", - "indexmap 2.7.0", + "indexmap 2.7.1", "indoc", "inkwell", "insta", @@ -664,7 +676,6 @@ dependencies = [ "nac3core_derive", "nac3parser", "parking_lot", - "rayon", "regex", "strum", "strum_macros", @@ -752,12 +763,12 @@ dependencies = [ [[package]] name = "petgraph" -version = "0.6.5" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.7.0", + "indexmap 2.7.1", ] [[package]] @@ -980,27 +991,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", -] - -[[package]] -name = "rayon" -version = "1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", + "getrandom 0.2.15", ] [[package]] @@ -1050,9 +1041,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.43" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags", "errno", @@ -1069,9 +1060,9 @@ checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "same-file" @@ -1090,9 +1081,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "semver" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "serde" @@ -1116,9 +1107,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.135" +version = "1.0.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" +checksum = "930cfb6e6abf99298aaad7d29abbef7a9999a9a8806a40088f55f0dcec03146b" dependencies = [ "itoa", "memchr", @@ -1165,9 +1156,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "similar" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" [[package]] name = "siphasher" @@ -1189,12 +1180,11 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "string-interner" -version = "0.17.0" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c6a0d765f5807e98a091107bae0a56ea3799f66a5de47b2c84c94a39c09974e" +checksum = "1a3275464d7a9f2d4cac57c89c2ef96a8524dba2864c8d6f82e3980baf136f9b" dependencies = [ - "cfg-if", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "serde", ] @@ -1272,13 +1262,13 @@ checksum = "42a4d50cdb458045afc8131fd91b64904da29548bcb63c7236e0844936c13078" [[package]] name = "tempfile" -version = "3.15.0" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" +checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" dependencies = [ "cfg-if", "fastrand", - "getrandom", + "getrandom 0.3.1", "once_cell", "rustix", "windows-sys 0.59.0", @@ -1363,7 +1353,7 @@ version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap 2.7.0", + "indexmap 2.7.1", "serde", "serde_spanned", "toml_datetime", @@ -1372,9 +1362,9 @@ dependencies = [ [[package]] name = "trybuild" -version = "1.0.101" +version = "1.0.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8dcd332a5496c026f1e14b7f3d2b7bd98e509660c04239c58b0ba38a12daded4" +checksum = "b812699e0c4f813b872b373a4471717d9eb550da14b311058a4d9cf4173cbca6" dependencies = [ "dissimilar", "glob", @@ -1446,9 +1436,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" [[package]] name = "unicode-width" @@ -1518,6 +1508,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "winapi-util" version = "0.1.9" @@ -1611,13 +1610,22 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.24" +version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" +checksum = "ad699df48212c6cc6eb4435f35500ac6fd3b9913324f938aea302022ce19d310" dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags", +] + [[package]] name = "yaml-rust" version = "0.4.5" diff --git a/nac3artiq/Cargo.toml b/nac3artiq/Cargo.toml index 2da812e..fa80465 100644 --- a/nac3artiq/Cargo.toml +++ b/nac3artiq/Cargo.toml @@ -9,10 +9,10 @@ name = "nac3artiq" crate-type = ["cdylib"] [dependencies] -itertools = "0.13" +itertools = "0.14" pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] } parking_lot = "0.12" -tempfile = "3.13" +tempfile = "3.16" nac3core = { path = "../nac3core" } nac3ld = { path = "../nac3ld" } diff --git a/nac3ast/Cargo.toml b/nac3ast/Cargo.toml index dc2bd55..947be09 100644 --- a/nac3ast/Cargo.toml +++ b/nac3ast/Cargo.toml @@ -11,5 +11,5 @@ fold = [] [dependencies] parking_lot = "0.12" -string-interner = "0.17" +string-interner = "0.18" fxhash = "0.2" diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 6521a33..7badcee 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -10,11 +10,10 @@ derive = ["dep:nac3core_derive"] no-escape-analysis = [] [dependencies] -itertools = "0.13" +itertools = "0.14" crossbeam = "0.8" -indexmap = "2.6" +indexmap = "2.7" parking_lot = "0.12" -rayon = "1.10" nac3core_derive = { path = "nac3core_derive", optional = true } nac3parser = { path = "../nac3parser" } strum = "0.26" @@ -31,4 +30,4 @@ indoc = "2.0" insta = "=1.11.0" [build-dependencies] -regex = "1.10" +regex = "1.11" From bdeeced1223a29db30bd5a92f48e627e1c953b16 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 23 Jan 2025 14:55:49 +0800 Subject: [PATCH 27/49] [core] codegen: Normalize RangeType factory functions Better matches factory functions of other ProxyTypes. --- nac3core/src/codegen/mod.rs | 2 +- nac3core/src/codegen/test.rs | 3 ++- nac3core/src/codegen/types/range.rs | 38 +++++++++++++++++++++++------ 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 37e1bb3..73a28b7 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -800,7 +800,7 @@ pub fn gen_func_impl< Some(t) => t.as_basic_type_enum(), } }), - (primitives.range, RangeType::new(context).as_base_type().into()), + (primitives.range, RangeType::new_with_generator(generator, context).as_base_type().into()), (primitives.exception, { let name = "Exception"; if let Some(t) = module.get_struct_type(name) { diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index a58a984..01672c5 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -453,8 +453,9 @@ fn test_classes_list_type_new() { #[test] fn test_classes_range_type_new() { let ctx = inkwell::context::Context::create(); + let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); - let llvm_range = RangeType::new(&ctx); + let llvm_range = RangeType::new_with_generator(&generator, &ctx); assert!(RangeType::is_representable(llvm_range.as_base_type()).is_ok()); } diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index bdd4e79..b92d765 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -5,9 +5,12 @@ use inkwell::{ }; use super::ProxyType; -use crate::codegen::{ - values::{ProxyValue, RangeValue}, - {CodeGenContext, CodeGenerator}, +use crate::{ + codegen::{ + values::{ProxyValue, RangeValue}, + {CodeGenContext, CodeGenerator}, + }, + typecheck::typedef::{Type, TypeEnum}, }; /// Proxy type for a `range` type in LLVM. @@ -54,12 +57,33 @@ impl<'ctx> RangeType<'ctx> { llvm_i32.array_type(3).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`RangeType`]. - #[must_use] - pub fn new(ctx: &'ctx Context) -> Self { + fn new_impl(ctx: &'ctx Context) -> Self { let llvm_range = Self::llvm_type(ctx); - RangeType::from_type(llvm_range) + RangeType { ty: llvm_range } + } + + /// Creates an instance of [`RangeType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + RangeType::new_impl(ctx.ctx) + } + + /// Creates an instance of [`RangeType`]. + #[must_use] + pub fn new_with_generator(_: &G, ctx: &'ctx Context) -> Self { + Self::new_impl(ctx) + } + + /// Creates an [`RangeType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type(ctx: &mut CodeGenContext<'ctx, '_>, ty: Type) -> Self { + // Check unifier type + assert!( + matches!(&*ctx.unifier.get_ty_immutable(ty), TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap()) + ); + + Self::new(ctx) } /// Creates an [`RangeType`] from a [`PointerType`]. From 87a637b448db0613c8cc8b9a3a1d9cb91e8fa1f4 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 23 Jan 2025 14:45:45 +0800 Subject: [PATCH 28/49] [core] codegen: Refactor Proxy{Type,Value} for StructProxy{Type,Value} --- nac3artiq/src/codegen.rs | 6 +- nac3core/src/codegen/builtin_fns.rs | 6 +- nac3core/src/codegen/expr.rs | 4 +- .../src/codegen/irrt/ndarray/broadcast.rs | 7 +- nac3core/src/codegen/stmt.rs | 6 +- nac3core/src/codegen/test.rs | 4 +- nac3core/src/codegen/types/list.rs | 65 ++--- nac3core/src/codegen/types/mod.rs | 18 +- .../src/codegen/types/ndarray/broadcast.rs | 53 ++-- .../src/codegen/types/ndarray/contiguous.rs | 69 +++--- .../src/codegen/types/ndarray/indexing.rs | 47 ++-- nac3core/src/codegen/types/ndarray/mod.rs | 49 ++-- nac3core/src/codegen/types/ndarray/nditer.rs | 49 ++-- nac3core/src/codegen/types/range.rs | 233 +++++++++--------- nac3core/src/codegen/types/tuple.rs | 22 +- nac3core/src/codegen/types/utils/slice.rs | 97 ++++---- nac3core/src/codegen/values/list.rs | 11 +- nac3core/src/codegen/values/mod.rs | 20 +- .../src/codegen/values/ndarray/broadcast.rs | 11 +- .../src/codegen/values/ndarray/contiguous.rs | 11 +- .../src/codegen/values/ndarray/indexing.rs | 11 +- nac3core/src/codegen/values/ndarray/mod.rs | 11 +- nac3core/src/codegen/values/ndarray/nditer.rs | 11 +- nac3core/src/codegen/values/range.rs | 23 +- nac3core/src/codegen/values/tuple.rs | 11 +- nac3core/src/codegen/values/utils/slice.rs | 11 +- nac3core/src/toplevel/builtins.rs | 6 +- 27 files changed, 342 insertions(+), 530 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index e472705..4c86028 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -19,9 +19,9 @@ use nac3core::{ llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, type_aligned_alloca, - types::ndarray::NDArrayType, + types::{ndarray::NDArrayType, RangeType}, values::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, RangeValue, + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, @@ -1431,7 +1431,7 @@ fn polymorphic_print<'ctx>( fmt.push_str("range("); flush(ctx, generator, &mut fmt, &mut args); - let val = RangeValue::from_pointer_value(value.into_pointer_value(), None); + let val = RangeType::new(ctx).map_value(value.into_pointer_value(), None); let (start, stop, step) = destructure_range(ctx, val); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 911e3dc..6cacac4 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -11,10 +11,10 @@ use super::{ irrt::calculate_len_for_slice_range, llvm_intrinsics, macros::codegen_unreachable, - types::{ndarray::NDArrayType, ListType, TupleType}, + types::{ndarray::NDArrayType, ListType, RangeType, TupleType}, values::{ ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray}, - ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, + ProxyValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, }; @@ -47,7 +47,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( let range_ty = ctx.primitives.range; Ok(if ctx.unifier.unioned(arg_ty, range_ty) { - let arg = RangeValue::from_pointer_value(arg.into_pointer_value(), Some("range")); + let arg = RangeType::new(ctx).map_value(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); calculate_len_for_slice_range(generator, ctx, start, end, step) } else { diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index fd2cd28..4da0ef3 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -32,7 +32,7 @@ use super::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, }, - types::{ndarray::NDArrayType, ListType}, + types::{ndarray::NDArrayType, ListType, RangeType}, values::{ ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, @@ -1151,7 +1151,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { let iter_val = - RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range")); + RangeType::new(ctx).map_value(iter_val.into_pointer_value(), Some("range")); let (start, stop, step) = destructure_range(ctx, iter_val); let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap(); // add 1 to the length as the value is rounded to zero diff --git a/nac3core/src/codegen/irrt/ndarray/broadcast.rs b/nac3core/src/codegen/irrt/ndarray/broadcast.rs index fceba25..a7d40a5 100644 --- a/nac3core/src/codegen/irrt/ndarray/broadcast.rs +++ b/nac3core/src/codegen/irrt/ndarray/broadcast.rs @@ -55,10 +55,9 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>( let llvm_usize = ctx.get_size_type(); assert_eq!(num_shape_entries.get_type(), llvm_usize); - assert!(ShapeEntryType::is_type( - generator, - ctx.ctx, - shape_entries.base_ptr(ctx, generator).get_type() + assert!(ShapeEntryType::is_representable( + shape_entries.base_ptr(ctx, generator).get_type(), + llvm_usize, ) .is_ok()); assert_eq!(dst_ndims.get_type(), llvm_usize); diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 7b99bc2..e8f1d90 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -17,10 +17,10 @@ use super::{ gen_in_range_check, irrt::{handle_slice_indices, list_slice_assignment}, macros::codegen_unreachable, - types::ndarray::NDArrayType, + types::{ndarray::NDArrayType, RangeType}, values::{ ndarray::{RustNDIndex, ScalarOrNDArray}, - ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, RangeValue, + ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, }, CodeGenContext, CodeGenerator, }; @@ -511,7 +511,7 @@ pub fn gen_for( if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { let iter_val = - RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range")); + RangeType::new(ctx).map_value(iter_val.into_pointer_value(), Some("range")); // Internal variable for loop; Cannot be assigned let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 01672c5..ecc0ba9 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -455,8 +455,10 @@ fn test_classes_range_type_new() { let ctx = inkwell::context::Context::create(); let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); + let llvm_usize = generator.get_size_type(&ctx); + let llvm_range = RangeType::new_with_generator(&generator, &ctx); - assert!(RangeType::is_representable(llvm_range.as_base_type()).is_ok()); + assert!(RangeType::is_representable(llvm_range.as_base_type(), llvm_usize).is_ok()); } #[test] diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 637cced..60015b8 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -56,34 +56,6 @@ impl<'ctx> ListStructFields<'ctx> { } impl<'ctx> ListType<'ctx> { - /// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { - return Err(format!("Expected struct type for `list` type, got {llvm_ty}")); - }; - - let fields = ListStructFields::new(ctx, llvm_usize); - - check_struct_type_matches_fields( - fields, - llvm_ty, - "list", - &[(fields.items.name(), &|ty| { - if ty.is_pointer_type() { - Ok(()) - } else { - Err(format!("Expected T* for `list.items`, got {ty}")) - } - })], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> ListStructFields<'ctx> { @@ -184,7 +156,7 @@ impl<'ctx> ListType<'ctx> { /// Creates an [`ListType`] from a [`PointerType`]. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); let ctx = ptr_ty.get_context(); @@ -336,24 +308,39 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { type Base = PointerType<'ctx>; type Value = ListValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!("Expected struct type for `list` type, got {llvm_ty}")); + }; + + let fields = ListStructFields::new(ctx, llvm_usize); + + check_struct_type_matches_fields( + fields, + llvm_ty, + "list", + &[(fields.items.name(), &|ty| { + if ty.is_pointer_type() { + Ok(()) + } else { + Err(format!("Expected T* for `list.items`, got {ty}")) + } + })], + ) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 0a31d6a..5865d63 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -17,8 +17,7 @@ //! on the stack. use inkwell::{ - context::Context, - types::BasicType, + types::{BasicType, IntType}, values::{IntValue, PointerValue}, }; @@ -46,18 +45,15 @@ pub trait ProxyType<'ctx>: Into { /// The type of values represented by this type. type Value: ProxyValue<'ctx, Type = Self>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + /// Checks whether `llvm_ty` can be represented by this [`ProxyType`]. + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String>; - /// Checks whether `llvm_ty` can be represented by this [`ProxyType`]. - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String>; + /// Checks whether the type represented by `ty` expresses the same type represented by this + /// [`ProxyType`]. + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String>; /// Returns the type that should be used in `alloca` IR statements. fn alloca_type(&self) -> impl BasicType<'ctx>; diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs index 3a1fd8d..af1a26f 100644 --- a/nac3core/src/codegen/types/ndarray/broadcast.rs +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -32,28 +32,6 @@ pub struct ShapeEntryStructFields<'ctx> { } impl<'ctx> ShapeEntryType<'ctx> { - /// Checks whether `llvm_ty` represents a [`ShapeEntryType`], returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ndarray_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - return Err(format!( - "Expected struct type for `ShapeEntry` type, got {llvm_ndarray_ty}" - )); - }; - - check_struct_type_matches_fields( - Self::fields(ctx, llvm_usize), - llvm_ndarray_ty, - "NDArray", - &[], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields( @@ -103,7 +81,7 @@ impl<'ctx> ShapeEntryType<'ctx> { /// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } } @@ -152,24 +130,33 @@ impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { type Base = PointerType<'ctx>; type Value = ShapeEntryValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ndarray_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!( + "Expected struct type for `ShapeEntry` type, got {llvm_ndarray_ty}" + )); + }; + + check_struct_type_matches_fields( + Self::fields(ctx, llvm_usize), + llvm_ndarray_ty, + "NDArray", + &[], + ) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index c751d57..1987ab6 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -58,36 +58,6 @@ impl<'ctx> ContiguousNDArrayStructFields<'ctx> { } impl<'ctx> ContiguousNDArrayType<'ctx> { - /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { - return Err(format!( - "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" - )); - }; - - let fields = ContiguousNDArrayStructFields::new(ctx, llvm_usize); - - check_struct_type_matches_fields( - fields, - llvm_ty, - "ContiguousNDArray", - &[(fields.data.name(), &|ty| { - if ty.is_pointer_type() { - Ok(()) - } else { - Err(format!("Expected T* for `ContiguousNDArray.data`, got {ty}")) - } - })], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields( @@ -160,7 +130,7 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, ) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, item, llvm_usize } } @@ -222,24 +192,41 @@ impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { type Base = PointerType<'ctx>; type Value = ContiguousNDArrayValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!( + "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" + )); + }; + + let fields = ContiguousNDArrayStructFields::new(ctx, llvm_usize); + + check_struct_type_matches_fields( + fields, + llvm_ty, + "ContiguousNDArray", + &[(fields.data.name(), &|ty| { + if ty.is_pointer_type() { + Ok(()) + } else { + Err(format!("Expected T* for `ContiguousNDArray.data`, got {ty}")) + } + })], + ) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 3e4e136..8e15c90 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -35,25 +35,6 @@ pub struct NDIndexStructFields<'ctx> { } impl<'ctx> NDIndexType<'ctx> { - /// Checks whether `llvm_ty` represents a `ndindex` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { - return Err(format!( - "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" - )); - }; - - let fields = NDIndexStructFields::new(ctx, llvm_usize); - - check_struct_type_matches_fields(fields, llvm_ty, "NDIndex", &[]) - } - #[must_use] fn fields( ctx: impl AsContextRef<'ctx>, @@ -96,7 +77,7 @@ impl<'ctx> NDIndexType<'ctx> { #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } } @@ -180,24 +161,30 @@ impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { type Base = PointerType<'ctx>; type Value = NDIndexValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!( + "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" + )); + }; + + let fields = NDIndexStructFields::new(ctx, llvm_usize); + + check_struct_type_matches_fields(fields, llvm_ty, "NDIndex", &[]) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index fe73307..1743fe2 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -62,26 +62,6 @@ pub struct NDArrayStructFields<'ctx> { } impl<'ctx> NDArrayType<'ctx> { - /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ndarray_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); - }; - - check_struct_type_matches_fields( - Self::fields(ctx, llvm_usize), - llvm_ndarray_ty, - "NDArray", - &[], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields( @@ -211,7 +191,7 @@ impl<'ctx> NDArrayType<'ctx> { ndims: u64, llvm_usize: IntType<'ctx>, ) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); NDArrayType { ty: ptr_ty, dtype, ndims, llvm_usize } } @@ -450,24 +430,31 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { type Base = PointerType<'ctx>; type Value = NDArrayValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ndarray_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); + }; + + check_struct_type_matches_fields( + Self::fields(ctx, llvm_usize), + llvm_ndarray_ty, + "NDArray", + &[], + ) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 45b6bb0..6246eef 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -44,26 +44,6 @@ pub struct NDIterStructFields<'ctx> { } impl<'ctx> NDIterType<'ctx> { - /// Checks whether `llvm_ty` represents a `nditer` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ty else { - return Err(format!("Expected struct type for `NDIter` type, got {llvm_ty}")); - }; - - check_struct_type_matches_fields( - Self::fields(ctx, llvm_usize), - llvm_ndarray_ty, - "NDIter", - &[], - ) - } - /// Returns an instance of [`StructFields`] containing all field accessors for this type. #[must_use] fn fields(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> NDIterStructFields<'ctx> { @@ -110,7 +90,7 @@ impl<'ctx> NDIterType<'ctx> { /// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } } @@ -208,24 +188,31 @@ impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { type Base = PointerType<'ctx>; type Value = NDIterValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ty else { + return Err(format!("Expected struct type for `NDIter` type, got {llvm_ty}")); + }; + + check_struct_type_matches_fields( + Self::fields(ctx, llvm_usize), + llvm_ndarray_ty, + "NDIter", + &[], + ) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index b92d765..158152b 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -17,12 +17,125 @@ use crate::{ #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct RangeType<'ctx> { ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, } impl<'ctx> RangeType<'ctx> { - /// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not. - pub fn is_representable(llvm_ty: PointerType<'ctx>) -> Result<(), String> { - let llvm_range_ty = llvm_ty.get_element_type(); + /// Creates an LLVM type corresponding to the expected structure of a `Range`. + #[must_use] + fn llvm_type(ctx: &'ctx Context) -> PointerType<'ctx> { + // typedef int32_t Range[3]; + let llvm_i32 = ctx.i32_type(); + llvm_i32.array_type(3).ptr_type(AddressSpace::default()) + } + + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { + let llvm_range = Self::llvm_type(ctx); + + RangeType { ty: llvm_range, llvm_usize } + } + + /// Creates an instance of [`RangeType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`RangeType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + + /// Creates an [`RangeType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type(ctx: &mut CodeGenContext<'ctx, '_>, ty: Type) -> Self { + // Check unifier type + assert!( + matches!(&*ctx.unifier.get_ty_immutable(ty), TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap()) + ); + + Self::new(ctx) + } + + /// Creates an [`RangeType`] from a [`PointerType`]. + #[must_use] + pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); + + RangeType { ty: ptr_ty, llvm_usize } + } + + /// Returns the type of all fields of this `range` type. + #[must_use] + pub fn value_type(&self) -> IntType<'ctx> { + self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type() + } + + /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. + #[must_use] + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca_var(generator, ctx, name), + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`RangeValue`]. + #[must_use] + pub fn map_value( + &self, + value: <>::Value as ProxyValue<'ctx>>::Base, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { + type Base = PointerType<'ctx>; + type Value = RangeValue<'ctx>; + + fn is_representable( + llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + Self::has_same_repr(ty, llvm_usize) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn has_same_repr(ty: Self::Base, _: IntType<'ctx>) -> Result<(), String> { + let llvm_range_ty = ty.get_element_type(); let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else { return Err(format!("Expected array type for `range` type, got {llvm_range_ty}")); }; @@ -49,120 +162,6 @@ impl<'ctx> RangeType<'ctx> { Ok(()) } - /// Creates an LLVM type corresponding to the expected structure of a `Range`. - #[must_use] - fn llvm_type(ctx: &'ctx Context) -> PointerType<'ctx> { - // typedef int32_t Range[3]; - let llvm_i32 = ctx.i32_type(); - llvm_i32.array_type(3).ptr_type(AddressSpace::default()) - } - - fn new_impl(ctx: &'ctx Context) -> Self { - let llvm_range = Self::llvm_type(ctx); - - RangeType { ty: llvm_range } - } - - /// Creates an instance of [`RangeType`]. - #[must_use] - pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { - RangeType::new_impl(ctx.ctx) - } - - /// Creates an instance of [`RangeType`]. - #[must_use] - pub fn new_with_generator(_: &G, ctx: &'ctx Context) -> Self { - Self::new_impl(ctx) - } - - /// Creates an [`RangeType`] from a [unifier type][Type]. - #[must_use] - pub fn from_unifier_type(ctx: &mut CodeGenContext<'ctx, '_>, ty: Type) -> Self { - // Check unifier type - assert!( - matches!(&*ctx.unifier.get_ty_immutable(ty), TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap()) - ); - - Self::new(ctx) - } - - /// Creates an [`RangeType`] from a [`PointerType`]. - #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self { - debug_assert!(Self::is_representable(ptr_ty).is_ok()); - - RangeType { ty: ptr_ty } - } - - /// Returns the type of all fields of this `range` type. - #[must_use] - pub fn value_type(&self) -> IntType<'ctx> { - self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type() - } - - /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. - /// - /// See [`ProxyType::raw_alloca`]. - #[must_use] - pub fn alloca( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Value { - >::Value::from_pointer_value(self.raw_alloca(ctx, name), name) - } - - /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. - /// - /// See [`ProxyType::raw_alloca_var`]. - #[must_use] - pub fn alloca_var( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Value { - >::Value::from_pointer_value( - self.raw_alloca_var(generator, ctx, name), - name, - ) - } - - /// Converts an existing value into a [`RangeValue`]. - #[must_use] - pub fn map_value( - &self, - value: <>::Value as ProxyValue<'ctx>>::Base, - name: Option<&'ctx str>, - ) -> >::Value { - >::Value::from_pointer_value(value, name) - } -} - -impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { - type Base = PointerType<'ctx>; - type Value = RangeValue<'ctx>; - - fn is_type( - generator: &G, - ctx: &'ctx Context, - llvm_ty: impl BasicType<'ctx>, - ) -> Result<(), String> { - if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) - } else { - Err(format!("Expected pointer type, got {llvm_ty:?}")) - } - } - - fn is_representable( - _: &G, - _: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty) - } - fn alloca_type(&self) -> impl BasicType<'ctx> { self.as_base_type().get_element_type().into_struct_type() } diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index 5c73652..d05b7f2 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -21,11 +21,6 @@ pub struct TupleType<'ctx> { } impl<'ctx> TupleType<'ctx> { - /// Checks whether `llvm_ty` represents any tuple type, returning [Err] if it does not. - pub fn is_representable(_value: StructType<'ctx>) -> Result<(), String> { - Ok(()) - } - /// Creates an LLVM type corresponding to the expected structure of a tuple. #[must_use] fn llvm_type(ctx: &'ctx Context, tys: &[BasicTypeEnum<'ctx>]) -> StructType<'ctx> { @@ -83,7 +78,7 @@ impl<'ctx> TupleType<'ctx> { /// Creates an [`TupleType`] from a [`StructType`]. #[must_use] pub fn from_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_representable(struct_ty).is_ok()); + debug_assert!(Self::has_same_repr(struct_ty, llvm_usize).is_ok()); TupleType { ty: struct_ty, llvm_usize } } @@ -165,24 +160,19 @@ impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { type Base = StructType<'ctx>; type Value = TupleValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::StructType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected struct type, got {llvm_ty:?}")) } } - fn is_representable( - _generator: &G, - _ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty) + fn has_same_repr(_: Self::Base, _: IntType<'ctx>) -> Result<(), String> { + Ok(()) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index 0ef4d1b..b7fafef 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -61,50 +61,6 @@ impl<'ctx> SliceFields<'ctx> { } impl<'ctx> SliceType<'ctx> { - /// Checks whether `llvm_ty` represents a `slice` type, returning [Err] if it does not. - pub fn is_representable( - llvm_ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - let ctx = llvm_ty.get_context(); - - let fields = SliceFields::new(ctx, llvm_usize); - - let llvm_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { - return Err(format!("Expected struct type for `Slice` type, got {llvm_ty}")); - }; - - check_struct_type_matches_fields( - fields, - llvm_ty, - "Slice", - &[ - (fields.start.name(), &|ty| { - if ty.is_int_type() { - Ok(()) - } else { - Err(format!("Expected int type for `Slice.start`, got {ty}")) - } - }), - (fields.stop.name(), &|ty| { - if ty.is_int_type() { - Ok(()) - } else { - Err(format!("Expected int type for `Slice.stop`, got {ty}")) - } - }), - (fields.step.name(), &|ty| { - if ty.is_int_type() { - Ok(()) - } else { - Err(format!("Expected int type for `Slice.step`, got {ty}")) - } - }), - ], - ) - } - // TODO: Move this into e.g. StructProxyType #[must_use] pub fn get_fields(&self) -> SliceFields<'ctx> { @@ -156,7 +112,7 @@ impl<'ctx> SliceType<'ctx> { int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>, ) -> Self { - debug_assert!(Self::is_representable(ptr_ty, int_ty).is_ok()); + debug_assert!(Self::has_same_repr(ptr_ty, int_ty).is_ok()); Self { ty: ptr_ty, int_ty, llvm_usize } } @@ -221,24 +177,55 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { type Base = PointerType<'ctx>; type Value = SliceValue<'ctx>; - fn is_type( - generator: &G, - ctx: &'ctx Context, + fn is_representable( llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, ) -> Result<(), String> { if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { - >::is_representable(generator, ctx, ty) + Self::has_same_repr(ty, llvm_usize) } else { Err(format!("Expected pointer type, got {llvm_ty:?}")) } } - fn is_representable( - generator: &G, - ctx: &'ctx Context, - llvm_ty: Self::Base, - ) -> Result<(), String> { - Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let fields = SliceFields::new(ctx, llvm_usize); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!("Expected struct type for `Slice` type, got {llvm_ty}")); + }; + + check_struct_type_matches_fields( + fields, + llvm_ty, + "Slice", + &[ + (fields.start.name(), &|ty| { + if ty.is_int_type() { + Ok(()) + } else { + Err(format!("Expected int type for `Slice.start`, got {ty}")) + } + }), + (fields.stop.name(), &|ty| { + if ty.is_int_type() { + Ok(()) + } else { + Err(format!("Expected int type for `Slice.stop`, got {ty}")) + } + }), + (fields.step.name(), &|ty| { + if ty.is_int_type() { + Ok(()) + } else { + Err(format!("Expected int type for `Slice.step`, got {ty}")) + } + }), + ], + ) } fn alloca_type(&self) -> impl BasicType<'ctx> { diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 4ba5b6a..075f7f6 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -21,15 +21,6 @@ pub struct ListValue<'ctx> { } impl<'ctx> ListValue<'ctx> { - /// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - ListType::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`ListValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -37,7 +28,7 @@ impl<'ctx> ListValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); ListValue { value: ptr, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index c789fe0..dae10f3 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -1,7 +1,6 @@ -use inkwell::{context::Context, values::BasicValue}; +use inkwell::{types::IntType, values::BasicValue}; use super::types::ProxyType; -use crate::codegen::CodeGenerator; pub use array::*; pub use list::*; pub use range::*; @@ -24,21 +23,8 @@ pub trait ProxyValue<'ctx>: Into { type Type: ProxyType<'ctx, Value = Self>; /// Checks whether `value` can be represented by this [`ProxyValue`]. - fn is_instance( - generator: &G, - ctx: &'ctx Context, - value: impl BasicValue<'ctx>, - ) -> Result<(), String> { - Self::Type::is_type(generator, ctx, value.as_basic_value_enum().get_type()) - } - - /// Checks whether `value` can be represented by this [`ProxyValue`]. - fn is_representable( - generator: &G, - ctx: &'ctx Context, - value: Self::Base, - ) -> Result<(), String> { - Self::is_instance(generator, ctx, value.as_basic_value_enum()) + fn is_instance(value: impl BasicValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { + Self::Type::is_representable(value.as_basic_value_enum().get_type(), llvm_usize) } /// Returns the [type][ProxyType] of this value. diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index b5182a2..acbd299 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -26,15 +26,6 @@ pub struct ShapeEntryValue<'ctx> { } impl<'ctx> ShapeEntryValue<'ctx> { - /// Checks whether `value` is an instance of `ShapeEntry`, returning [Err] if `value` is - /// not an instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - >::Type::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`ShapeEntryValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -42,7 +33,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 0fbb85f..65e8025 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -23,15 +23,6 @@ pub struct ContiguousNDArrayValue<'ctx> { } impl<'ctx> ContiguousNDArrayValue<'ctx> { - /// Checks whether `value` is an instance of `ContiguousNDArray`, returning [Err] if `value` is - /// not an instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - >::Type::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`ContiguousNDArrayValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -40,7 +31,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, item: dtype, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 60c9c3b..3b7b8f1 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -30,15 +30,6 @@ pub struct NDIndexValue<'ctx> { } impl<'ctx> NDIndexValue<'ctx> { - /// Checks whether `value` is an instance of `ndindex`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - >::Type::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`NDIndexValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -46,7 +37,7 @@ impl<'ctx> NDIndexValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 1bf5db3..cba35ad 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -49,15 +49,6 @@ pub struct NDArrayValue<'ctx> { } impl<'ctx> NDArrayValue<'ctx> { - /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - NDArrayType::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`NDArrayValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -67,7 +58,7 @@ impl<'ctx> NDArrayValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 86f370e..5479b92 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -23,15 +23,6 @@ pub struct NDIterValue<'ctx> { } impl<'ctx> NDIterValue<'ctx> { - /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - ::Type::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`NDArrayValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -41,7 +32,7 @@ impl<'ctx> NDIterValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, parent, indices, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/range.rs b/nac3core/src/codegen/values/range.rs index 7e9976a..b1a5806 100644 --- a/nac3core/src/codegen/values/range.rs +++ b/nac3core/src/codegen/values/range.rs @@ -1,4 +1,7 @@ -use inkwell::values::{BasicValueEnum, IntValue, PointerValue}; +use inkwell::{ + types::IntType, + values::{BasicValueEnum, IntValue, PointerValue}, +}; use super::ProxyValue; use crate::codegen::{types::RangeType, CodeGenContext}; @@ -7,21 +10,21 @@ use crate::codegen::{types::RangeType, CodeGenContext}; #[derive(Copy, Clone)] pub struct RangeValue<'ctx> { value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, } impl<'ctx> RangeValue<'ctx> { - /// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance. - pub fn is_representable(value: PointerValue<'ctx>) -> Result<(), String> { - RangeType::is_representable(value.get_type()) - } - /// Creates an [`RangeValue`] from a [`PointerValue`]. #[must_use] - pub fn from_pointer_value(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self { - debug_assert!(Self::is_representable(ptr).is_ok()); + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); - RangeValue { value: ptr, name } + RangeValue { value: ptr, llvm_usize, name } } fn ptr_to_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { @@ -138,7 +141,7 @@ impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { type Type = RangeType<'ctx>; fn get_type(&self) -> Self::Type { - RangeType::from_type(self.value.get_type()) + RangeType::from_type(self.value.get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs index 5167e47..4558f18 100644 --- a/nac3core/src/codegen/values/tuple.rs +++ b/nac3core/src/codegen/values/tuple.rs @@ -14,15 +14,6 @@ pub struct TupleValue<'ctx> { } impl<'ctx> TupleValue<'ctx> { - /// Checks whether `value` is an instance of `tuple`, returning [Err] if `value` is not an - /// instance. - pub fn is_representable( - value: StructValue<'ctx>, - _llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - TupleType::is_representable(value.get_type()) - } - /// Creates an [`TupleValue`] from a [`StructValue`]. #[must_use] pub fn from_struct_value( @@ -30,7 +21,7 @@ impl<'ctx> TupleValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(value, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(value, llvm_usize).is_ok()); Self { value, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/utils/slice.rs b/nac3core/src/codegen/values/utils/slice.rs index dffe6ce..df9e4de 100644 --- a/nac3core/src/codegen/values/utils/slice.rs +++ b/nac3core/src/codegen/values/utils/slice.rs @@ -24,15 +24,6 @@ pub struct SliceValue<'ctx> { } impl<'ctx> SliceValue<'ctx> { - /// Checks whether `value` is an instance of `ContiguousNDArray`, returning [Err] if `value` is - /// not an instance. - pub fn is_representable( - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> Result<(), String> { - >::Type::is_representable(value.get_type(), llvm_usize) - } - /// Creates an [`SliceValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -41,7 +32,7 @@ impl<'ctx> SliceValue<'ctx> { llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); Self { value: ptr, int_ty, llvm_usize, name } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e06366c..1c3b085 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -17,10 +17,10 @@ use crate::{ builtin_fns, numpy::*, stmt::{exn_constructor, gen_if_callback}, - types::ndarray::NDArrayType, + types::{ndarray::NDArrayType, RangeType}, values::{ ndarray::{shape::parse_numpy_int_sequence, ScalarOrNDArray}, - ProxyValue, RangeValue, + ProxyValue, }, }, symbol_resolver::SymbolValue, @@ -577,7 +577,7 @@ impl<'a> BuiltinBuilder<'a> { let (zelf_ty, zelf) = obj.unwrap(); let zelf = zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value(); - let zelf = RangeValue::from_pointer_value(zelf, Some("range")); + let zelf = RangeType::new(ctx).map_value(zelf, Some("range")); let mut start = None; let mut stop = None; From 96e98947cccdcdfdd6d69865fdcadb85d116fc3f Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 23 Jan 2025 14:46:30 +0800 Subject: [PATCH 29/49] [core] codegen: Add StructProxy{Type,Value} --- nac3core/src/codegen/types/structure.rs | 44 +++++++++++++++++++++++- nac3core/src/codegen/values/mod.rs | 1 + nac3core/src/codegen/values/structure.rs | 24 +++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 nac3core/src/codegen/values/structure.rs diff --git a/nac3core/src/codegen/types/structure.rs b/nac3core/src/codegen/types/structure.rs index 87781d1..0e35c81 100644 --- a/nac3core/src/codegen/types/structure.rs +++ b/nac3core/src/codegen/types/structure.rs @@ -2,13 +2,55 @@ use std::marker::PhantomData; use inkwell::{ context::AsContextRef, - types::{BasicTypeEnum, IntType, StructType}, + types::{BasicTypeEnum, IntType, PointerType, StructType}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, + AddressSpace, }; use itertools::Itertools; +use super::ProxyType; use crate::codegen::CodeGenContext; +/// A LLVM type that is used to represent a corresponding structure-like type in NAC3. +pub trait StructProxyType<'ctx>: ProxyType<'ctx, Base = PointerType<'ctx>> { + /// The concrete type of [`StructFields`]. + type StructFields: StructFields<'ctx>; + + /// Whether this [`StructProxyType`] has the same LLVM type representation as + /// [`llvm_ty`][StructType]. + fn has_same_struct_repr( + llvm_ty: StructType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + Self::has_same_pointer_repr(llvm_ty.ptr_type(AddressSpace::default()), llvm_usize) + } + + /// Whether this [`StructProxyType`] has the same LLVM type representation as + /// [`llvm_ty`][PointerType]. + fn has_same_pointer_repr( + llvm_ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + Self::has_same_repr(llvm_ty, llvm_usize) + } + + /// Returns the fields present in this [`StructProxyType`]. + #[must_use] + fn get_fields(&self) -> Self::StructFields; + + /// Returns the [`StructType`]. + #[must_use] + fn get_struct_type(&self) -> StructType<'ctx> { + self.as_base_type().get_element_type().into_struct_type() + } + + /// Returns the [`PointerType`] representing this type. + #[must_use] + fn get_pointer_type(&self) -> PointerType<'ctx> { + self.as_base_type() + } +} + /// Trait indicating that the structure is a field-wise representation of an LLVM structure. /// /// # Usage diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index dae10f3..9a24635 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -10,6 +10,7 @@ mod array; mod list; pub mod ndarray; mod range; +pub mod structure; mod tuple; pub mod utils; diff --git a/nac3core/src/codegen/values/structure.rs b/nac3core/src/codegen/values/structure.rs new file mode 100644 index 0000000..dfe4543 --- /dev/null +++ b/nac3core/src/codegen/values/structure.rs @@ -0,0 +1,24 @@ +use inkwell::values::{BasicValueEnum, PointerValue, StructValue}; + +use super::ProxyValue; +use crate::codegen::{types::structure::StructProxyType, CodeGenContext}; + +/// An LLVM value that is used to represent a corresponding structure-like value in NAC3. +pub trait StructProxyValue<'ctx>: + ProxyValue<'ctx, Base = PointerValue<'ctx>, Type: StructProxyType<'ctx, Value = Self>> +{ + /// Returns this value as a [`StructValue`]. + #[must_use] + fn get_struct_value(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructValue<'ctx> { + ctx.builder + .build_load(self.get_pointer_value(ctx), "") + .map(BasicValueEnum::into_struct_value) + .unwrap() + } + + /// Returns this value as a [`PointerValue`]. + #[must_use] + fn get_pointer_value(&self, _: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.as_base_value() + } +} From b521bc0c821643e9d9bc501e0b89544843a6d076 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 24 Jan 2025 10:10:23 +0800 Subject: [PATCH 30/49] [core] codegen: Add Proxy{Type,Value}::as_abi_{type,value} Needed for PtrToOrBasic{Type,Value}. --- nac3artiq/src/codegen.rs | 2 +- nac3artiq/src/symbol_resolver.rs | 6 +- nac3core/src/codegen/builtin_fns.rs | 130 +++++++++--------- nac3core/src/codegen/expr.rs | 18 +-- nac3core/src/codegen/irrt/ndarray/array.rs | 4 +- nac3core/src/codegen/irrt/ndarray/basic.rs | 33 +++-- .../src/codegen/irrt/ndarray/broadcast.rs | 2 +- nac3core/src/codegen/irrt/ndarray/indexing.rs | 4 +- nac3core/src/codegen/irrt/ndarray/iter.rs | 8 +- .../src/codegen/irrt/ndarray/transpose.rs | 4 +- nac3core/src/codegen/mod.rs | 8 +- nac3core/src/codegen/numpy.rs | 16 +-- nac3core/src/codegen/test.rs | 6 +- nac3core/src/codegen/types/list.rs | 7 +- nac3core/src/codegen/types/mod.rs | 12 +- nac3core/src/codegen/types/ndarray/array.rs | 10 +- .../src/codegen/types/ndarray/broadcast.rs | 7 +- .../src/codegen/types/ndarray/contiguous.rs | 7 +- .../src/codegen/types/ndarray/indexing.rs | 7 +- nac3core/src/codegen/types/ndarray/mod.rs | 7 +- nac3core/src/codegen/types/ndarray/nditer.rs | 7 +- nac3core/src/codegen/types/range.rs | 9 +- nac3core/src/codegen/types/tuple.rs | 5 + nac3core/src/codegen/types/utils/slice.rs | 7 +- nac3core/src/codegen/values/list.rs | 7 +- nac3core/src/codegen/values/mod.rs | 14 +- .../src/codegen/values/ndarray/broadcast.rs | 5 + .../src/codegen/values/ndarray/contiguous.rs | 13 +- .../src/codegen/values/ndarray/indexing.rs | 5 + nac3core/src/codegen/values/ndarray/mod.rs | 19 ++- nac3core/src/codegen/values/ndarray/nditer.rs | 9 +- nac3core/src/codegen/values/range.rs | 11 +- nac3core/src/codegen/values/tuple.rs | 5 + nac3core/src/codegen/values/utils/slice.rs | 5 + nac3core/src/toplevel/builtins.rs | 8 +- 35 files changed, 268 insertions(+), 159 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 4c86028..2cc5438 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -761,7 +761,7 @@ fn format_rpc_ret<'ctx>( ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.position_at_end(tail_bb); - ndarray.as_base_value().into() + ndarray.as_abi_value(ctx).into() } _ => { diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 4b398a9..06a9400 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1146,7 +1146,7 @@ impl InnerResolver { if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global( - llvm_ndarray.as_base_type().get_element_type().into_struct_type(), + llvm_ndarray.as_abi_type().get_element_type().into_struct_type(), Some(AddressSpace::default()), &id_str, ) @@ -1316,7 +1316,7 @@ impl InnerResolver { }; let ndarray = llvm_ndarray - .as_base_type() + .as_abi_type() .get_element_type() .into_struct_type() .const_named_struct(&[ @@ -1328,7 +1328,7 @@ impl InnerResolver { ]); let ndarray_global = ctx.module.add_global( - llvm_ndarray.as_base_type().get_element_type().into_struct_type(), + llvm_ndarray.as_abi_type().get_element_type().into_struct_type(), Some(AddressSpace::default()), &id_str, ); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 6cacac4..20a89d0 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,6 +1,6 @@ use inkwell::{ types::BasicTypeEnum, - values::{BasicValue, BasicValueEnum, IntValue}, + values::{BasicValueEnum, IntValue}, FloatPredicate, IntPredicate, OptimizationLevel, }; use itertools::Itertools; @@ -137,7 +137,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "int32", &[n_ty]), @@ -197,7 +197,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "int64", &[n_ty]), @@ -273,7 +273,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "uint32", &[n_ty]), @@ -338,7 +338,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "uint64", &[n_ty]), @@ -402,7 +402,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, "float", &[n_ty]), @@ -448,7 +448,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -485,7 +485,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -550,7 +550,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -600,7 +600,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -650,7 +650,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -767,7 +767,7 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1026,7 +1026,7 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - result.as_base_value().into() + result.as_abi_value(ctx).into() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1653,11 +1653,11 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( let out_c = out.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_cholesky( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_qr` linalg function @@ -1699,20 +1699,20 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_np_linalg_qr( ctx, - x1_c.as_base_value().into(), - q_c.as_base_value().into(), - r_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + q_c.as_abi_value(ctx).into(), + r_c.as_abi_value(ctx).into(), None, ); - let q = q.as_base_value().as_basic_value_enum(); - let r = r.as_base_value().as_basic_value_enum(); + let q = q.as_abi_value(ctx); + let r = r.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[q.get_type(), r.get_type()]).construct_from_objects( ctx, - [q, r], + [q.into(), r.into()], None, ); - Ok(tuple.as_base_value().into()) + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_svd` linalg function @@ -1760,19 +1760,19 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_np_linalg_svd( ctx, - x1_c.as_base_value().into(), - u_c.as_base_value().into(), - s_c.as_base_value().into(), - vh_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + u_c.as_abi_value(ctx).into(), + s_c.as_abi_value(ctx).into(), + vh_c.as_abi_value(ctx).into(), None, ); - let u = u.as_base_value().as_basic_value_enum(); - let s = s.as_base_value().as_basic_value_enum(); - let vh = vh.as_base_value().as_basic_value_enum(); + let u = u.as_abi_value(ctx); + let s = s.as_abi_value(ctx); + let vh = vh.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[u.get_type(), s.get_type(), vh.get_type()]) - .construct_from_objects(ctx, [u, s, vh], None); - Ok(tuple.as_base_value().into()) + .construct_from_objects(ctx, [u.into(), s.into(), vh.into()], None); + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_inv` linalg function @@ -1800,12 +1800,12 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( let out_c = out.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_inv( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_pinv` linalg function @@ -1845,12 +1845,12 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( let out_c = out.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_pinv( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `sp_linalg_lu` linalg function @@ -1892,20 +1892,20 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let u_c = u.make_contiguous_ndarray(generator, ctx); extern_fns::call_sp_linalg_lu( ctx, - x1_c.as_base_value().into(), - l_c.as_base_value().into(), - u_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + l_c.as_abi_value(ctx).into(), + u_c.as_abi_value(ctx).into(), None, ); - let l = l.as_base_value().as_basic_value_enum(); - let u = u.as_base_value().as_basic_value_enum(); + let l = l.as_abi_value(ctx); + let u = u.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[l.get_type(), u.get_type()]).construct_from_objects( ctx, - [l, u], + [l.into(), u.into()], None, ); - Ok(tuple.as_base_value().into()) + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_matrix_power` linalg function @@ -1953,13 +1953,13 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_np_linalg_matrix_power( ctx, - x1_c.as_base_value().into(), - x2_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + x2_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); - Ok(out.as_base_value().into()) + Ok(out.as_abi_value(ctx).into()) } /// Invokes the `np_linalg_det` linalg function @@ -1993,8 +1993,8 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( let out_c = det.make_contiguous_ndarray(generator, ctx); extern_fns::call_np_linalg_det( ctx, - x1_c.as_base_value().into(), - out_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + out_c.as_abi_value(ctx).into(), None, ); @@ -2035,20 +2035,20 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let z_c = z.make_contiguous_ndarray(generator, ctx); extern_fns::call_sp_linalg_schur( ctx, - x1_c.as_base_value().into(), - t_c.as_base_value().into(), - z_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + t_c.as_abi_value(ctx).into(), + z_c.as_abi_value(ctx).into(), None, ); - let t = t.as_base_value().as_basic_value_enum(); - let z = z.as_base_value().as_basic_value_enum(); + let t = t.as_abi_value(ctx); + let z = z.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[t.get_type(), z.get_type()]).construct_from_objects( ctx, - [t, z], + [t.into(), z.into()], None, ); - Ok(tuple.as_base_value().into()) + Ok(tuple.as_abi_value(ctx).into()) } /// Invokes the `sp_linalg_hessenberg` linalg function @@ -2083,18 +2083,18 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let q_c = q.make_contiguous_ndarray(generator, ctx); extern_fns::call_sp_linalg_hessenberg( ctx, - x1_c.as_base_value().into(), - h_c.as_base_value().into(), - q_c.as_base_value().into(), + x1_c.as_abi_value(ctx).into(), + h_c.as_abi_value(ctx).into(), + q_c.as_abi_value(ctx).into(), None, ); - let h = h.as_base_value().as_basic_value_enum(); - let q = q.as_base_value().as_basic_value_enum(); + let h = h.as_abi_value(ctx); + let q = q.as_abi_value(ctx); let tuple = TupleType::new(ctx, &[h.get_type(), q.get_type()]).construct_from_objects( ctx, - [h, q], + [h.into(), q.into()], None, ); - Ok(tuple.as_base_value().into()) + Ok(tuple.as_abi_value(ctx).into()) } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 4da0ef3..7a1d42f 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1307,7 +1307,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( emit_cont_bb(ctx, list); - Ok(Some(list.as_base_value().into())) + Ok(Some(list.as_abi_value(ctx).into())) } /// Generates LLVM IR for a binary operator expression using the [`Type`] and @@ -1437,7 +1437,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().const_zero(), ); - Ok(Some(new_list.as_base_value().into())) + Ok(Some(new_list.as_abi_value(ctx).into())) } Operator::Mult => { @@ -1524,7 +1524,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( llvm_usize.const_int(1, false), )?; - Ok(Some(new_list.as_base_value().into())) + Ok(Some(new_list.as_abi_value(ctx).into())) } _ => todo!("Operator not supported"), @@ -1601,7 +1601,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Ok(result) }) .unwrap(); - Ok(Some(result.as_base_value().into())) + Ok(Some(result.as_abi_value(ctx).into())) } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); @@ -1796,7 +1796,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( }, )?; - mapped_ndarray.as_base_value().into() + mapped_ndarray.as_abi_value(ctx).into() } else { unimplemented!() })) @@ -1883,7 +1883,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( }, )?; - return Ok(Some(result_ndarray.as_base_value().into())); + return Ok(Some(result_ndarray.as_abi_value(ctx).into())); } } @@ -2493,7 +2493,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ); ctx.builder.build_store(elem_ptr, *v).unwrap(); } - arr_str_ptr.as_base_value().into() + arr_str_ptr.as_abi_value(ctx).into() } ExprKind::Tuple { elts, .. } => { let elements_val = elts @@ -2988,7 +2988,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v, (start, end, step), ); - res_array_ret.as_base_value().into() + res_array_ret.as_abi_value(ctx).into() } else { let len = v.load_size(ctx, Some("len")); let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { @@ -3050,7 +3050,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .index(generator, ctx, &indices) .split_unsized(generator, ctx) .to_basic_value_enum(); - return Ok(Some(ValueEnum::Dynamic(result))); + return Ok(Some(result.into())); } TypeEnum::TTuple { .. } => { let index: u32 = diff --git a/nac3core/src/codegen/irrt/ndarray/array.rs b/nac3core/src/codegen/irrt/ndarray/array.rs index 5e9c0f0..63a2ab0 100644 --- a/nac3core/src/codegen/irrt/ndarray/array.rs +++ b/nac3core/src/codegen/irrt/ndarray/array.rs @@ -36,7 +36,7 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato ctx, &name, None, - &[list.as_base_value().into(), ndims.into(), shape.base_ptr(ctx, generator).into()], + &[list.as_abi_value(ctx).into(), ndims.into(), shape.base_ptr(ctx, generator).into()], None, None, ); @@ -65,7 +65,7 @@ pub fn call_nac3_ndarray_array_write_list_to_array<'ctx>( ctx, &name, None, - &[list.as_base_value().into(), ndarray.as_base_value().into()], + &[list.as_abi_value(ctx).into(), ndarray.as_abi_value(ctx).into()], None, None, ); diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs index aa792b1..5f291c8 100644 --- a/nac3core/src/codegen/irrt/ndarray/basic.rs +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -93,7 +93,7 @@ pub fn call_nac3_ndarray_size<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size"); @@ -101,7 +101,7 @@ pub fn call_nac3_ndarray_size<'ctx>( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("size"), None, ) @@ -118,7 +118,7 @@ pub fn call_nac3_ndarray_nbytes<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes"); @@ -126,7 +126,7 @@ pub fn call_nac3_ndarray_nbytes<'ctx>( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("nbytes"), None, ) @@ -143,7 +143,7 @@ pub fn call_nac3_ndarray_len<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len"); @@ -151,7 +151,7 @@ pub fn call_nac3_ndarray_len<'ctx>( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("len"), None, ) @@ -167,7 +167,7 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous"); @@ -175,7 +175,7 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( ctx, &name, Some(llvm_i1.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], Some("is_c_contiguous"), None, ) @@ -194,7 +194,7 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); assert_eq!(index.get_type(), llvm_usize); @@ -204,7 +204,10 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( ctx, &name, Some(llvm_pi8.into()), - &[(llvm_ndarray.into(), ndarray.as_base_value().into()), (llvm_usize.into(), index.into())], + &[ + (llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()), + (llvm_usize.into(), index.into()), + ], Some("pelement"), None, ) @@ -227,7 +230,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_usize = ctx.get_size_type(); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); assert_eq!( BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), @@ -241,7 +244,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized &name, Some(llvm_pi8.into()), &[ - (llvm_ndarray.into(), ndarray.as_base_value().into()), + (llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()), (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), ], Some("pelement"), @@ -258,7 +261,7 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) { - let llvm_ndarray = ndarray.get_type().as_base_type(); + let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape"); @@ -266,7 +269,7 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( ctx, &name, None, - &[(llvm_ndarray.into(), ndarray.as_base_value().into())], + &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], None, None, ); @@ -288,7 +291,7 @@ pub fn call_nac3_ndarray_copy_data<'ctx>( ctx, &name, None, - &[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()], + &[src_ndarray.as_abi_value(ctx).into(), dst_ndarray.as_abi_value(ctx).into()], None, None, ); diff --git a/nac3core/src/codegen/irrt/ndarray/broadcast.rs b/nac3core/src/codegen/irrt/ndarray/broadcast.rs index a7d40a5..59b0e4c 100644 --- a/nac3core/src/codegen/irrt/ndarray/broadcast.rs +++ b/nac3core/src/codegen/irrt/ndarray/broadcast.rs @@ -30,7 +30,7 @@ pub fn call_nac3_ndarray_broadcast_to<'ctx>( ctx, &name, None, - &[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()], + &[src_ndarray.as_abi_value(ctx).into(), dst_ndarray.as_abi_value(ctx).into()], None, None, ); diff --git a/nac3core/src/codegen/irrt/ndarray/indexing.rs b/nac3core/src/codegen/irrt/ndarray/indexing.rs index df5b27d..0d5d920 100644 --- a/nac3core/src/codegen/irrt/ndarray/indexing.rs +++ b/nac3core/src/codegen/irrt/ndarray/indexing.rs @@ -25,8 +25,8 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( &[ indices.size(ctx, generator).into(), indices.base_ptr(ctx, generator).into(), - src_ndarray.as_base_value().into(), - dst_ndarray.as_base_value().into(), + src_ndarray.as_abi_value(ctx).into(), + dst_ndarray.as_abi_value(ctx).into(), ], None, None, diff --git a/nac3core/src/codegen/irrt/ndarray/iter.rs b/nac3core/src/codegen/irrt/ndarray/iter.rs index ad90178..e4424df 100644 --- a/nac3core/src/codegen/irrt/ndarray/iter.rs +++ b/nac3core/src/codegen/irrt/ndarray/iter.rs @@ -40,8 +40,8 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( &name, None, &[ - (iter.get_type().as_base_type().into(), iter.as_base_value().into()), - (ndarray.get_type().as_base_type().into(), ndarray.as_base_value().into()), + (iter.get_type().as_abi_type().into(), iter.as_abi_value(ctx).into()), + (ndarray.get_type().as_abi_type().into(), ndarray.as_abi_value(ctx).into()), (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), ], None, @@ -63,7 +63,7 @@ pub fn call_nac3_nditer_has_element<'ctx>( ctx, &name, Some(ctx.ctx.bool_type().into()), - &[iter.as_base_value().into()], + &[iter.as_abi_value(ctx).into()], None, None, ) @@ -77,5 +77,5 @@ pub fn call_nac3_nditer_has_element<'ctx>( pub fn call_nac3_nditer_next<'ctx>(ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>) { let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_next"); - infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None); + infer_and_call_function(ctx, &name, None, &[iter.as_abi_value(ctx).into()], None, None); } diff --git a/nac3core/src/codegen/irrt/ndarray/transpose.rs b/nac3core/src/codegen/irrt/ndarray/transpose.rs index 6d152dd..331611f 100644 --- a/nac3core/src/codegen/irrt/ndarray/transpose.rs +++ b/nac3core/src/codegen/irrt/ndarray/transpose.rs @@ -34,8 +34,8 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( &name, None, &[ - src_ndarray.as_base_value().into(), - dst_ndarray.as_base_value().into(), + src_ndarray.as_abi_value(ctx).into(), + dst_ndarray.as_abi_value(ctx).into(), axes.map_or(llvm_usize.const_zero(), |axes| axes.size(ctx, generator)).into(), axes.map_or(llvm_usize.ptr_type(AddressSpace::default()).const_null(), |axes| { axes.base_ptr(ctx, generator) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 73a28b7..a188d1c 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -562,7 +562,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( *params.iter().next().unwrap().1, ); - ListType::new_with_generator(generator, ctx, element_type).as_base_type().into() + ListType::new_with_generator(generator, ctx, element_type).as_abi_type().into() } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { @@ -572,7 +572,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( ctx, module, generator, unifier, top_level, type_cache, dtype, ); - NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_base_type().into() + NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_abi_type().into() } _ => unreachable!( @@ -626,7 +626,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) }) .collect_vec(); - TupleType::new_with_generator(generator, ctx, &fields).as_base_type().into() + TupleType::new_with_generator(generator, ctx, &fields).as_abi_type().into() } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), @@ -800,7 +800,7 @@ pub fn gen_func_impl< Some(t) => t.as_basic_type_enum(), } }), - (primitives.range, RangeType::new_with_generator(generator, context).as_base_type().into()), + (primitives.range, RangeType::new_with_generator(generator, context).as_abi_type().into()), (primitives.exception, { let name = "Exception"; if let Some(t) = module.get_struct_type(name) { diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 3cdd1ef..2eec88d 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -44,7 +44,7 @@ pub fn gen_ndarray_empty<'ctx>( let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_empty(generator, context, &shape, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.zeros`. @@ -69,7 +69,7 @@ pub fn gen_ndarray_zeros<'ctx>( let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_zeros(generator, context, dtype, &shape, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.ones`. @@ -94,7 +94,7 @@ pub fn gen_ndarray_ones<'ctx>( let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_ones(generator, context, dtype, &shape, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.full`. @@ -127,7 +127,7 @@ pub fn gen_ndarray_full<'ctx>( fill_value_arg, None, ); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } pub fn gen_ndarray_array<'ctx>( @@ -166,7 +166,7 @@ pub fn gen_ndarray_array<'ctx>( .construct_numpy_array(generator, context, (obj_ty, obj_arg), copy, None) .atleast_nd(generator, context, ndims); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.eye`. @@ -225,7 +225,7 @@ pub fn gen_ndarray_eye<'ctx>( let ndarray = NDArrayType::new(context, llvm_dtype, 2) .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.identity`. @@ -253,7 +253,7 @@ pub fn gen_ndarray_identity<'ctx>( .unwrap(); let ndarray = NDArrayType::new(context, llvm_dtype, 2) .construct_numpy_identity(generator, context, dtype, n, None); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.copy`. @@ -274,7 +274,7 @@ pub fn gen_ndarray_copy<'ctx>( let this = NDArrayType::from_unifier_type(generator, context, this_ty) .map_value(this_arg.into_pointer_value(), None); let ndarray = this.make_copy(generator, context); - Ok(ndarray.as_base_value()) + Ok(ndarray.as_abi_value(context)) } /// Generates LLVM IR for `ndarray.fill`. diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index ecc0ba9..15c4654 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -447,7 +447,7 @@ fn test_classes_list_type_new() { let llvm_usize = generator.get_size_type(&ctx); let llvm_list = ListType::new_with_generator(&generator, &ctx, llvm_i32.into()); - assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok()); + assert!(ListType::is_representable(llvm_list.as_abi_type(), llvm_usize).is_ok()); } #[test] @@ -458,7 +458,7 @@ fn test_classes_range_type_new() { let llvm_usize = generator.get_size_type(&ctx); let llvm_range = RangeType::new_with_generator(&generator, &ctx); - assert!(RangeType::is_representable(llvm_range.as_base_type(), llvm_usize).is_ok()); + assert!(RangeType::is_representable(llvm_range.as_abi_type(), llvm_usize).is_ok()); } #[test] @@ -470,5 +470,5 @@ fn test_classes_ndarray_type_new() { let llvm_usize = generator.get_size_type(&ctx); let llvm_ndarray = NDArrayType::new_with_generator(&generator, &ctx, llvm_i32.into(), 2); - assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); + assert!(NDArrayType::is_representable(llvm_ndarray.as_abi_type(), llvm_usize).is_ok()); } diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 60015b8..f99ad5c 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -305,6 +305,7 @@ impl<'ctx> ListType<'ctx> { } impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = ListValue<'ctx>; @@ -344,12 +345,16 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 5865d63..abeab5b 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -38,8 +38,10 @@ pub mod utils; /// A LLVM type that is used to represent a corresponding type in NAC3. pub trait ProxyType<'ctx>: Into { - /// The LLVM type of which values of this type possess. This is usually a - /// [LLVM pointer type][PointerType] for any non-primitive types. + /// The ABI type of which values of this type possess. + type ABI: BasicType<'ctx>; + + /// The LLVM type of which values of this type possess. type Base: BasicType<'ctx>; /// The type of values represented by this type. @@ -118,4 +120,10 @@ pub trait ProxyType<'ctx>: Into { /// Returns the [base type][Self::Base] of this proxy. fn as_base_type(&self) -> Self::Base; + + /// Returns this proxy as its ABI type, i.e. the expected type representation if a value of this + /// [`ProxyType`] is being passed into or returned from a function. + /// + /// See [`CodeGenContext::get_llvm_abi_type`]. + fn as_abi_type(&self) -> Self::ABI; } diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index 7061112..9630ec1 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -151,7 +151,7 @@ impl<'ctx> NDArrayType<'ctx> { (list_ty, list), name, ); - Ok(Some(ndarray.as_base_value())) + Ok(Some(ndarray.as_abi_value(ctx))) }, |generator, ctx| { let ndarray = self.construct_numpy_array_from_list_copy_none_impl( @@ -160,7 +160,7 @@ impl<'ctx> NDArrayType<'ctx> { (list_ty, list), name, ); - Ok(Some(ndarray.as_base_value())) + Ok(Some(ndarray.as_abi_value(ctx))) }, ) .unwrap() @@ -189,11 +189,11 @@ impl<'ctx> NDArrayType<'ctx> { |_generator, _ctx| Ok(copy), |generator, ctx| { let ndarray = ndarray.make_copy(generator, ctx); // Force copy - Ok(Some(ndarray.as_base_value())) + Ok(Some(ndarray.as_abi_value(ctx))) }, - |_generator, _ctx| { + |_generator, ctx| { // No need to copy. Return `ndarray` itself. - Ok(Some(ndarray.as_base_value())) + Ok(Some(ndarray.as_abi_value(ctx))) }, ) .unwrap() diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs index af1a26f..40847ce 100644 --- a/nac3core/src/codegen/types/ndarray/broadcast.rs +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -127,6 +127,7 @@ impl<'ctx> ShapeEntryType<'ctx> { } impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = ShapeEntryValue<'ctx>; @@ -160,12 +161,16 @@ impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index 1987ab6..40311a5 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -189,6 +189,7 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { } impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = ContiguousNDArrayValue<'ctx>; @@ -230,12 +231,16 @@ impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 8e15c90..ec214ce 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -158,6 +158,7 @@ impl<'ctx> NDIndexType<'ctx> { } impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = NDIndexValue<'ctx>; @@ -188,12 +189,16 @@ impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index 1743fe2..a79a1f3 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -427,6 +427,7 @@ impl<'ctx> NDArrayType<'ctx> { } impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = NDArrayValue<'ctx>; @@ -458,12 +459,16 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index 6246eef..ba21a7e 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -185,6 +185,7 @@ impl<'ctx> NDIterType<'ctx> { } impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = NDIterValue<'ctx>; @@ -216,12 +217,16 @@ impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index 158152b..b6f15c7 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -72,7 +72,7 @@ impl<'ctx> RangeType<'ctx> { /// Returns the type of all fields of this `range` type. #[must_use] pub fn value_type(&self) -> IntType<'ctx> { - self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type() + self.as_abi_type().get_element_type().into_array_type().get_element_type().into_int_type() } /// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type. @@ -120,6 +120,7 @@ impl<'ctx> RangeType<'ctx> { } impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = RangeValue<'ctx>; @@ -163,12 +164,16 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index d05b7f2..29e9323 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -157,6 +157,7 @@ impl<'ctx> TupleType<'ctx> { } impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { + type ABI = StructType<'ctx>; type Base = StructType<'ctx>; type Value = TupleValue<'ctx>; @@ -182,6 +183,10 @@ impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for StructType<'ctx> { diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index b7fafef..e482ed5 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -174,6 +174,7 @@ impl<'ctx> SliceType<'ctx> { } impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { + type ABI = PointerType<'ctx>; type Base = PointerType<'ctx>; type Value = SliceValue<'ctx>; @@ -229,12 +230,16 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { } fn alloca_type(&self) -> impl BasicType<'ctx> { - self.as_base_type().get_element_type().into_struct_type() + self.as_abi_type().get_element_type().into_struct_type() } fn as_base_type(&self) -> Self::Base { self.ty } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } } impl<'ctx> From> for PointerType<'ctx> { diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 075f7f6..8b2b6cb 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -110,7 +110,7 @@ impl<'ctx> ListValue<'ctx> { let llvm_list_i8 = ::Type::new(ctx, &llvm_i8); Self::from_pointer_value( - ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_base_type(), "").unwrap(), + ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_abi_type(), "").unwrap(), self.llvm_usize, self.name, ) @@ -118,6 +118,7 @@ impl<'ctx> ListValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = ListType<'ctx>; @@ -128,6 +129,10 @@ impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index 9a24635..90f327e 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -1,6 +1,6 @@ use inkwell::{types::IntType, values::BasicValue}; -use super::types::ProxyType; +use super::{types::ProxyType, CodeGenContext}; pub use array::*; pub use list::*; pub use range::*; @@ -16,8 +16,10 @@ pub mod utils; /// A LLVM type that is used to represent a non-primitive value in NAC3. pub trait ProxyValue<'ctx>: Into { - /// The type of LLVM values represented by this instance. This is usually the - /// [LLVM pointer type][PointerValue]. + /// The ABI type of LLVM values represented by this instance. + type ABI: BasicValue<'ctx>; + + /// The type of LLVM values represented by this instance. type Base: BasicValue<'ctx>; /// The type of this value. @@ -33,4 +35,10 @@ pub trait ProxyValue<'ctx>: Into { /// Returns the [base value][Self::Base] of this proxy. fn as_base_value(&self) -> Self::Base; + + /// Returns this proxy as its ABI value, i.e. the expected value representation if a value + /// represented by this [`ProxyValue`] is being passed into or returned from a function. + /// + /// See [`CodeGenContext::get_llvm_abi_type`]. + fn as_abi_value(&self, ctx: &CodeGenContext<'ctx, '_>) -> Self::ABI; } diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index acbd299..883b461 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -58,6 +58,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = ShapeEntryType<'ctx>; @@ -68,6 +69,10 @@ impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 65e8025..a23be22 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -41,7 +41,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { - self.ndims_field().set(ctx, self.as_base_value(), value, self.name); + self.ndims_field().set(ctx, self.as_abi_value(ctx), value, self.name); } fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -49,7 +49,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.shape_field().set(ctx, self.as_base_value(), value, self.name); + self.shape_field().set(ctx, self.as_abi_value(ctx), value, self.name); } pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { @@ -61,7 +61,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.data_field().set(ctx, self.as_base_value(), value, self.name); + self.data_field().set(ctx, self.as_abi_value(ctx), value, self.name); } pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { @@ -70,6 +70,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = ContiguousNDArrayType<'ctx>; @@ -84,6 +85,10 @@ impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { @@ -124,7 +129,7 @@ impl<'ctx> NDArrayValue<'ctx> { |_, ctx| Ok(self.is_c_contiguous(ctx)), |_, ctx| { // This ndarray is contiguous. - let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name); + let data = self.data_field(ctx).get(ctx, self.as_abi_value(ctx), self.name); let data = ctx .builder .build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "") diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 3b7b8f1..0084671 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -68,6 +68,7 @@ impl<'ctx> NDIndexValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = NDIndexType<'ctx>; @@ -78,6 +79,10 @@ impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index cba35ad..e45fe85 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -108,7 +108,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Stores the array of dimension sizes `dims` into this instance. fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name); + self.shape_field(ctx).set(ctx, self.value, dims, self.name); } /// Convenience method for creating a new array storing dimension sizes with the given `size`. @@ -136,7 +136,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Stores the array of stride sizes `strides` into this instance. fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) { - self.strides_field(ctx).set(ctx, self.as_base_value(), strides, self.name); + self.strides_field(ctx).set(ctx, self.value, strides, self.name); } /// Convenience method for creating a new array storing the stride with the given `size`. @@ -171,7 +171,7 @@ impl<'ctx> NDArrayValue<'ctx> { .builder .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - self.data_field(ctx).set(ctx, self.as_base_value(), data.into_pointer_value(), self.name); + self.data_field(ctx).set(ctx, self.value, data.into_pointer_value(), self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -462,6 +462,7 @@ impl<'ctx> NDArrayValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = NDArrayType<'ctx>; @@ -477,6 +478,10 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { @@ -503,7 +508,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.shape_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) + self.0.shape_field(ctx).get(ctx, self.0.value, self.0.name) } fn size( @@ -601,7 +606,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.strides_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) + self.0.strides_field(ctx).get(ctx, self.0.value, self.0.name) } fn size( @@ -699,7 +704,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.data_field(ctx).get(ctx, self.0.as_base_value(), self.0.name) + self.0.data_field(ctx).get(ctx, self.0.value, self.0.name) } fn size( @@ -966,7 +971,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> { pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { match self { ScalarOrNDArray::Scalar(scalar) => scalar, - ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(), + ScalarOrNDArray::NDArray(ndarray) => ndarray.value.into(), } } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 5479b92..3fdd0a8 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -68,7 +68,7 @@ impl<'ctx> NDIterValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let elem_ty = self.parent.dtype; - let p = self.element_field(ctx).get(ctx, self.as_base_value(), self.name); + let p = self.element_field(ctx).get(ctx, self.as_abi_value(ctx), self.name); ctx.builder .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .unwrap() @@ -88,7 +88,7 @@ impl<'ctx> NDIterValue<'ctx> { /// Get the index of the current element if this ndarray were a flat ndarray. #[must_use] pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.nth_field(ctx).get(ctx, self.as_base_value(), self.name) + self.nth_field(ctx).get(ctx, self.as_abi_value(ctx), self.name) } /// Get the indices of the current element. @@ -105,6 +105,7 @@ impl<'ctx> NDIterValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = NDIterType<'ctx>; @@ -115,6 +116,10 @@ impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/range.rs b/nac3core/src/codegen/values/range.rs index b1a5806..20bdba7 100644 --- a/nac3core/src/codegen/values/range.rs +++ b/nac3core/src/codegen/values/range.rs @@ -34,7 +34,7 @@ impl<'ctx> RangeValue<'ctx> { unsafe { ctx.builder .build_in_bounds_gep( - self.as_base_value(), + self.as_abi_value(ctx), &[llvm_i32.const_zero(), llvm_i32.const_int(0, false)], var_name.as_str(), ) @@ -49,7 +49,7 @@ impl<'ctx> RangeValue<'ctx> { unsafe { ctx.builder .build_in_bounds_gep( - self.as_base_value(), + self.as_abi_value(ctx), &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], var_name.as_str(), ) @@ -64,7 +64,7 @@ impl<'ctx> RangeValue<'ctx> { unsafe { ctx.builder .build_in_bounds_gep( - self.as_base_value(), + self.as_abi_value(ctx), &[llvm_i32.const_zero(), llvm_i32.const_int(2, false)], var_name.as_str(), ) @@ -137,6 +137,7 @@ impl<'ctx> RangeValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = RangeType<'ctx>; @@ -147,6 +148,10 @@ impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs index 4558f18..08b2b8b 100644 --- a/nac3core/src/codegen/values/tuple.rs +++ b/nac3core/src/codegen/values/tuple.rs @@ -57,6 +57,7 @@ impl<'ctx> TupleValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> { + type ABI = StructValue<'ctx>; type Base = StructValue<'ctx>; type Type = TupleType<'ctx>; @@ -67,6 +68,10 @@ impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for StructValue<'ctx> { diff --git a/nac3core/src/codegen/values/utils/slice.rs b/nac3core/src/codegen/values/utils/slice.rs index df9e4de..21453f4 100644 --- a/nac3core/src/codegen/values/utils/slice.rs +++ b/nac3core/src/codegen/values/utils/slice.rs @@ -150,6 +150,7 @@ impl<'ctx> SliceValue<'ctx> { } impl<'ctx> ProxyValue<'ctx> for SliceValue<'ctx> { + type ABI = PointerValue<'ctx>; type Base = PointerValue<'ctx>; type Type = SliceType<'ctx>; @@ -160,6 +161,10 @@ impl<'ctx> ProxyValue<'ctx> for SliceValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } } impl<'ctx> From> for PointerValue<'ctx> { diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 1c3b085..165f64a 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -664,7 +664,7 @@ impl<'a> BuiltinBuilder<'a> { zelf.store_end(ctx, stop); zelf.store_step(ctx, step); - Ok(Some(zelf.as_base_value().into())) + Ok(Some(zelf.as_abi_value(ctx).into())) }, )))), loc: None, @@ -1320,7 +1320,7 @@ impl<'a> BuiltinBuilder<'a> { _ => unreachable!(), }; - Ok(Some(result_tuple.as_base_value().into())) + Ok(Some(result_tuple.as_abi_value(ctx).into())) }), ) } @@ -1356,7 +1356,7 @@ impl<'a> BuiltinBuilder<'a> { .map_value(arg_val.into_pointer_value(), None); let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument - Ok(Some(ndarray.as_base_value().into())) + Ok(Some(ndarray.as_abi_value(ctx).into())) }), ), @@ -1410,7 +1410,7 @@ impl<'a> BuiltinBuilder<'a> { _ => unreachable!(), }; - Ok(Some(new_ndarray.as_base_value().as_basic_value_enum())) + Ok(Some(new_ndarray.as_abi_value(ctx).as_basic_value_enum())) }), ) } From eec62c3bbb5e63a8bb5c256bba4514fbd8542f58 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 24 Jan 2025 10:53:14 +0800 Subject: [PATCH 31/49] [core] codegen: Refactor StructField getters and setters --- nac3core/src/codegen/types/structure.rs | 39 ++++++++++++++----- nac3core/src/codegen/values/list.rs | 6 +-- .../src/codegen/values/ndarray/broadcast.rs | 4 +- .../src/codegen/values/ndarray/contiguous.rs | 12 +++--- .../src/codegen/values/ndarray/indexing.rs | 8 ++-- nac3core/src/codegen/values/ndarray/mod.rs | 16 ++++---- nac3core/src/codegen/values/ndarray/nditer.rs | 4 +- nac3core/src/codegen/values/utils/slice.rs | 30 +++++++------- 8 files changed, 70 insertions(+), 49 deletions(-) diff --git a/nac3core/src/codegen/types/structure.rs b/nac3core/src/codegen/types/structure.rs index 0e35c81..d2622b0 100644 --- a/nac3core/src/codegen/types/structure.rs +++ b/nac3core/src/codegen/types/structure.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use inkwell::{ context::AsContextRef, types::{BasicTypeEnum, IntType, PointerType, StructType}, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, + values::{AggregateValueEnum, BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -203,17 +203,38 @@ where /// Gets the value of this field for a given `obj`. #[must_use] - pub fn get_from_value(&self, obj: StructValue<'ctx>) -> Value { - obj.get_field_at_index(self.index).and_then(|value| Value::try_from(value).ok()).unwrap() + pub fn extract_value(&self, ctx: &CodeGenContext<'ctx, '_>, obj: StructValue<'ctx>) -> Value { + Value::try_from( + ctx.builder + .build_extract_value( + obj, + self.index, + &format!("{}.{}", obj.get_name().to_str().unwrap(), self.name), + ) + .unwrap(), + ) + .unwrap() } /// Sets the value of this field for a given `obj`. - pub fn set_for_value(&self, obj: StructValue<'ctx>, value: Value) { - obj.set_field_at_index(self.index, value); + #[must_use] + pub fn insert_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + obj: StructValue<'ctx>, + value: Value, + ) -> StructValue<'ctx> { + let obj_name = obj.get_name().to_str().unwrap(); + let new_obj_name = if obj_name.chars().all(char::is_numeric) { "" } else { obj_name }; + + ctx.builder + .build_insert_value(obj, value, self.index, new_obj_name) + .map(AggregateValueEnum::into_struct_value) + .unwrap() } - /// Gets the value of this field for a pointer-to-structure. - pub fn get( + /// Loads the value of this field for a pointer-to-structure. + pub fn load( &self, ctx: &CodeGenContext<'ctx, '_>, pobj: PointerValue<'ctx>, @@ -229,8 +250,8 @@ where .unwrap() } - /// Sets the value of this field for a pointer-to-structure. - pub fn set( + /// Stores the value of this field for a pointer-to-structure. + pub fn store( &self, ctx: &CodeGenContext<'ctx, '_>, pobj: PointerValue<'ctx>, diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 8b2b6cb..cdd1a41 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -45,7 +45,7 @@ impl<'ctx> ListValue<'ctx> { /// Stores the array of data elements `data` into this instance. fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { - self.items_field(ctx).set(ctx, self.value, data, self.name); + self.items_field(ctx).store(ctx, self.value, data, self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -91,7 +91,7 @@ impl<'ctx> ListValue<'ctx> { pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) { debug_assert_eq!(size.get_type(), ctx.get_size_type()); - self.len_field(ctx).set(ctx, self.value, size, self.name); + self.len_field(ctx).store(ctx, self.value, size, self.name); } /// Returns the size of this `list` as a value. @@ -100,7 +100,7 @@ impl<'ctx> ListValue<'ctx> { ctx: &CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> IntValue<'ctx> { - self.len_field(ctx).get(ctx, self.value, name) + self.len_field(ctx).load(ctx, self.value, name) } /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index 883b461..e30bfae 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -44,7 +44,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { /// Stores the number of dimensions into this value. pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { - self.ndims_field().set(ctx, self.value, value, self.name); + self.ndims_field().store(ctx, self.value, value, self.name); } fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -53,7 +53,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { /// Stores the shape into this value. pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.shape_field().set(ctx, self.value, value, self.name); + self.shape_field().store(ctx, self.value, value, self.name); } } diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index a23be22..b8bf0af 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -41,7 +41,7 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { - self.ndims_field().set(ctx, self.as_abi_value(ctx), value, self.name); + self.ndims_field().store(ctx, self.as_abi_value(ctx), value, self.name); } fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -49,11 +49,11 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.shape_field().set(ctx, self.as_abi_value(ctx), value, self.name); + self.shape_field().store(ctx, self.as_abi_value(ctx), value, self.name); } pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.shape_field().get(ctx, self.value, self.name) + self.shape_field().load(ctx, self.value, self.name) } fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -61,11 +61,11 @@ impl<'ctx> ContiguousNDArrayValue<'ctx> { } pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.data_field().set(ctx, self.as_abi_value(ctx), value, self.name); + self.data_field().store(ctx, self.as_abi_value(ctx), value, self.name); } pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.data_field().get(ctx, self.value, self.name) + self.data_field().load(ctx, self.value, self.name) } } @@ -129,7 +129,7 @@ impl<'ctx> NDArrayValue<'ctx> { |_, ctx| Ok(self.is_c_contiguous(ctx)), |_, ctx| { // This ndarray is contiguous. - let data = self.data_field(ctx).get(ctx, self.as_abi_value(ctx), self.name); + let data = self.data_field(ctx).load(ctx, self.as_abi_value(ctx), self.name); let data = ctx .builder .build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "") diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 0084671..49fdfe1 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -47,11 +47,11 @@ impl<'ctx> NDIndexValue<'ctx> { } pub fn load_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.type_field().get(ctx, self.value, self.name) + self.type_field().load(ctx, self.value, self.name) } pub fn store_type(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { - self.type_field().set(ctx, self.value, value, self.name); + self.type_field().store(ctx, self.value, value, self.name); } fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { @@ -59,11 +59,11 @@ impl<'ctx> NDIndexValue<'ctx> { } pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.data_field().get(ctx, self.value, self.name) + self.data_field().load(ctx, self.value, self.name) } pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - self.data_field().set(ctx, self.value, value, self.name); + self.data_field().store(ctx, self.value, value, self.name); } } diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index e45fe85..38c87e0 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -94,12 +94,12 @@ impl<'ctx> NDArrayValue<'ctx> { pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, itemsize: IntValue<'ctx>) { debug_assert_eq!(itemsize.get_type(), ctx.get_size_type()); - self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name); + self.itemsize_field(ctx).store(ctx, self.value, itemsize, self.name); } /// Returns the size of each element of this `NDArray` as a value. pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.itemsize_field(ctx).get(ctx, self.value, self.name) + self.itemsize_field(ctx).load(ctx, self.value, self.name) } fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { @@ -108,7 +108,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Stores the array of dimension sizes `dims` into this instance. fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - self.shape_field(ctx).set(ctx, self.value, dims, self.name); + self.shape_field(ctx).store(ctx, self.value, dims, self.name); } /// Convenience method for creating a new array storing dimension sizes with the given `size`. @@ -136,7 +136,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Stores the array of stride sizes `strides` into this instance. fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) { - self.strides_field(ctx).set(ctx, self.value, strides, self.name); + self.strides_field(ctx).store(ctx, self.value, strides, self.name); } /// Convenience method for creating a new array storing the stride with the given `size`. @@ -171,7 +171,7 @@ impl<'ctx> NDArrayValue<'ctx> { .builder .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - self.data_field(ctx).set(ctx, self.value, data.into_pointer_value(), self.name); + self.data_field(ctx).store(ctx, self.value, data.into_pointer_value(), self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -508,7 +508,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.shape_field(ctx).get(ctx, self.0.value, self.0.name) + self.0.shape_field(ctx).load(ctx, self.0.value, self.0.name) } fn size( @@ -606,7 +606,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.strides_field(ctx).get(ctx, self.0.value, self.0.name) + self.0.strides_field(ctx).load(ctx, self.0.value, self.0.name) } fn size( @@ -704,7 +704,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.data_field(ctx).get(ctx, self.0.value, self.0.name) + self.0.data_field(ctx).load(ctx, self.0.value, self.0.name) } fn size( diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 3fdd0a8..e485574 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -68,7 +68,7 @@ impl<'ctx> NDIterValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let elem_ty = self.parent.dtype; - let p = self.element_field(ctx).get(ctx, self.as_abi_value(ctx), self.name); + let p = self.element_field(ctx).load(ctx, self.as_abi_value(ctx), self.name); ctx.builder .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .unwrap() @@ -88,7 +88,7 @@ impl<'ctx> NDIterValue<'ctx> { /// Get the index of the current element if this ndarray were a flat ndarray. #[must_use] pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.nth_field(ctx).get(ctx, self.as_abi_value(ctx), self.name) + self.nth_field(ctx).load(ctx, self.as_abi_value(ctx), self.name) } /// Get the indices of the current element. diff --git a/nac3core/src/codegen/values/utils/slice.rs b/nac3core/src/codegen/values/utils/slice.rs index 21453f4..f773935 100644 --- a/nac3core/src/codegen/values/utils/slice.rs +++ b/nac3core/src/codegen/values/utils/slice.rs @@ -42,7 +42,7 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_start_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.start_defined_field().get(ctx, self.value, self.name) + self.start_defined_field().load(ctx, self.value, self.name) } fn start_field(&self) -> StructField<'ctx, IntValue<'ctx>> { @@ -50,22 +50,22 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.start_field().get(ctx, self.value, self.name) + self.start_field().load(ctx, self.value, self.name) } pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option>) { match value { Some(start) => { - self.start_defined_field().set( + self.start_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_all_ones(), self.name, ); - self.start_field().set(ctx, self.value, start, self.name); + self.start_field().store(ctx, self.value, start, self.name); } - None => self.start_defined_field().set( + None => self.start_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_zero(), @@ -79,7 +79,7 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_stop_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.stop_defined_field().get(ctx, self.value, self.name) + self.stop_defined_field().load(ctx, self.value, self.name) } fn stop_field(&self) -> StructField<'ctx, IntValue<'ctx>> { @@ -87,22 +87,22 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_stop(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.stop_field().get(ctx, self.value, self.name) + self.stop_field().load(ctx, self.value, self.name) } pub fn store_stop(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option>) { match value { Some(stop) => { - self.stop_defined_field().set( + self.stop_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_all_ones(), self.name, ); - self.stop_field().set(ctx, self.value, stop, self.name); + self.stop_field().store(ctx, self.value, stop, self.name); } - None => self.stop_defined_field().set( + None => self.stop_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_zero(), @@ -116,7 +116,7 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_step_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.step_defined_field().get(ctx, self.value, self.name) + self.step_defined_field().load(ctx, self.value, self.name) } fn step_field(&self) -> StructField<'ctx, IntValue<'ctx>> { @@ -124,22 +124,22 @@ impl<'ctx> SliceValue<'ctx> { } pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.step_field().get(ctx, self.value, self.name) + self.step_field().load(ctx, self.value, self.name) } pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option>) { match value { Some(step) => { - self.step_defined_field().set( + self.step_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_all_ones(), self.name, ); - self.step_field().set(ctx, self.value, step, self.name); + self.step_field().store(ctx, self.value, step, self.name); } - None => self.step_defined_field().set( + None => self.step_defined_field().store( ctx, self.value, ctx.ctx.bool_type().const_zero(), From 68da9b0ecff29a203bcd9539a204e05130e74420 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 27 Jan 2025 23:23:05 +0800 Subject: [PATCH 32/49] [core] codegen: Implement StructProxy on existing proxies --- nac3artiq/src/codegen.rs | 8 +- nac3artiq/src/symbol_resolver.rs | 20 ++-- nac3core/src/codegen/builtin_fns.rs | 59 +++++++----- nac3core/src/codegen/expr.rs | 6 +- nac3core/src/codegen/numpy.rs | 10 +- nac3core/src/codegen/stmt.rs | 4 +- nac3core/src/codegen/types/list.rs | 54 ++++++++--- nac3core/src/codegen/types/ndarray/array.rs | 8 +- .../src/codegen/types/ndarray/broadcast.rs | 53 ++++++++--- .../src/codegen/types/ndarray/contiguous.rs | 57 +++++++++--- .../src/codegen/types/ndarray/indexing.rs | 49 +++++++--- nac3core/src/codegen/types/ndarray/mod.rs | 60 +++++++++--- nac3core/src/codegen/types/ndarray/nditer.rs | 59 +++++++++--- nac3core/src/codegen/types/range.rs | 35 ++++++- nac3core/src/codegen/types/tuple.rs | 37 ++++++-- nac3core/src/codegen/types/utils/slice.rs | 78 ++++++++++++---- nac3core/src/codegen/values/list.rs | 61 +++++++----- .../src/codegen/values/ndarray/broadcast.rs | 39 ++++++-- .../src/codegen/values/ndarray/contiguous.rs | 34 ++++++- .../src/codegen/values/ndarray/indexing.rs | 32 ++++++- nac3core/src/codegen/values/ndarray/mod.rs | 92 +++++++++++-------- nac3core/src/codegen/values/ndarray/nditer.rs | 52 ++++++++--- nac3core/src/codegen/values/ndarray/shape.rs | 4 +- nac3core/src/codegen/values/range.rs | 26 +++++- nac3core/src/codegen/values/tuple.rs | 22 ++++- nac3core/src/codegen/values/utils/slice.rs | 34 ++++++- nac3core/src/toplevel/builtins.rs | 10 +- 27 files changed, 729 insertions(+), 274 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 2cc5438..572accc 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -476,8 +476,8 @@ fn format_rpc_arg<'ctx>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let dtype = ctx.get_llvm_type(generator, elem_ty); - let ndarray = - NDArrayType::new(ctx, dtype, ndims).map_value(arg.into_pointer_value(), None); + let ndarray = NDArrayType::new(ctx, dtype, ndims) + .map_pointer_value(arg.into_pointer_value(), None); let ndims = llvm_usize.const_int(ndims, false); @@ -1383,7 +1383,7 @@ fn polymorphic_print<'ctx>( let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty) - .map_value(value.into_pointer_value(), None); + .map_pointer_value(value.into_pointer_value(), None); let num_0 = llvm_usize.const_zero(); @@ -1431,7 +1431,7 @@ fn polymorphic_print<'ctx>( fmt.push_str("range("); flush(ctx, generator, &mut fmt, &mut args); - let val = RangeType::new(ctx).map_value(value.into_pointer_value(), None); + let val = RangeType::new(ctx).map_pointer_value(value.into_pointer_value(), None); let (start, stop, step) = destructure_range(ctx, val); diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 06a9400..f14a8ee 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -16,7 +16,7 @@ use pyo3::{ use super::PrimitivePythonId; use nac3core::{ codegen::{ - types::{ndarray::NDArrayType, ProxyType}, + types::{ndarray::NDArrayType, structure::StructProxyType, ProxyType}, values::ndarray::make_contiguous_strides, CodeGenContext, CodeGenerator, }, @@ -1315,17 +1315,13 @@ impl InnerResolver { .unwrap() }; - let ndarray = llvm_ndarray - .as_abi_type() - .get_element_type() - .into_struct_type() - .const_named_struct(&[ - ndarray_itemsize.into(), - ndarray_ndims.into(), - ndarray_shape.into(), - ndarray_strides.into(), - ndarray_data.into(), - ]); + let ndarray = llvm_ndarray.get_struct_type().const_named_struct(&[ + ndarray_itemsize.into(), + ndarray_ndims.into(), + ndarray_shape.into(), + ndarray_strides.into(), + ndarray_data.into(), + ]); let ndarray_global = ctx.module.add_global( llvm_ndarray.as_abi_type().get_element_type().into_struct_type(), diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 20a89d0..dfb9082 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -47,14 +47,14 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( let range_ty = ctx.primitives.range; Ok(if ctx.unifier.unioned(arg_ty, range_ty) { - let arg = RangeType::new(ctx).map_value(arg.into_pointer_value(), Some("range")); + let arg = RangeType::new(ctx).map_pointer_value(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); calculate_len_for_slice_range(generator, ctx, start, end, step) } else { match &*ctx.unifier.get_ty_immutable(arg_ty) { TypeEnum::TTuple { .. } => { let tuple = TupleType::from_unifier_type(generator, ctx, arg_ty) - .map_value(arg.into_struct_value(), None); + .map_struct_value(arg.into_struct_value(), None); llvm_i32.const_int(tuple.get_type().num_elements().into(), false) } @@ -62,7 +62,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) - .map_value(arg.into_pointer_value(), None); + .map_pointer_value(arg.into_pointer_value(), None); ctx.builder .build_int_truncate_or_bit_cast(ndarray.len(ctx), llvm_i32, "len") .unwrap() @@ -72,7 +72,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => { let list = ListType::from_unifier_type(generator, ctx, arg_ty) - .map_value(arg.into_pointer_value(), None); + .map_pointer_value(arg.into_pointer_value(), None); ctx.builder .build_int_truncate_or_bit_cast(list.load_size(ctx, None), llvm_i32, "len") .unwrap() @@ -126,7 +126,8 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -186,7 +187,8 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -262,7 +264,8 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -327,7 +330,8 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -391,7 +395,8 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -435,7 +440,8 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -474,7 +480,8 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -536,7 +543,8 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -587,7 +595,8 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -637,7 +646,8 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, n_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, n_ty).map_pointer_value(n, None); let result = ndarray .map( @@ -858,7 +868,8 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let ndarray = NDArrayType::from_unifier_type(generator, ctx, a_ty).map_value(n, None); + let ndarray = + NDArrayType::from_unifier_type(generator, ctx, a_ty).map_pointer_value(n, None); let llvm_dtype = ndarray.get_type().element_type(); let zero = llvm_usize.const_zero(); @@ -1638,7 +1649,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1672,7 +1683,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1727,7 +1738,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1785,7 +1796,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1820,7 +1831,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1865,7 +1876,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -1974,7 +1985,7 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); if !x1.get_type().element_type().is_float_type() { unsupported_type(ctx, FN_NAME, &[x1_ty]); @@ -2013,7 +2024,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); assert_eq!(x1.get_type().ndims(), 2); if !x1.get_type().element_type().is_float_type() { @@ -2061,7 +2072,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(x1, None); assert_eq!(x1.get_type().ndims(), 2); if !x1.get_type().element_type().is_float_type() { diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 7a1d42f..20d296e 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1151,7 +1151,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { let iter_val = - RangeType::new(ctx).map_value(iter_val.into_pointer_value(), Some("range")); + RangeType::new(ctx).map_pointer_value(iter_val.into_pointer_value(), Some("range")); let (start, stop, step) = destructure_range(ctx, iter_val); let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap(); // add 1 to the length as the value is rounded to zero @@ -1767,7 +1767,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty) - .map_value(val.into_pointer_value(), None); + .map_pointer_value(val.into_pointer_value(), None); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function @@ -3043,7 +3043,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let ndarray_ty = value.custom.unwrap(); let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?; let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) - .map_value(ndarray.into_pointer_value(), None); + .map_pointer_value(ndarray.into_pointer_value(), None); let indices = RustNDIndex::from_subscript_expr(generator, ctx, slice)?; let result = ndarray diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 2eec88d..dfb1b4d 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -272,7 +272,7 @@ pub fn gen_ndarray_copy<'ctx>( obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; let this = NDArrayType::from_unifier_type(generator, context, this_ty) - .map_value(this_arg.into_pointer_value(), None); + .map_pointer_value(this_arg.into_pointer_value(), None); let ndarray = this.make_copy(generator, context); Ok(ndarray.as_abi_value(context)) } @@ -295,7 +295,7 @@ pub fn gen_ndarray_fill<'ctx>( let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; let this = NDArrayType::from_unifier_type(generator, context, this_ty) - .map_value(this_arg.into_pointer_value(), None); + .map_pointer_value(this_arg.into_pointer_value(), None); this.fill(generator, context, value_arg); Ok(()) } @@ -316,8 +316,10 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( match (x1, x2) { (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let a = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); - let b = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); + let a = + NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_pointer_value(n1, None); + let b = + NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_pointer_value(n2, None); // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. assert_eq!(a.get_type().ndims(), 1); diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index e8f1d90..0c1b931 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -440,7 +440,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( // ``` let target = NDArrayType::from_unifier_type(generator, ctx, target_ty) - .map_value(target.into_pointer_value(), None); + .map_pointer_value(target.into_pointer_value(), None); let target = target.index(generator, ctx, &key); let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value)) @@ -511,7 +511,7 @@ pub fn gen_for( if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { let iter_val = - RangeType::new(ctx).map_value(iter_val.into_pointer_value(), Some("range")); + RangeType::new(ctx).map_pointer_value(iter_val.into_pointer_value(), Some("range")); // Internal variable for loop; Cannot be assigned let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index f99ad5c..b4110da 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -1,7 +1,7 @@ use inkwell::{ - context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + context::Context, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, OptimizationLevel, }; use itertools::Itertools; @@ -13,8 +13,9 @@ use crate::{ codegen::{ types::structure::{ check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + StructProxyType, }, - values::{ListValue, ProxyValue}, + values::ListValue, CodeGenContext, CodeGenerator, }, typecheck::typedef::{iter_type_vars, Type, TypeEnum}, @@ -62,13 +63,6 @@ impl<'ctx> ListType<'ctx> { ListStructFields::new_typed(item, llvm_usize) } - /// See [`ListType::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self, _ctx: &impl AsContextRef<'ctx>) -> ListStructFields<'ctx> { - Self::fields(self.item.unwrap_or(self.llvm_usize.into()), self.llvm_usize) - } - /// Creates an LLVM type corresponding to the expected structure of a `List`. #[must_use] fn llvm_type( @@ -153,9 +147,15 @@ impl<'ctx> ListType<'ctx> { Self::new_impl(ctx.ctx, llvm_elem_type, llvm_usize) } + /// Creates an [`ListType`] from a [`StructType`]. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + /// Creates an [`ListType`] from a [`PointerType`]. #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); let ctx = ptr_ty.get_context(); @@ -295,9 +295,27 @@ impl<'ctx> ListType<'ctx> { /// Converts an existing value into a [`ListValue`]. #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ListValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value(value, self.llvm_usize, name) @@ -357,6 +375,14 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for ListType<'ctx> { + type StructFields = ListStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.item.unwrap_or(self.llvm_usize.into()), self.llvm_usize) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: ListType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index 9630ec1..633a0b4 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -167,7 +167,7 @@ impl<'ctx> NDArrayType<'ctx> { .map(BasicValueEnum::into_pointer_value) .unwrap(); - NDArrayType::new(ctx, dtype, ndims).map_value(ndarray, None) + NDArrayType::new(ctx, dtype, ndims).map_pointer_value(ndarray, None) } /// Implementation of `np_array(, copy=copy)`. @@ -200,7 +200,7 @@ impl<'ctx> NDArrayType<'ctx> { .map(BasicValueEnum::into_pointer_value) .unwrap(); - ndarray.get_type().map_value(ndarray_val, name) + ndarray.get_type().map_pointer_value(ndarray_val, name) } /// Create a new ndarray like @@ -222,7 +222,7 @@ impl<'ctx> NDArrayType<'ctx> { if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => { let list = ListType::from_unifier_type(generator, ctx, object_ty) - .map_value(object.into_pointer_value(), None); + .map_pointer_value(object.into_pointer_value(), None); self.construct_numpy_array_list_impl(generator, ctx, (object_ty, list), copy, name) } @@ -230,7 +230,7 @@ impl<'ctx> NDArrayType<'ctx> { if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) - .map_value(object.into_pointer_value(), None); + .map_pointer_value(object.into_pointer_value(), None); self.construct_numpy_array_ndarray_impl(generator, ctx, ndarray, copy, name) } diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs index 40847ce..fa532b4 100644 --- a/nac3core/src/codegen/types/ndarray/broadcast.rs +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -10,10 +10,10 @@ use nac3core_derive::StructFields; use crate::codegen::{ types::{ - structure::{check_struct_type_matches_fields, StructField, StructFields}, + structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType}, ProxyType, }, - values::{ndarray::ShapeEntryValue, ProxyValue}, + values::ndarray::ShapeEntryValue, CodeGenContext, CodeGenerator, }; @@ -41,13 +41,6 @@ impl<'ctx> ShapeEntryType<'ctx> { ShapeEntryStructFields::new(ctx, llvm_usize) } - /// See [`ShapeEntryStructFields::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> ShapeEntryStructFields<'ctx> { - Self::fields(ctx, self.llvm_usize) - } - /// Creates an LLVM type corresponding to the expected structure of a `ShapeEntry`. #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { @@ -78,9 +71,15 @@ impl<'ctx> ShapeEntryType<'ctx> { Self::new_impl(ctx, generator.get_size_type(ctx)) } + /// Creates a [`ShapeEntryType`] from a [`StructType`] representing an `ShapeEntry`. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + /// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`. #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } @@ -117,9 +116,27 @@ impl<'ctx> ShapeEntryType<'ctx> { /// Converts an existing value into a [`ShapeEntryValue`]. #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ShapeEntryValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value(value, self.llvm_usize, name) @@ -173,6 +190,14 @@ impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for ShapeEntryType<'ctx> { + type StructFields = ShapeEntryStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: ShapeEntryType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index 40311a5..1857536 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -1,7 +1,7 @@ use inkwell::{ context::Context, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -13,10 +13,11 @@ use crate::{ types::{ structure::{ check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + StructProxyType, }, ProxyType, }, - values::{ndarray::ContiguousNDArrayValue, ProxyValue}, + values::ndarray::ContiguousNDArrayValue, CodeGenContext, CodeGenerator, }, toplevel::numpy::unpack_ndarray_var_tys, @@ -67,13 +68,6 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { ContiguousNDArrayStructFields::new_typed(item, llvm_usize) } - /// See [`NDArrayType::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self) -> ContiguousNDArrayStructFields<'ctx> { - Self::fields(self.item, self.llvm_usize) - } - /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. #[must_use] fn llvm_type( @@ -123,9 +117,19 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { Self::new_impl(ctx.ctx, llvm_dtype, ctx.get_size_type()) } + /// Creates an [`ContiguousNDArrayType`] from a [`StructType`] representing an `NDArray`. + #[must_use] + pub fn from_struct_type( + ty: StructType<'ctx>, + item: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), item, llvm_usize) + } + /// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`. #[must_use] - pub fn from_type( + pub fn from_pointer_type( ptr_ty: PointerType<'ctx>, item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, @@ -174,9 +178,28 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { /// Converts an existing value into a [`ContiguousNDArrayValue`]. #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.item, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ContiguousNDArrayValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( @@ -243,6 +266,14 @@ impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for ContiguousNDArrayType<'ctx> { + type StructFields = ContiguousNDArrayStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.item, self.llvm_usize) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: ContiguousNDArrayType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index ec214ce..d00e0fb 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -10,12 +10,12 @@ use nac3core_derive::StructFields; use crate::codegen::{ types::{ - structure::{check_struct_type_matches_fields, StructField, StructFields}, + structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType}, ProxyType, }, values::{ ndarray::{NDIndexValue, RustNDIndex}, - ArrayLikeIndexer, ArraySliceValue, ProxyValue, + ArrayLikeIndexer, ArraySliceValue, }, CodeGenContext, CodeGenerator, }; @@ -43,11 +43,6 @@ impl<'ctx> NDIndexType<'ctx> { NDIndexStructFields::new(ctx, llvm_usize) } - #[must_use] - pub fn get_fields(&self) -> NDIndexStructFields<'ctx> { - Self::fields(self.ty.get_context(), self.llvm_usize) - } - #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { let field_tys = @@ -76,7 +71,12 @@ impl<'ctx> NDIndexType<'ctx> { } #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } @@ -148,9 +148,26 @@ impl<'ctx> NDIndexType<'ctx> { } #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value(value, self.llvm_usize, name) @@ -201,6 +218,14 @@ impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for NDIndexType<'ctx> { + type StructFields = NDIndexStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: NDIndexType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index a79a1f3..28ea527 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{BasicValue, IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{BasicValue, IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -9,12 +9,12 @@ use itertools::Itertools; use nac3core_derive::StructFields; use super::{ - structure::{check_struct_type_matches_fields, StructField, StructFields}, + structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType}, ProxyType, }; use crate::{ codegen::{ - values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeMutator}, + values::{ndarray::NDArrayValue, TypedArrayLikeMutator}, {CodeGenContext, CodeGenerator}, }, toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, @@ -71,13 +71,6 @@ impl<'ctx> NDArrayType<'ctx> { NDArrayStructFields::new(ctx, llvm_usize) } - /// See [`NDArrayType::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> NDArrayStructFields<'ctx> { - Self::fields(ctx, self.llvm_usize) - } - /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { @@ -183,9 +176,20 @@ impl<'ctx> NDArrayType<'ctx> { Self::new_impl(ctx.ctx, llvm_dtype, ndims, ctx.get_size_type()) } + /// Creates an [`NDArrayType`] from a [`StructType`] representing an `NDArray`. + #[must_use] + pub fn from_struct_type( + ty: StructType<'ctx>, + dtype: BasicTypeEnum<'ctx>, + ndims: u64, + llvm_usize: IntType<'ctx>, + ) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), dtype, ndims, llvm_usize) + } + /// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`. #[must_use] - pub fn from_type( + pub fn from_pointer_type( ptr_ty: PointerType<'ctx>, dtype: BasicTypeEnum<'ctx>, ndims: u64, @@ -411,9 +415,29 @@ impl<'ctx> NDArrayType<'ctx> { /// Converts an existing value into a [`NDArrayValue`]. #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.dtype, + self.ndims, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`NDArrayValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( @@ -471,6 +495,14 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for NDArrayType<'ctx> { + type StructFields = NDArrayStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: NDArrayType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index ba21a7e..aec1a6f 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{IntValue, PointerValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -11,7 +11,9 @@ use nac3core_derive::StructFields; use super::ProxyType; use crate::codegen::{ irrt, - types::structure::{check_struct_type_matches_fields, StructField, StructFields}, + types::structure::{ + check_struct_type_matches_fields, StructField, StructFields, StructProxyType, + }, values::{ ndarray::{NDArrayValue, NDIterValue}, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAdapter, @@ -50,13 +52,6 @@ impl<'ctx> NDIterType<'ctx> { NDIterStructFields::new(ctx, llvm_usize) } - /// See [`NDIterType::fields`]. - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> NDIterStructFields<'ctx> { - Self::fields(ctx, self.llvm_usize) - } - /// Creates an LLVM type corresponding to the expected structure of an `NDIter`. #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { @@ -87,9 +82,15 @@ impl<'ctx> NDIterType<'ctx> { Self::new_impl(ctx, generator.get_size_type(ctx)) } + /// Creates an [`NDIterType`] from a [`StructType`] representing an `NDIter`. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + /// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`. #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); Self { ty: ptr_ty, llvm_usize } @@ -159,7 +160,8 @@ impl<'ctx> NDIterType<'ctx> { let indices = TypedArrayLikeAdapter::from(indices, |_, _, v| v.into_int_value(), |_, _, v| v.into()); - let nditer = self.map_value(nditer, ndarray, indices.as_slice_value(ctx, generator), None); + let nditer = + self.map_pointer_value(nditer, ndarray, indices.as_slice_value(ctx, generator), None); irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, &indices); @@ -167,9 +169,30 @@ impl<'ctx> NDIterType<'ctx> { } #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + parent, + indices, + self.llvm_usize, + name, + ) + } + + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, parent: NDArrayValue<'ctx>, indices: ArraySliceValue<'ctx>, name: Option<&'ctx str>, @@ -229,6 +252,14 @@ impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for NDIterType<'ctx> { + type StructFields = NDIterStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: NDIterType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index b6f15c7..e8f6f4d 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -1,13 +1,14 @@ use inkwell::{ context::Context, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, + types::{AnyTypeEnum, ArrayType, BasicType, BasicTypeEnum, IntType, PointerType}, + values::{ArrayValue, PointerValue}, AddressSpace, }; use super::ProxyType; use crate::{ codegen::{ - values::{ProxyValue, RangeValue}, + values::RangeValue, {CodeGenContext, CodeGenerator}, }, typecheck::typedef::{Type, TypeEnum}, @@ -61,9 +62,15 @@ impl<'ctx> RangeType<'ctx> { Self::new(ctx) } + /// Creates an [`RangeType`] from a [`ArrayType`]. + #[must_use] + pub fn from_array_type(arr_ty: ArrayType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(arr_ty.ptr_type(AddressSpace::default()), llvm_usize) + } + /// Creates an [`RangeType`] from a [`PointerType`]. #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); RangeType { ty: ptr_ty, llvm_usize } @@ -110,9 +117,27 @@ impl<'ctx> RangeType<'ctx> { /// Converts an existing value into a [`RangeValue`]. #[must_use] - pub fn map_value( + pub fn map_array_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: ArrayValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_array_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`RangeValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value(value, self.llvm_usize, name) diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index 29e9323..ea66feb 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -1,16 +1,13 @@ use inkwell::{ context::Context, - types::{BasicType, BasicTypeEnum, IntType, StructType}, - values::BasicValueEnum, + types::{BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{BasicValueEnum, PointerValue, StructValue}, }; use itertools::Itertools; use super::ProxyType; use crate::{ - codegen::{ - values::{ProxyValue, TupleValue}, - CodeGenContext, CodeGenerator, - }, + codegen::{values::TupleValue, CodeGenContext, CodeGenerator}, typecheck::typedef::{Type, TypeEnum}, }; @@ -77,12 +74,18 @@ impl<'ctx> TupleType<'ctx> { /// Creates an [`TupleType`] from a [`StructType`]. #[must_use] - pub fn from_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_struct_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::has_same_repr(struct_ty, llvm_usize).is_ok()); TupleType { ty: struct_ty, llvm_usize } } + /// Creates an [`TupleType`] from a [`PointerType`]. + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_struct_type(ptr_ty.get_element_type().into_struct_type(), llvm_usize) + } + /// Returns the number of elements present in this [`TupleType`]. #[must_use] pub fn num_elements(&self) -> u32 { @@ -117,7 +120,10 @@ impl<'ctx> TupleType<'ctx> { ctx: &CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> >::Value { - self.map_value(Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(), name) + self.map_struct_value( + Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(), + name, + ) } /// Constructs a [`TupleValue`] from `objects`. The resulting tuple preserves the order of @@ -147,13 +153,24 @@ impl<'ctx> TupleType<'ctx> { /// Converts an existing value into a [`ListValue`]. #[must_use] - pub fn map_value( + pub fn map_struct_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + value: StructValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_struct_value(value, self.llvm_usize, name) } + + /// Converts an existing value into a [`TupleValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + value: PointerValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(ctx, value, self.llvm_usize, name) + } } impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> { diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index e482ed5..e43ac74 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -1,7 +1,7 @@ use inkwell::{ context::{AsContextRef, Context, ContextRef}, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -12,10 +12,11 @@ use crate::codegen::{ types::{ structure::{ check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + StructProxyType, }, ProxyType, }, - values::{utils::SliceValue, ProxyValue}, + values::utils::SliceValue, CodeGenContext, CodeGenerator, }; @@ -27,7 +28,7 @@ pub struct SliceType<'ctx> { } #[derive(PartialEq, Eq, Clone, Copy, StructFields)] -pub struct SliceFields<'ctx> { +pub struct SliceStructFields<'ctx> { #[value_type(bool_type())] pub start_defined: StructField<'ctx, IntValue<'ctx>>, #[value_type(usize)] @@ -42,14 +43,14 @@ pub struct SliceFields<'ctx> { pub step: StructField<'ctx, IntValue<'ctx>>, } -impl<'ctx> SliceFields<'ctx> { - /// Creates a new instance of [`SliceFields`] with a custom integer type for its range values. +impl<'ctx> SliceStructFields<'ctx> { + /// Creates a new instance of [`SliceStructFields`] with a custom integer type for its range values. #[must_use] pub fn new_sized(ctx: &impl AsContextRef<'ctx>, int_ty: IntType<'ctx>) -> Self { let ctx = unsafe { ContextRef::new(ctx.as_ctx_ref()) }; let mut counter = FieldIndexCounter::default(); - SliceFields { + SliceStructFields { start_defined: StructField::create(&mut counter, "start_defined", ctx.bool_type()), start: StructField::create(&mut counter, "start", int_ty), stop_defined: StructField::create(&mut counter, "stop_defined", ctx.bool_type()), @@ -61,16 +62,10 @@ impl<'ctx> SliceFields<'ctx> { } impl<'ctx> SliceType<'ctx> { - // TODO: Move this into e.g. StructProxyType - #[must_use] - pub fn get_fields(&self) -> SliceFields<'ctx> { - SliceFields::new_sized(&self.int_ty.get_context(), self.int_ty) - } - /// Creates an LLVM type corresponding to the expected structure of a `Slice`. #[must_use] fn llvm_type(ctx: &'ctx Context, int_ty: IntType<'ctx>) -> PointerType<'ctx> { - let field_tys = SliceFields::new_sized(&int_ty.get_context(), int_ty) + let field_tys = SliceStructFields::new_sized(&int_ty.get_context(), int_ty) .into_iter() .map(|field| field.1) .collect_vec(); @@ -90,6 +85,16 @@ impl<'ctx> SliceType<'ctx> { Self::new_impl(ctx.ctx, int_ty, ctx.get_size_type()) } + /// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + int_ty: IntType<'ctx>, + ) -> Self { + Self::new_impl(ctx, int_ty, generator.get_size_type(ctx)) + } + /// Creates an instance of [`SliceType`] with `usize` as its backing integer type. #[must_use] pub fn new_usize(ctx: &CodeGenContext<'ctx, '_>) -> Self { @@ -105,9 +110,19 @@ impl<'ctx> SliceType<'ctx> { Self::new_impl(ctx, generator.get_size_type(ctx), generator.get_size_type(ctx)) } + /// Creates an [`SliceType`] from a [`StructType`] representing a `slice`. + #[must_use] + pub fn from_struct_type( + ty: StructType<'ctx>, + int_ty: IntType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), int_ty, llvm_usize) + } + /// Creates an [`SliceType`] from a [`PointerType`] representing a `slice`. #[must_use] - pub fn from_type( + pub fn from_pointer_type( ptr_ty: PointerType<'ctx>, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>, @@ -157,11 +172,30 @@ impl<'ctx> SliceType<'ctx> { ) } + /// Converts an existing value into a [`SliceValue`]. + #[must_use] + pub fn map_struct_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.int_ty, + self.llvm_usize, + name, + ) + } + /// Converts an existing value into a [`ContiguousNDArrayValue`]. #[must_use] - pub fn map_value( + pub fn map_pointer_value( &self, - value: <>::Value as ProxyValue<'ctx>>::Base, + value: PointerValue<'ctx>, name: Option<&'ctx str>, ) -> >::Value { >::Value::from_pointer_value( @@ -192,7 +226,7 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { let ctx = ty.get_context(); - let fields = SliceFields::new(ctx, llvm_usize); + let fields = SliceStructFields::new(ctx, llvm_usize); let llvm_ty = ty.get_element_type(); let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { @@ -242,6 +276,14 @@ impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> { } } +impl<'ctx> StructProxyType<'ctx> for SliceType<'ctx> { + type StructFields = SliceStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + SliceStructFields::new_sized(&self.ty.get_context(), self.int_ty) + } +} + impl<'ctx> From> for PointerType<'ctx> { fn from(value: SliceType<'ctx>) -> Self { value.as_base_type() diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index cdd1a41..453065b 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -1,14 +1,18 @@ use inkwell::{ types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, }; use super::{ - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + structure::StructProxyValue, ArrayLikeIndexer, ArrayLikeValue, ProxyValue, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::{ - types::{structure::StructField, ListType, ProxyType}, + types::{ + structure::{StructField, StructProxyType}, + ListType, ProxyType, + }, {CodeGenContext, CodeGenerator}, }; @@ -21,6 +25,26 @@ pub struct ListValue<'ctx> { } impl<'ctx> ListValue<'ctx> { + /// Creates an [`ListValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) + } + /// Creates an [`ListValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -33,19 +57,13 @@ impl<'ctx> ListValue<'ctx> { ListValue { value: ptr, llvm_usize, name } } - fn items_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(&ctx.ctx).items - } - - /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` - /// on the field. - fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.items_field(ctx).ptr_by_gep(ctx, self.value, self.name) + fn items_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().items } /// Stores the array of data elements `data` into this instance. fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { - self.items_field(ctx).store(ctx, self.value, data, self.name); + self.items_field().store(ctx, self.value, data, self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -83,15 +101,15 @@ impl<'ctx> ListValue<'ctx> { ListDataProxy(self) } - fn len_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(&ctx.ctx).len + fn len_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().len } /// Stores the `size` of this `list` into this instance. pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) { debug_assert_eq!(size.get_type(), ctx.get_size_type()); - self.len_field(ctx).store(ctx, self.value, size, self.name); + self.len_field().store(ctx, self.value, size, self.name); } /// Returns the size of this `list` as a value. @@ -100,7 +118,7 @@ impl<'ctx> ListValue<'ctx> { ctx: &CodeGenContext<'ctx, '_>, name: Option<&'ctx str>, ) -> IntValue<'ctx> { - self.len_field(ctx).load(ctx, self.value, name) + self.len_field().load(ctx, self.value, name) } /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. @@ -123,7 +141,7 @@ impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { type Type = ListType<'ctx>; fn get_type(&self) -> Self::Type { - ListType::from_type(self.as_base_value().get_type(), self.llvm_usize) + ListType::from_pointer_type(self.as_base_value().get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { @@ -135,6 +153,8 @@ impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for ListValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: ListValue<'ctx>) -> Self { value.as_base_value() @@ -159,12 +179,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.pptr_to_data(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() + self.0.items_field().load(ctx, self.0.value, self.0.name) } fn size( diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index e30bfae..4935a36 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -1,6 +1,6 @@ use inkwell::{ types::IntType, - values::{IntValue, PointerValue}, + values::{IntValue, PointerValue, StructValue}, }; use itertools::Itertools; @@ -8,12 +8,13 @@ use crate::codegen::{ irrt, types::{ ndarray::{NDArrayType, ShapeEntryType}, - structure::StructField, + structure::{StructField, StructProxyType}, ProxyType, }, values::{ - ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + ndarray::NDArrayValue, structure::StructProxyValue, ArrayLikeIndexer, ArrayLikeValue, + ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + TypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, }; @@ -26,6 +27,26 @@ pub struct ShapeEntryValue<'ctx> { } impl<'ctx> ShapeEntryValue<'ctx> { + /// Creates an [`ShapeEntryValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) + } + /// Creates an [`ShapeEntryValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -39,7 +60,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { } fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(self.value.get_type().get_context()).ndims + self.get_type().get_fields().ndims } /// Stores the number of dimensions into this value. @@ -48,7 +69,7 @@ impl<'ctx> ShapeEntryValue<'ctx> { } fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(self.value.get_type().get_context()).shape + self.get_type().get_fields().shape } /// Stores the shape into this value. @@ -63,7 +84,7 @@ impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { type Type = ShapeEntryType<'ctx>; fn get_type(&self) -> Self::Type { - Self::Type::from_type(self.value.get_type(), self.llvm_usize) + Self::Type::from_pointer_type(self.value.get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { @@ -75,6 +96,8 @@ impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for ShapeEntryValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: ShapeEntryValue<'ctx>) -> Self { value.as_base_value() @@ -163,7 +186,7 @@ fn broadcast_shapes<'ctx, G, Shape>( None, ) }; - let shape_entry = llvm_shape_ty.map_value(pshape_entry, None); + let shape_entry = llvm_shape_ty.map_pointer_value(pshape_entry, None); let in_ndims = llvm_usize.const_int(*in_ndims, false); shape_entry.store_ndims(ctx, in_ndims); diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index b8bf0af..9dca06a 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -1,16 +1,17 @@ use inkwell::{ types::{BasicType, BasicTypeEnum, IntType}, - values::{IntValue, PointerValue}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; -use super::{ArrayLikeValue, NDArrayValue, ProxyValue}; +use super::NDArrayValue; use crate::codegen::{ stmt::gen_if_callback, types::{ ndarray::{ContiguousNDArrayType, NDArrayType}, - structure::StructField, + structure::{StructField, StructProxyType}, }, + values::{structure::StructProxyValue, ArrayLikeValue, ProxyValue}, CodeGenContext, CodeGenerator, }; @@ -23,6 +24,27 @@ pub struct ContiguousNDArrayValue<'ctx> { } impl<'ctx> ContiguousNDArrayValue<'ctx> { + /// Creates an [`ContiguousNDArrayValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + dtype: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, dtype, llvm_usize, name) + } + /// Creates an [`ContiguousNDArrayValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -75,7 +97,7 @@ impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { type Type = ContiguousNDArrayType<'ctx>; fn get_type(&self) -> Self::Type { - >::Type::from_type( + >::Type::from_pointer_type( self.as_base_value().get_type(), self.item, self.llvm_usize, @@ -91,6 +113,8 @@ impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: ContiguousNDArrayValue<'ctx>) -> Self { value.as_base_value() @@ -129,7 +153,7 @@ impl<'ctx> NDArrayValue<'ctx> { |_, ctx| Ok(self.is_c_contiguous(ctx)), |_, ctx| { // This ndarray is contiguous. - let data = self.data_field(ctx).load(ctx, self.as_abi_value(ctx), self.name); + let data = self.data_field().load(ctx, self.as_abi_value(ctx), self.name); let data = ctx .builder .build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "") diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 49fdfe1..6ed0ed0 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -1,6 +1,6 @@ use inkwell::{ types::IntType, - values::{IntValue, PointerValue}, + values::{IntValue, PointerValue, StructValue}, AddressSpace, }; use itertools::Itertools; @@ -12,10 +12,12 @@ use crate::{ irrt, types::{ ndarray::{NDArrayType, NDIndexType}, - structure::StructField, + structure::{StructField, StructProxyType}, utils::SliceType, }, - values::{ndarray::NDArrayValue, utils::RustSlice, ProxyValue}, + values::{ + ndarray::NDArrayValue, structure::StructProxyValue, utils::RustSlice, ProxyValue, + }, CodeGenContext, CodeGenerator, }, typecheck::typedef::Type, @@ -30,6 +32,26 @@ pub struct NDIndexValue<'ctx> { } impl<'ctx> NDIndexValue<'ctx> { + /// Creates an [`NDIndexValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) + } + /// Creates an [`NDIndexValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -73,7 +95,7 @@ impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> { type Type = NDIndexType<'ctx>; fn get_type(&self) -> Self::Type { - Self::Type::from_type(self.value.get_type(), self.llvm_usize) + Self::Type::from_pointer_type(self.value.get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { @@ -85,6 +107,8 @@ impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for NDIndexValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: NDIndexValue<'ctx>) -> Self { value.as_base_value() diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 38c87e0..dcb6947 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -2,14 +2,14 @@ use std::iter::repeat_n; use inkwell::{ types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, + values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, }; use itertools::Itertools; use super::{ - ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TupleValue, TypedArrayLikeAccessor, - TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, + structure::StructProxyValue, ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TupleValue, + TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::{ @@ -18,7 +18,11 @@ use crate::{ llvm_intrinsics::{call_int_umin, call_memcpy_generic_array}, stmt::gen_for_callback_incrementing, type_aligned_alloca, - types::{ndarray::NDArrayType, structure::StructField, TupleType}, + types::{ + ndarray::NDArrayType, + structure::{StructField, StructProxyType}, + TupleType, + }, CodeGenContext, CodeGenerator, }, typecheck::typedef::{Type, TypeEnum}, @@ -49,6 +53,28 @@ pub struct NDArrayValue<'ctx> { } impl<'ctx> NDArrayValue<'ctx> { + /// Creates an [`NDArrayValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + dtype: BasicTypeEnum<'ctx>, + ndims: u64, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, dtype, ndims, llvm_usize, name) + } + /// Creates an [`NDArrayValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -63,52 +89,45 @@ impl<'ctx> NDArrayValue<'ctx> { NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name } } - fn ndims_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).ndims - } - - /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. - fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.ndims_field(ctx).ptr_by_gep(ctx, self.value, self.name) + fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().ndims } /// Stores the number of dimensions `ndims` into this instance. pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, ndims: IntValue<'ctx>) { debug_assert_eq!(ndims.get_type(), ctx.get_size_type()); - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_store(pndims, ndims).unwrap(); + self.ndims_field().store(ctx, self.value, ndims, self.name); } /// Returns the number of dimensions of this `NDArray` as a value. pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() + self.ndims_field().load(ctx, self.value, self.name) } - fn itemsize_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).itemsize + fn itemsize_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().itemsize } /// Stores the size of each element `itemsize` into this instance. pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, itemsize: IntValue<'ctx>) { debug_assert_eq!(itemsize.get_type(), ctx.get_size_type()); - self.itemsize_field(ctx).store(ctx, self.value, itemsize, self.name); + self.itemsize_field().store(ctx, self.value, itemsize, self.name); } /// Returns the size of each element of this `NDArray` as a value. pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.itemsize_field(ctx).load(ctx, self.value, self.name) + self.itemsize_field().load(ctx, self.value, self.name) } - fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).shape + fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().shape } /// Stores the array of dimension sizes `dims` into this instance. fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - self.shape_field(ctx).store(ctx, self.value, dims, self.name); + self.shape_field().store(ctx, self.value, dims, self.name); } /// Convenience method for creating a new array storing dimension sizes with the given `size`. @@ -127,16 +146,13 @@ impl<'ctx> NDArrayValue<'ctx> { NDArrayShapeProxy(self) } - fn strides_field( - &self, - ctx: &CodeGenContext<'ctx, '_>, - ) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).strides + fn strides_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().strides } /// Stores the array of stride sizes `strides` into this instance. fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) { - self.strides_field(ctx).store(ctx, self.value, strides, self.name); + self.strides_field().store(ctx, self.value, strides, self.name); } /// Convenience method for creating a new array storing the stride with the given `size`. @@ -155,14 +171,14 @@ impl<'ctx> NDArrayValue<'ctx> { NDArrayStridesProxy(self) } - fn data_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).data + fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().data } /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - self.data_field(ctx).ptr_by_gep(ctx, self.value, self.name) + self.data_field().ptr_by_gep(ctx, self.value, self.name) } /// Stores the array of data elements `data` into this instance. @@ -171,7 +187,7 @@ impl<'ctx> NDArrayValue<'ctx> { .builder .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - self.data_field(ctx).store(ctx, self.value, data.into_pointer_value(), self.name); + self.data_field().store(ctx, self.value, data.into_pointer_value(), self.name); } /// Convenience method for creating a new array storing data elements with the given element @@ -467,7 +483,7 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { type Type = NDArrayType<'ctx>; fn get_type(&self) -> Self::Type { - NDArrayType::from_type( + NDArrayType::from_pointer_type( self.as_base_value().get_type(), self.dtype, self.ndims, @@ -484,6 +500,8 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for NDArrayValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: NDArrayValue<'ctx>) -> Self { value.as_base_value() @@ -508,7 +526,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.shape_field(ctx).load(ctx, self.0.value, self.0.name) + self.0.shape_field().load(ctx, self.0.value, self.0.name) } fn size( @@ -606,7 +624,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.strides_field(ctx).load(ctx, self.0.value, self.0.name) + self.0.strides_field().load(ctx, self.0.value, self.0.name) } fn size( @@ -704,7 +722,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx: &CodeGenContext<'ctx, '_>, _: &G, ) -> PointerValue<'ctx> { - self.0.data_field(ctx).load(ctx, self.0.value, self.0.name) + self.0.data_field().load(ctx, self.0.value, self.0.name) } fn size( @@ -958,7 +976,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> { if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty) - .map_value(object.into_pointer_value(), None); + .map_pointer_value(object.into_pointer_value(), None); ScalarOrNDArray::NDArray(ndarray) } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index e485574..c1bf751 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -1,15 +1,18 @@ use inkwell::{ types::{BasicType, IntType}, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, }; -use super::{NDArrayValue, ProxyValue}; +use super::NDArrayValue; use crate::codegen::{ irrt, stmt::{gen_for_callback, BreakContinueHooks}, - types::{ndarray::NDIterType, structure::StructField}, - values::{ArraySliceValue, TypedArrayLikeAdapter}, + types::{ + ndarray::NDIterType, + structure::{StructField, StructProxyType}, + }, + values::{structure::StructProxyValue, ArraySliceValue, ProxyValue, TypedArrayLikeAdapter}, CodeGenContext, CodeGenerator, }; @@ -23,6 +26,28 @@ pub struct NDIterValue<'ctx> { } impl<'ctx> NDIterValue<'ctx> { + /// Creates an [`NDArrayValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + parent: NDArrayValue<'ctx>, + indices: ArraySliceValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, parent, indices, llvm_usize, name) + } + /// Creates an [`NDArrayValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -56,11 +81,8 @@ impl<'ctx> NDIterValue<'ctx> { irrt::ndarray::call_nac3_nditer_next(ctx, *self); } - fn element_field( - &self, - ctx: &CodeGenContext<'ctx, '_>, - ) -> StructField<'ctx, PointerValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).element + fn element_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().element } /// Get pointer to the current element. @@ -68,7 +90,7 @@ impl<'ctx> NDIterValue<'ctx> { pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let elem_ty = self.parent.dtype; - let p = self.element_field(ctx).load(ctx, self.as_abi_value(ctx), self.name); + let p = self.element_field().load(ctx, self.as_abi_value(ctx), self.name); ctx.builder .build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element") .unwrap() @@ -81,14 +103,14 @@ impl<'ctx> NDIterValue<'ctx> { ctx.builder.build_load(p, "value").unwrap() } - fn nth_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> { - self.get_type().get_fields(ctx.ctx).nth + fn nth_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().nth } /// Get the index of the current element if this ndarray were a flat ndarray. #[must_use] pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - self.nth_field(ctx).load(ctx, self.as_abi_value(ctx), self.name) + self.nth_field().load(ctx, self.as_abi_value(ctx), self.name) } /// Get the indices of the current element. @@ -110,7 +132,7 @@ impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> { type Type = NDIterType<'ctx>; fn get_type(&self) -> Self::Type { - NDIterType::from_type(self.as_base_value().get_type(), self.llvm_usize) + NDIterType::from_pointer_type(self.as_base_value().get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { @@ -122,6 +144,8 @@ impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for NDIterValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: NDIterValue<'ctx>) -> Self { value.as_base_value() diff --git a/nac3core/src/codegen/values/ndarray/shape.rs b/nac3core/src/codegen/values/ndarray/shape.rs index 3ac2795..b3331b6 100644 --- a/nac3core/src/codegen/values/ndarray/shape.rs +++ b/nac3core/src/codegen/values/ndarray/shape.rs @@ -42,7 +42,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` let input_seq = ListType::from_unifier_type(generator, ctx, input_seq_ty) - .map_value(input_seq.into_pointer_value(), None); + .map_pointer_value(input_seq.into_pointer_value(), None); let len = input_seq.load_size(ctx, None); // TODO: Find a way to remove this mid-BB allocation @@ -86,7 +86,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` let input_seq = TupleType::from_unifier_type(generator, ctx, input_seq_ty) - .map_value(input_seq.into_struct_value(), None); + .map_struct_value(input_seq.into_struct_value(), None); let len = input_seq.get_type().num_elements(); diff --git a/nac3core/src/codegen/values/range.rs b/nac3core/src/codegen/values/range.rs index 20bdba7..67e623a 100644 --- a/nac3core/src/codegen/values/range.rs +++ b/nac3core/src/codegen/values/range.rs @@ -1,10 +1,10 @@ use inkwell::{ types::IntType, - values::{BasicValueEnum, IntValue, PointerValue}, + values::{ArrayValue, BasicValueEnum, IntValue, PointerValue}, }; use super::ProxyValue; -use crate::codegen::{types::RangeType, CodeGenContext}; +use crate::codegen::{types::RangeType, CodeGenContext, CodeGenerator}; /// Proxy type for accessing a `range` value in LLVM. #[derive(Copy, Clone)] @@ -15,6 +15,26 @@ pub struct RangeValue<'ctx> { } impl<'ctx> RangeValue<'ctx> { + /// Creates an [`RangeValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_array_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: ArrayValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) + } + /// Creates an [`RangeValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -142,7 +162,7 @@ impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { type Type = RangeType<'ctx>; fn get_type(&self) -> Self::Type { - RangeType::from_type(self.value.get_type(), self.llvm_usize) + RangeType::from_pointer_type(self.value.get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs index 08b2b8b..320e219 100644 --- a/nac3core/src/codegen/values/tuple.rs +++ b/nac3core/src/codegen/values/tuple.rs @@ -1,6 +1,6 @@ use inkwell::{ types::IntType, - values::{BasicValue, BasicValueEnum, StructValue}, + values::{BasicValue, BasicValueEnum, PointerValue, StructValue}, }; use super::ProxyValue; @@ -26,6 +26,24 @@ impl<'ctx> TupleValue<'ctx> { Self { value, llvm_usize, name } } + /// Creates an [`TupleValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ctx: &CodeGenContext<'ctx, '_>, + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + Self::from_struct_value( + ctx.builder + .build_load(ptr, name.unwrap_or_default()) + .map(BasicValueEnum::into_struct_value) + .unwrap(), + llvm_usize, + name, + ) + } + /// Stores a value into the tuple element at the given `index`. pub fn store_element( &mut self, @@ -62,7 +80,7 @@ impl<'ctx> ProxyValue<'ctx> for TupleValue<'ctx> { type Type = TupleType<'ctx>; fn get_type(&self) -> Self::Type { - TupleType::from_type(self.as_base_value().get_type(), self.llvm_usize) + TupleType::from_struct_type(self.as_base_value().get_type(), self.llvm_usize) } fn as_base_value(&self) -> Self::Base { diff --git a/nac3core/src/codegen/values/utils/slice.rs b/nac3core/src/codegen/values/utils/slice.rs index f773935..549e556 100644 --- a/nac3core/src/codegen/values/utils/slice.rs +++ b/nac3core/src/codegen/values/utils/slice.rs @@ -1,14 +1,17 @@ use inkwell::{ types::IntType, - values::{IntValue, PointerValue}, + values::{IntValue, PointerValue, StructValue}, }; use nac3parser::ast::Expr; use crate::{ codegen::{ - types::{structure::StructField, utils::SliceType}, - values::ProxyValue, + types::{ + structure::{StructField, StructProxyType}, + utils::SliceType, + }, + values::{structure::StructProxyValue, ProxyValue}, CodeGenContext, CodeGenerator, }, typecheck::typedef::Type, @@ -24,6 +27,27 @@ pub struct SliceValue<'ctx> { } impl<'ctx> SliceValue<'ctx> { + /// Creates an [`SliceValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + int_ty: IntType<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, int_ty, llvm_usize, name) + } + /// Creates an [`SliceValue`] from a [`PointerValue`]. #[must_use] pub fn from_pointer_value( @@ -155,7 +179,7 @@ impl<'ctx> ProxyValue<'ctx> for SliceValue<'ctx> { type Type = SliceType<'ctx>; fn get_type(&self) -> Self::Type { - Self::Type::from_type(self.value.get_type(), self.int_ty, self.llvm_usize) + Self::Type::from_pointer_type(self.value.get_type(), self.int_ty, self.llvm_usize) } fn as_base_value(&self) -> Self::Base { @@ -167,6 +191,8 @@ impl<'ctx> ProxyValue<'ctx> for SliceValue<'ctx> { } } +impl<'ctx> StructProxyValue<'ctx> for SliceValue<'ctx> {} + impl<'ctx> From> for PointerValue<'ctx> { fn from(value: SliceValue<'ctx>) -> Self { value.as_base_value() diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 165f64a..eff614e 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -577,7 +577,7 @@ impl<'a> BuiltinBuilder<'a> { let (zelf_ty, zelf) = obj.unwrap(); let zelf = zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value(); - let zelf = RangeType::new(ctx).map_value(zelf, Some("range")); + let zelf = RangeType::new(ctx).map_pointer_value(zelf, Some("range")); let mut start = None; let mut stop = None; @@ -1280,7 +1280,7 @@ impl<'a> BuiltinBuilder<'a> { let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) - .map_value(ndarray.into_pointer_value(), None); + .map_pointer_value(ndarray.into_pointer_value(), None); let size = ctx .builder @@ -1312,7 +1312,7 @@ impl<'a> BuiltinBuilder<'a> { args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) - .map_value(ndarray.into_pointer_value(), None); + .map_pointer_value(ndarray.into_pointer_value(), None); let result_tuple = match prim { PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx), @@ -1353,7 +1353,7 @@ impl<'a> BuiltinBuilder<'a> { let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) - .map_value(arg_val.into_pointer_value(), None); + .map_pointer_value(arg_val.into_pointer_value(), None); let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument Ok(Some(ndarray.as_abi_value(ctx).into())) @@ -1391,7 +1391,7 @@ impl<'a> BuiltinBuilder<'a> { args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?; let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty) - .map_value(ndarray_val.into_pointer_value(), None); + .map_pointer_value(ndarray_val.into_pointer_value(), None); let shape = parse_numpy_int_sequence(generator, ctx, (shape_ty, shape_val)); From d394b24304ed06bd758fdf684f943131dd84e282 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 3 Feb 2025 13:10:13 +0800 Subject: [PATCH 33/49] [meta] flake: Add LLVM bintools to artiq-{instrumented,pgo} --- flake.nix | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flake.nix b/flake.nix index a48ff69..51551c7 100644 --- a/flake.nix +++ b/flake.nix @@ -85,7 +85,7 @@ name = "nac3artiq-instrumented"; src = self; inherit (nac3artiq) cargoLock; - nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-instrumented ]; + nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt pkgs.llvmPackages_14.bintools llvm-nac3-instrumented ]; buildInputs = [ pkgs.python3 llvm-nac3-instrumented ]; cargoBuildFlags = [ "--package" "nac3artiq" "--features" "init-llvm-profile" ]; doCheck = false; @@ -148,7 +148,7 @@ name = "nac3artiq-pgo"; src = self; inherit (nac3artiq) cargoLock; - nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-pgo ]; + nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt pkgs.llvmPackages_14.bintools llvm-nac3-pgo ]; buildInputs = [ pkgs.python3 llvm-nac3-pgo ]; cargoBuildFlags = [ "--package" "nac3artiq" ]; cargoTestFlags = [ "--package" "nac3ast" "--package" "nac3parser" "--package" "nac3core" "--package" "nac3artiq" ]; From c32c68b0b0823a21a4528262321dc4c1ebc4fde4 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Wed, 5 Feb 2025 15:42:23 +0800 Subject: [PATCH 34/49] flake: update dependencies --- flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flake.lock b/flake.lock index 3e4af70..7e36e53 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1736798957, - "narHash": "sha256-qwpCtZhSsSNQtK4xYGzMiyEDhkNzOCz/Vfu4oL2ETsQ=", + "lastModified": 1738680400, + "narHash": "sha256-ooLh+XW8jfa+91F1nhf9OF7qhuA/y1ChLx6lXDNeY5U=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "9abb87b552b7f55ac8916b6fc9e5cb486656a2f3", + "rev": "799ba5bffed04ced7067a91798353d360788b30d", "type": "github" }, "original": { From 6bcdc3ce00ab0b46dead9db89537dcb199121aac Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 10 Feb 2025 10:56:22 +0800 Subject: [PATCH 35/49] [core] codegen/extern_fns: Change expansion pattern Makes more sense to attach the parameter delimiter to the end of each parameter. --- nac3core/src/codegen/extern_fns.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nac3core/src/codegen/extern_fns.rs b/nac3core/src/codegen/extern_fns.rs index 32cf37d..6412fbd 100644 --- a/nac3core/src/codegen/extern_fns.rs +++ b/nac3core/src/codegen/extern_fns.rs @@ -37,8 +37,8 @@ macro_rules! generate_extern_fn { ($fn_name:ident, $extern_fn:literal $(,$args:ident)* $(,$attributes:literal)*) => { #[doc = concat!("Invokes the [`", stringify!($extern_fn), "`](https://en.cppreference.com/w/c/numeric/math/", stringify!($llvm_name), ") function." )] pub fn $fn_name<'ctx>( - ctx: &CodeGenContext<'ctx, '_> - $(,$args: FloatValue<'ctx>)*, + ctx: &CodeGenContext<'ctx, '_>, + $($args: FloatValue<'ctx>,)* name: Option<&str>, ) -> FloatValue<'ctx> { const FN_NAME: &str = $extern_fn; @@ -158,8 +158,8 @@ macro_rules! generate_linalg_extern_fn { ($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => { #[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )] pub fn $fn_name<'ctx>( - ctx: &mut CodeGenContext<'ctx, '_> - $(,$input_matrix: BasicValueEnum<'ctx>)*, + ctx: &mut CodeGenContext<'ctx, '_>, + $($input_matrix: BasicValueEnum<'ctx>,)* name: Option<&str>, ){ const FN_NAME: &str = $extern_fn; From f52ba9f15140335a082c124602c43ea0539e33c7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 4 Feb 2025 17:13:33 +0800 Subject: [PATCH 36/49] [core] codegen/irrt: Refactor IRRT to use more create/infer fns --- nac3artiq/src/codegen.rs | 82 +++--------- nac3artiq/src/timeline.rs | 48 +++---- nac3core/src/codegen/extern_fns.rs | 104 ++++++++------- nac3core/src/codegen/irrt/list.rs | 72 +++++------ nac3core/src/codegen/irrt/math.rs | 135 +++++++++----------- nac3core/src/codegen/irrt/ndarray/basic.rs | 91 ++++--------- nac3core/src/codegen/irrt/ndarray/iter.rs | 23 +--- nac3core/src/codegen/irrt/ndarray/matmul.rs | 28 +--- nac3core/src/codegen/irrt/range.rs | 27 ++-- nac3core/src/codegen/irrt/slice.rs | 30 +++-- nac3core/src/codegen/irrt/string.rs | 50 +++----- 11 files changed, 270 insertions(+), 420 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 572accc..fb6992b 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -15,7 +15,7 @@ use pyo3::{ use super::{symbol_resolver::InnerResolver, timeline::TimeFns}; use nac3core::{ codegen::{ - expr::{destructure_range, gen_call}, + expr::{create_fn_and_call, destructure_range, gen_call, infer_and_call_function}, llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, type_aligned_alloca, @@ -914,47 +914,14 @@ fn rpc_codegen_callback_fn<'ctx>( } // call - if is_async { - let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| { - ctx.module.add_function( - "rpc_send_async", - ctx.ctx.void_type().fn_type( - &[ - int32.into(), - tag_ptr_type.ptr_type(AddressSpace::default()).into(), - ptr_type.ptr_type(AddressSpace::default()).into(), - ], - false, - ), - None, - ) - }); - ctx.builder - .build_call( - rpc_send_async, - &[service_id.into(), tag_ptr.into(), args_ptr.into()], - "rpc.send", - ) - .unwrap(); - } else { - let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| { - ctx.module.add_function( - "rpc_send", - ctx.ctx.void_type().fn_type( - &[ - int32.into(), - tag_ptr_type.ptr_type(AddressSpace::default()).into(), - ptr_type.ptr_type(AddressSpace::default()).into(), - ], - false, - ), - None, - ) - }); - ctx.builder - .build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") - .unwrap(); - } + infer_and_call_function( + ctx, + if is_async { "rpc_send_async" } else { "rpc_send" }, + None, + &[service_id.into(), tag_ptr.into(), args_ptr.into()], + Some("rpc.send"), + None, + ); // reclaim stack space used by arguments call_stackrestore(ctx, stackptr); @@ -1168,29 +1135,22 @@ fn polymorphic_print<'ctx>( debug_assert!(!fmt.is_empty()); debug_assert_eq!(fmt.as_bytes().last().unwrap(), &0u8); - let fn_name = if as_rtio { "rtio_log" } else { "core_log" }; - let print_fn = ctx.module.get_function(fn_name).unwrap_or_else(|| { - let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let fn_t = if as_rtio { - let llvm_void = ctx.ctx.void_type(); - llvm_void.fn_type(&[llvm_pi8.into()], true) - } else { - let llvm_i32 = ctx.ctx.i32_type(); - llvm_i32.fn_type(&[llvm_pi8.into()], true) - }; - ctx.module.add_function(fn_name, fn_t, None) - }); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); let fmt = ctx.gen_string(generator, fmt); let fmt = unsafe { fmt.get_field_at_index_unchecked(0) }.into_pointer_value(); - ctx.builder - .build_call( - print_fn, - &once(fmt.into()).chain(args).map(BasicValueEnum::into).collect_vec(), - "", - ) - .unwrap(); + create_fn_and_call( + ctx, + if as_rtio { "rtio_log" } else { "core_log" }, + if as_rtio { None } else { Some(llvm_i32.into()) }, + &[llvm_pi8.into()], + &once(fmt.into()).chain(args).map(BasicValueEnum::into).collect_vec(), + true, + None, + None, + ); }; let llvm_i32 = ctx.ctx.i32_type(); diff --git a/nac3artiq/src/timeline.rs b/nac3artiq/src/timeline.rs index f51c553..5d6fc79 100644 --- a/nac3artiq/src/timeline.rs +++ b/nac3artiq/src/timeline.rs @@ -1,11 +1,6 @@ -use itertools::Either; - use nac3core::{ - codegen::CodeGenContext, - inkwell::{ - values::{BasicValueEnum, CallSiteValue}, - AddressSpace, AtomicOrdering, - }, + codegen::{expr::infer_and_call_function, CodeGenContext}, + inkwell::{values::BasicValueEnum, AddressSpace, AtomicOrdering}, }; /// Functions for manipulating the timeline. @@ -288,36 +283,27 @@ pub struct ExternTimeFns {} impl TimeFns for ExternTimeFns { fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> { - let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| { - ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None) - }); - ctx.builder - .build_call(now_mu, &[], "now_mu") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + "now_mu", + Some(ctx.ctx.i64_type().into()), + &[], + Some("now_mu"), + None, + ) + .unwrap() } fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { - let at_mu = ctx.module.get_function("at_mu").unwrap_or_else(|| { - ctx.module.add_function( - "at_mu", - ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false), - None, - ) - }); - ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap(); + assert_eq!(t.get_type(), ctx.ctx.i64_type().into()); + + infer_and_call_function(ctx, "at_mu", None, &[t], Some("at_mu"), None); } fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) { - let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| { - ctx.module.add_function( - "delay_mu", - ctx.ctx.void_type().fn_type(&[ctx.ctx.i64_type().into()], false), - None, - ) - }); - ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu").unwrap(); + assert_eq!(dt.get_type(), ctx.ctx.i64_type().into()); + + infer_and_call_function(ctx, "delay_mu", None, &[dt], Some("delay_mu"), None); } } diff --git a/nac3core/src/codegen/extern_fns.rs b/nac3core/src/codegen/extern_fns.rs index 6412fbd..8dcc2f2 100644 --- a/nac3core/src/codegen/extern_fns.rs +++ b/nac3core/src/codegen/extern_fns.rs @@ -1,10 +1,9 @@ use inkwell::{ attributes::{Attribute, AttributeLoc}, - values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}, + values::{BasicValueEnum, FloatValue, IntValue}, }; -use itertools::Either; -use super::CodeGenContext; +use super::{expr::infer_and_call_function, CodeGenContext}; /// Macro to generate extern function /// Both function return type and function parameter type are `FloatValue` @@ -46,24 +45,23 @@ macro_rules! generate_extern_fn { let llvm_f64 = ctx.ctx.f64_type(); $(debug_assert_eq!($args.get_type(), llvm_f64);)* - let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[$($args.get_type().into()),*], false); - let func = ctx.module.add_function(FN_NAME, fn_type, None); - for attr in [$($attributes),*] { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), - ); - } - func - }); - - ctx.builder - .build_call(extern_fn, &[$($args.into()),*], name.unwrap_or_default()) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + FN_NAME, + Some(llvm_f64.into()), + &[$($args.into()),*], + name, + Some(&|func| { + for attr in [$($attributes),*] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + }) + ) + .map(BasicValueEnum::into_float_value) + .unwrap() } }; } @@ -112,25 +110,23 @@ pub fn call_ldexp<'ctx>( debug_assert_eq!(arg.get_type(), llvm_f64); debug_assert_eq!(exp.get_type(), llvm_i32); - let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false); - let func = ctx.module.add_function(FN_NAME, fn_type, None); - for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), - ); - } - - func - }); - - ctx.builder - .build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default()) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + FN_NAME, + Some(llvm_f64.into()), + &[arg.into(), exp.into()], + name, + Some(&|func| { + for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + }), + ) + .map(BasicValueEnum::into_float_value) + .unwrap() } /// Macro to generate `np_linalg` and `sp_linalg` functions @@ -163,20 +159,22 @@ macro_rules! generate_linalg_extern_fn { name: Option<&str>, ){ const FN_NAME: &str = $extern_fn; - let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { - let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false); - let func = ctx.module.add_function(FN_NAME, fn_type, None); - for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), - ); - } - func - }); - - ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap(); + infer_and_call_function( + ctx, + FN_NAME, + None, + &[$($input_matrix.into(),)*], + name, + Some(&|func| { + for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + }), + ); } }; } diff --git a/nac3core/src/codegen/irrt/list.rs b/nac3core/src/codegen/irrt/list.rs index c01e2cb..20daace 100644 --- a/nac3core/src/codegen/irrt/list.rs +++ b/nac3core/src/codegen/irrt/list.rs @@ -1,13 +1,14 @@ use inkwell::{ types::BasicTypeEnum, - values::{BasicValueEnum, CallSiteValue, IntValue}, + values::{BasicValueEnum, IntValue}, AddressSpace, IntPredicate, }; -use itertools::Either; use super::calculate_len_for_slice_range; use crate::codegen::{ + expr::infer_and_call_function, macros::codegen_unreachable, + stmt::gen_if_callback, values::{ArrayLikeValue, ListValue}, CodeGenContext, CodeGenerator, }; @@ -36,25 +37,6 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( assert_eq!(src_idx.2.get_type(), llvm_i32); let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", llvm_pi8); - let slice_assign_fun = { - let ty_vec = vec![ - llvm_i32.into(), // dest start idx - llvm_i32.into(), // dest end idx - llvm_i32.into(), // dest step - elem_ptr_type.into(), // dest arr ptr - llvm_i32.into(), // dest arr len - llvm_i32.into(), // src start idx - llvm_i32.into(), // src end idx - llvm_i32.into(), // src step - elem_ptr_type.into(), // src arr ptr - llvm_i32.into(), // src arr len - llvm_i32.into(), // size - ]; - ctx.module.get_function(fun_symbol).unwrap_or_else(|| { - let fn_t = llvm_i32.fn_type(ty_vec.as_slice(), false); - ctx.module.add_function(fun_symbol, fn_t, None) - }) - }; let zero = llvm_i32.const_zero(); let one = llvm_i32.const_int(1, false); @@ -127,7 +109,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( ); let new_len = { - let args = vec![ + let args = [ dest_idx.0.into(), // dest start idx dest_idx.1.into(), // dest end idx dest_idx.2.into(), // dest step @@ -150,25 +132,35 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( } .into(), ]; - ctx.builder - .build_call(slice_assign_fun, args.as_slice(), "slice_assign") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + fun_symbol, + Some(llvm_i32.into()), + &args, + Some("slice_assign"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() }; // update length - let need_update = - ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); - let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); - let update_bb = ctx.ctx.append_basic_block(current, "update"); - let cont_bb = ctx.ctx.append_basic_block(current, "cont"); - ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); - ctx.builder.position_at_end(update_bb); - let new_len = - ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); - dest_arr.store_size(ctx, new_len); - ctx.builder.build_unconditional_branch(cont_bb).unwrap(); - ctx.builder.position_at_end(cont_bb); + gen_if_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update") + .unwrap()) + }, + |_, ctx| { + let new_len = + ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); + dest_arr.store_size(ctx, new_len); + Ok(()) + }, + |_, _| Ok(()), + ) + .unwrap(); } diff --git a/nac3core/src/codegen/irrt/math.rs b/nac3core/src/codegen/irrt/math.rs index 33445b2..e430080 100644 --- a/nac3core/src/codegen/irrt/math.rs +++ b/nac3core/src/codegen/irrt/math.rs @@ -1,10 +1,10 @@ use inkwell::{ - values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}, + values::{BasicValueEnum, FloatValue, IntValue}, IntPredicate, }; -use itertools::Either; use crate::codegen::{ + expr::infer_and_call_function, macros::codegen_unreachable, {CodeGenContext, CodeGenerator}, }; @@ -18,18 +18,16 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>( exp: IntValue<'ctx>, signed: bool, ) -> IntValue<'ctx> { - let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) { + let base_type = base.get_type(); + + let symbol = match (base_type.get_bit_width(), exp.get_type().get_bit_width(), signed) { (32, 32, true) => "__nac3_int_exp_int32_t", (64, 64, true) => "__nac3_int_exp_int64_t", (32, 32, false) => "__nac3_int_exp_uint32_t", (64, 64, false) => "__nac3_int_exp_uint64_t", _ => codegen_unreachable!(ctx), }; - let base_type = base.get_type(); - let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| { - let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false); - ctx.module.add_function(symbol, fn_type, None) - }); + // throw exception when exp < 0 let ge_zero = ctx .builder @@ -48,12 +46,17 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>( [None, None, None], ctx.current_loc, ); - ctx.builder - .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() + + infer_and_call_function( + ctx, + symbol, + Some(base_type.into()), + &[base.into(), exp.into()], + Some("call_int_pow"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() } /// Generates a call to `isinf` in IR. Returns an `i1` representing the result. @@ -67,20 +70,17 @@ pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>( assert_eq!(v.get_type(), llvm_f64); - let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { - let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_isinf", fn_type, None) - }); - - let ret = ctx - .builder - .build_call(intrinsic_fn, &[v.into()], "isinf") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - generator.bool_to_i1(ctx, ret) + infer_and_call_function( + ctx, + "__nac3_isinf", + Some(llvm_i32.into()), + &[v.into()], + Some("isinf"), + None, + ) + .map(BasicValueEnum::into_int_value) + .map(|ret| generator.bool_to_i1(ctx, ret)) + .unwrap() } /// Generates a call to `isnan` in IR. Returns an `i1` representing the result. @@ -94,20 +94,17 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( assert_eq!(v.get_type(), llvm_f64); - let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { - let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_isnan", fn_type, None) - }); - - let ret = ctx - .builder - .build_call(intrinsic_fn, &[v.into()], "isnan") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - generator.bool_to_i1(ctx, ret) + infer_and_call_function( + ctx, + "__nac3_isnan", + Some(llvm_i32.into()), + &[v.into()], + Some("isnan"), + None, + ) + .map(BasicValueEnum::into_int_value) + .map(|ret| generator.bool_to_i1(ctx, ret)) + .unwrap() } /// Generates a call to `gamma` in IR. Returns an `f64` representing the result. @@ -116,17 +113,16 @@ pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> assert_eq!(v.get_type(), llvm_f64); - let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_gamma", fn_type, None) - }); - - ctx.builder - .build_call(intrinsic_fn, &[v.into()], "gamma") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + "__nac3_gamma", + Some(llvm_f64.into()), + &[v.into()], + Some("gamma"), + None, + ) + .map(BasicValueEnum::into_float_value) + .unwrap() } /// Generates a call to `gammaln` in IR. Returns an `f64` representing the result. @@ -135,17 +131,16 @@ pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) - assert_eq!(v.get_type(), llvm_f64); - let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_gammaln", fn_type, None) - }); - - ctx.builder - .build_call(intrinsic_fn, &[v.into()], "gammaln") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + "__nac3_gammaln", + Some(llvm_f64.into()), + &[v.into()], + Some("gammaln"), + None, + ) + .map(BasicValueEnum::into_float_value) + .unwrap() } /// Generates a call to `j0` in IR. Returns an `f64` representing the result. @@ -154,15 +149,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo assert_eq!(v.get_type(), llvm_f64); - let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_j0", fn_type, None) - }); - - ctx.builder - .build_call(intrinsic_fn, &[v.into()], "j0") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) + infer_and_call_function(ctx, "__nac3_j0", Some(llvm_f64.into()), &[v.into()], Some("j0"), None) + .map(BasicValueEnum::into_float_value) .unwrap() } diff --git a/nac3core/src/codegen/irrt/ndarray/basic.rs b/nac3core/src/codegen/irrt/ndarray/basic.rs index 5f291c8..06f38f7 100644 --- a/nac3core/src/codegen/irrt/ndarray/basic.rs +++ b/nac3core/src/codegen/irrt/ndarray/basic.rs @@ -1,13 +1,11 @@ use inkwell::{ - types::BasicTypeEnum, values::{BasicValueEnum, IntValue, PointerValue}, AddressSpace, }; use crate::codegen::{ - expr::{create_and_call_function, infer_and_call_function}, + expr::infer_and_call_function, irrt::get_usize_dependent_function_name, - types::ProxyType, values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor}, CodeGenContext, CodeGenerator, }; @@ -21,24 +19,17 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = ctx.get_size_type(); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - assert_eq!( - BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); + assert_eq!(shape.element_type(ctx, generator), llvm_usize.into()); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_usize.into()), - &[ - (llvm_usize.into(), shape.size(ctx, generator).into()), - (llvm_pusize.into(), shape.base_ptr(ctx, generator).into()), - ], + &[shape.size(ctx, generator).into(), shape.base_ptr(ctx, generator).into()], None, None, ); @@ -55,29 +46,22 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = ctx.get_size_type(); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - assert_eq!( - BasicTypeEnum::try_from(ndarray_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - assert_eq!( - BasicTypeEnum::try_from(output_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); + assert_eq!(ndarray_shape.element_type(ctx, generator), llvm_usize.into()); + assert_eq!(output_shape.element_type(ctx, generator), llvm_usize.into()); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_output_shape_same"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_usize.into()), &[ - (llvm_usize.into(), ndarray_shape.size(ctx, generator).into()), - (llvm_pusize.into(), ndarray_shape.base_ptr(ctx, generator).into()), - (llvm_usize.into(), output_shape.size(ctx, generator).into()), - (llvm_pusize.into(), output_shape.base_ptr(ctx, generator).into()), + ndarray_shape.size(ctx, generator).into(), + ndarray_shape.base_ptr(ctx, generator).into(), + output_shape.size(ctx, generator).into(), + output_shape.base_ptr(ctx, generator).into(), ], None, None, @@ -93,15 +77,14 @@ pub fn call_nac3_ndarray_size<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], + &[ndarray.as_abi_value(ctx).into()], Some("size"), None, ) @@ -118,15 +101,14 @@ pub fn call_nac3_ndarray_nbytes<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], + &[ndarray.as_abi_value(ctx).into()], Some("nbytes"), None, ) @@ -143,15 +125,14 @@ pub fn call_nac3_ndarray_len<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_usize.into()), - &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], + &[ndarray.as_abi_value(ctx).into()], Some("len"), None, ) @@ -167,15 +148,14 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( ndarray: NDArrayValue<'ctx>, ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); - let llvm_ndarray = ndarray.get_type(); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_i1.into()), - &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], + &[ndarray.as_abi_value(ctx).into()], Some("is_c_contiguous"), None, ) @@ -194,20 +174,16 @@ pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_usize = ctx.get_size_type(); - let llvm_ndarray = ndarray.get_type(); assert_eq!(index.get_type(), llvm_usize); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_nth_pelement"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_pi8.into()), - &[ - (llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()), - (llvm_usize.into(), index.into()), - ], + &[ndarray.as_abi_value(ctx).into(), index.into()], Some("pelement"), None, ) @@ -229,24 +205,16 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_usize = ctx.get_size_type(); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let llvm_ndarray = ndarray.get_type(); - assert_eq!( - BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); + assert_eq!(indices.element_type(ctx, generator), llvm_usize.into()); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_pelement_by_indices"); - create_and_call_function( + infer_and_call_function( ctx, &name, Some(llvm_pi8.into()), - &[ - (llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()), - (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), - ], + &[ndarray.as_abi_value(ctx).into(), indices.base_ptr(ctx, generator).into()], Some("pelement"), None, ) @@ -261,18 +229,9 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( ctx: &CodeGenContext<'ctx, '_>, ndarray: NDArrayValue<'ctx>, ) { - let llvm_ndarray = ndarray.get_type(); - let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape"); - create_and_call_function( - ctx, - &name, - None, - &[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())], - None, - None, - ); + infer_and_call_function(ctx, &name, None, &[ndarray.as_abi_value(ctx).into()], None, None); } /// Generates a call to `__nac3_ndarray_copy_data`. diff --git a/nac3core/src/codegen/irrt/ndarray/iter.rs b/nac3core/src/codegen/irrt/ndarray/iter.rs index e4424df..d44870d 100644 --- a/nac3core/src/codegen/irrt/ndarray/iter.rs +++ b/nac3core/src/codegen/irrt/ndarray/iter.rs @@ -1,13 +1,8 @@ -use inkwell::{ - types::BasicTypeEnum, - values::{BasicValueEnum, IntValue}, - AddressSpace, -}; +use inkwell::values::{BasicValueEnum, IntValue}; use crate::codegen::{ - expr::{create_and_call_function, infer_and_call_function}, + expr::infer_and_call_function, irrt::get_usize_dependent_function_name, - types::ProxyType, values::{ ndarray::{NDArrayValue, NDIterValue}, ProxyValue, TypedArrayLikeAccessor, @@ -26,23 +21,19 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ) { let llvm_usize = ctx.get_size_type(); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - assert_eq!( - BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); + assert_eq!(indices.element_type(ctx, generator), llvm_usize.into()); let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize"); - create_and_call_function( + infer_and_call_function( ctx, &name, None, &[ - (iter.get_type().as_abi_type().into(), iter.as_abi_value(ctx).into()), - (ndarray.get_type().as_abi_type().into(), ndarray.as_abi_value(ctx).into()), - (llvm_pusize.into(), indices.base_ptr(ctx, generator).into()), + iter.as_abi_value(ctx).into(), + ndarray.as_abi_value(ctx).into(), + indices.base_ptr(ctx, generator).into(), ], None, None, diff --git a/nac3core/src/codegen/irrt/ndarray/matmul.rs b/nac3core/src/codegen/irrt/ndarray/matmul.rs index 0df774f..d2e73ae 100644 --- a/nac3core/src/codegen/irrt/ndarray/matmul.rs +++ b/nac3core/src/codegen/irrt/ndarray/matmul.rs @@ -1,4 +1,4 @@ -use inkwell::{types::BasicTypeEnum, values::IntValue}; +use inkwell::values::IntValue; use crate::codegen::{ expr::infer_and_call_function, irrt::get_usize_dependent_function_name, @@ -22,26 +22,12 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized ) { let llvm_usize = ctx.get_size_type(); - assert_eq!( - BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - assert_eq!( - BasicTypeEnum::try_from(b_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - assert_eq!( - BasicTypeEnum::try_from(new_a_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - assert_eq!( - BasicTypeEnum::try_from(new_b_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); - assert_eq!( - BasicTypeEnum::try_from(dst_shape.element_type(ctx, generator)).unwrap(), - llvm_usize.into() - ); + assert_eq!(a_shape.element_type(ctx, generator), llvm_usize.into()); + assert_eq!(b_shape.element_type(ctx, generator), llvm_usize.into()); + assert_eq!(final_ndims.get_type(), llvm_usize); + assert_eq!(new_a_shape.element_type(ctx, generator), llvm_usize.into()); + assert_eq!(new_b_shape.element_type(ctx, generator), llvm_usize.into()); + assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into()); let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes"); diff --git a/nac3core/src/codegen/irrt/range.rs b/nac3core/src/codegen/irrt/range.rs index 3b6bc31..d624929 100644 --- a/nac3core/src/codegen/irrt/range.rs +++ b/nac3core/src/codegen/irrt/range.rs @@ -1,10 +1,9 @@ use inkwell::{ - values::{BasicValueEnum, CallSiteValue, IntValue}, + values::{BasicValueEnum, IntValue}, IntPredicate, }; -use itertools::Either; -use crate::codegen::{CodeGenContext, CodeGenerator}; +use crate::codegen::{expr::infer_and_call_function, CodeGenContext, CodeGenerator}; /// Invokes the `__nac3_range_slice_len` in IRRT. /// @@ -23,16 +22,10 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( const SYMBOL: &str = "__nac3_range_slice_len"; let llvm_i32 = ctx.ctx.i32_type(); - assert_eq!(start.get_type(), llvm_i32); assert_eq!(end.get_type(), llvm_i32); assert_eq!(step.get_type(), llvm_i32); - let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { - let fn_t = llvm_i32.fn_type(&[llvm_i32.into(), llvm_i32.into(), llvm_i32.into()], false); - ctx.module.add_function(SYMBOL, fn_t, None) - }); - // assert step != 0, throw exception if not let not_zero = ctx .builder @@ -47,10 +40,14 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( ctx.current_loc, ); - ctx.builder - .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + SYMBOL, + Some(llvm_i32.into()), + &[start.into(), end.into(), step.into()], + Some("calc_len"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() } diff --git a/nac3core/src/codegen/irrt/slice.rs b/nac3core/src/codegen/irrt/slice.rs index 35e2151..cc1f28d 100644 --- a/nac3core/src/codegen/irrt/slice.rs +++ b/nac3core/src/codegen/irrt/slice.rs @@ -1,10 +1,9 @@ -use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue}; -use itertools::Either; +use inkwell::values::{BasicValueEnum, IntValue}; use nac3parser::ast::Expr; use crate::{ - codegen::{CodeGenContext, CodeGenerator}, + codegen::{expr::infer_and_call_function, CodeGenContext, CodeGenerator}, typecheck::typedef::Type, }; @@ -17,23 +16,26 @@ pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>( length: IntValue<'ctx>, ) -> Result>, String> { const SYMBOL: &str = "__nac3_slice_index_bound"; - let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { - let i32_t = ctx.ctx.i32_type(); - let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false); - ctx.module.add_function(SYMBOL, fn_t, None) - }); + + let llvm_i32 = ctx.ctx.i32_type(); + assert_eq!(length.get_type(), llvm_i32); let i = if let Some(v) = generator.gen_expr(ctx, i)? { v.to_basic_value_enum(ctx, generator, i.custom.unwrap())? } else { return Ok(None); }; + Ok(Some( - ctx.builder - .build_call(func, &[i.into(), length.into()], "bounded_ind") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(), + infer_and_call_function( + ctx, + SYMBOL, + Some(llvm_i32.into()), + &[i, length.into()], + Some("bounded_ind"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap(), )) } diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs index e2fd8c0..e015570 100644 --- a/nac3core/src/codegen/irrt/string.rs +++ b/nac3core/src/codegen/irrt/string.rs @@ -1,8 +1,10 @@ -use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}; -use itertools::Either; +use inkwell::{ + values::{BasicValueEnum, IntValue, PointerValue}, + AddressSpace, +}; use super::get_usize_dependent_function_name; -use crate::codegen::CodeGenContext; +use crate::codegen::{expr::infer_and_call_function, CodeGenContext}; /// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. pub fn call_string_eq<'ctx>( @@ -13,33 +15,23 @@ pub fn call_string_eq<'ctx>( str2_len: IntValue<'ctx>, ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); + let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + let llvm_usize = ctx.get_size_type(); + assert_eq!(str1_ptr.get_type(), llvm_pi8); + assert_eq!(str1_len.get_type(), llvm_usize); + assert_eq!(str2_ptr.get_type(), llvm_pi8); + assert_eq!(str2_len.get_type(), llvm_usize); let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq"); - let func = ctx.module.get_function(&func_name).unwrap_or_else(|| { - ctx.module.add_function( - &func_name, - llvm_i1.fn_type( - &[ - str1_ptr.get_type().into(), - str1_len.get_type().into(), - str2_ptr.get_type().into(), - str2_len.get_type().into(), - ], - false, - ), - None, - ) - }); - - ctx.builder - .build_call( - func, - &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()], - "str_eq_call", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() + infer_and_call_function( + ctx, + &func_name, + Some(llvm_i1.into()), + &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()], + Some("str_eq_call"), + None, + ) + .map(BasicValueEnum::into_int_value) + .unwrap() } From 529fa67855e8f44373d90d7e7b5c7a08b9745f56 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 7 Feb 2025 14:01:10 +0800 Subject: [PATCH 37/49] [core] codegen: Add bool_to_int_type to replace bool_to_{i1,i8} Unifies the implementation for both functions. --- nac3core/src/codegen/expr.rs | 2 +- nac3core/src/codegen/generator.rs | 20 +++++++++--- nac3core/src/codegen/mod.rs | 52 +++++++++++-------------------- 3 files changed, 35 insertions(+), 39 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 20d296e..f4e03d0 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2001,7 +2001,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ).into_int_value(); let result = call_string_eq(ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); if *op == Cmpop::NotEq { - ctx.builder.build_not(result, "").unwrap() + ctx.builder.build_not(result, "").unwrap() } else { result } diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index 620ede0..42c7c71 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -7,7 +7,7 @@ use inkwell::{ use nac3parser::ast::{Expr, Stmt, StrRef}; -use super::{bool_to_i1, bool_to_i8, expr::*, stmt::*, values::ArraySliceValue, CodeGenContext}; +use super::{bool_to_int_type, expr::*, stmt::*, values::ArraySliceValue, CodeGenContext}; use crate::{ symbol_resolver::ValueEnum, toplevel::{DefinitionId, TopLevelDef}, @@ -248,22 +248,32 @@ pub trait CodeGenerator { gen_block(self, ctx, stmts) } - /// See [`bool_to_i1`]. + /// Converts the value of a boolean-like value `bool_value` into an `i1`. fn bool_to_i1<'ctx>( &self, ctx: &CodeGenContext<'ctx, '_>, bool_value: IntValue<'ctx>, ) -> IntValue<'ctx> { - bool_to_i1(&ctx.builder, bool_value) + self.bool_to_int_type(ctx, bool_value, ctx.ctx.bool_type()) } - /// See [`bool_to_i8`]. + /// Converts the value of a boolean-like value `bool_value` into an `i8`. fn bool_to_i8<'ctx>( &self, ctx: &CodeGenContext<'ctx, '_>, bool_value: IntValue<'ctx>, ) -> IntValue<'ctx> { - bool_to_i8(&ctx.builder, ctx.ctx, bool_value) + self.bool_to_int_type(ctx, bool_value, ctx.ctx.i8_type()) + } + + /// See [`bool_to_int_type`]. + fn bool_to_int_type<'ctx>( + &self, + ctx: &CodeGenContext<'ctx, '_>, + bool_value: IntValue<'ctx>, + ty: IntType<'ctx>, + ) -> IntValue<'ctx> { + bool_to_int_type(&ctx.builder, bool_value, ty) } } diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index a188d1c..f1b9cfb 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -933,7 +933,7 @@ pub fn gen_func_impl< let param_val = param.into_int_value(); if expected_ty.get_bit_width() == 8 && param_val.get_type().get_bit_width() == 1 { - bool_to_i8(&builder, context, param_val) + bool_to_int_type(&builder, param_val, context.i8_type()) } else { param_val } @@ -1103,43 +1103,29 @@ pub fn gen_func<'ctx, G: CodeGenerator>( }) } -/// Converts the value of a boolean-like value `bool_value` into an `i1`. -fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntValue<'ctx> { - if bool_value.get_type().get_bit_width() == 1 { - bool_value - } else { - builder - .build_int_compare( - IntPredicate::NE, - bool_value, - bool_value.get_type().const_zero(), - "tobool", - ) - .unwrap() - } -} - -/// Converts the value of a boolean-like value `bool_value` into an `i8`. -fn bool_to_i8<'ctx>( +/// Converts the value of a boolean-like value `value` into an arbitrary [`IntType`]. +/// +/// This has the same semantics as `(ty)(value != 0)` in C. +/// +/// The returned value is guaranteed to either be `0` or `1`, except for `ty == i1` where only the +/// least-significant bit would be guaranteed to be `0` or `1`. +fn bool_to_int_type<'ctx>( builder: &Builder<'ctx>, - ctx: &'ctx Context, - bool_value: IntValue<'ctx>, + value: IntValue<'ctx>, + ty: IntType<'ctx>, ) -> IntValue<'ctx> { - let value_bits = bool_value.get_type().get_bit_width(); - match value_bits { - 8 => bool_value, - 1 => builder.build_int_z_extend(bool_value, ctx.i8_type(), "frombool").unwrap(), - _ => bool_to_i8( + // i1 -> i1 : %value ; no-op + // i1 -> i : zext i1 %value to i ; guaranteed to be 0 or 1 - see docs + // i -> i: zext i1 (icmp eq i %value, 0) to i ; same as i -> i1 -> i + match (value.get_type().get_bit_width(), ty.get_bit_width()) { + (1, 1) => value, + (1, _) => builder.build_int_z_extend(value, ty, "frombool").unwrap(), + _ => bool_to_int_type( builder, - ctx, builder - .build_int_compare( - IntPredicate::NE, - bool_value, - bool_value.get_type().const_zero(), - "", - ) + .build_int_compare(IntPredicate::NE, value, value.get_type().const_zero(), "tobool") .unwrap(), + ty, ), } } From 0d8cb909dda0f3082e3279390e7d6ac27756a0fa Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 7 Feb 2025 10:38:19 +0800 Subject: [PATCH 38/49] [core] codegen/expr: Fix and use gen_unaryop_expr for boolean not ops While refactoring, I ran into the issue where `!true == true`, which was caused by the same upper 7-bit of booleans being undefined issue that was encountered before. It turns out the implementation in `gen_unaryop_expr` is also inadequate, as `(~v & (i1) 0x1)`` will still leave upper 7 bits undefined (for whatever reason). This commit fixes this issue once and for all by using a combination of `icmp` + `zext` to ensure that the resulting value must be `0 | 1`, and refactor to use that whenever we need to invert boolean values. --- nac3core/src/codegen/expr.rs | 39 ++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index f4e03d0..53aa5f1 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1704,11 +1704,12 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( Ok(Some(if ty == ctx.primitives.bool { let val = val.into_int_value(); if op == ast::Unaryop::Not { - let not = ctx.builder.build_not(val, "not").unwrap(); - let not_bool = - ctx.builder.build_and(not, not.get_type().const_int(1, false), "").unwrap(); + let not = ctx + .builder + .build_int_compare(IntPredicate::EQ, val, val.get_type().const_zero(), "not") + .unwrap(); - not_bool.into() + generator.bool_to_int_type(ctx, not, val.get_type()).into() } else { let llvm_i32 = ctx.ctx.i32_type(); @@ -2001,7 +2002,18 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ).into_int_value(); let result = call_string_eq(ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); if *op == Cmpop::NotEq { - ctx.builder.build_not(result, "").unwrap() + gen_unaryop_expr_with_values( + generator, + ctx, + Unaryop::Not, + (&Some(ctx.primitives.bool), result.into()), + ) + .transpose() + .unwrap() + .and_then(|res| { + res.to_basic_value_enum(ctx, generator, ctx.primitives.bool) + })? + .into_int_value() } else { result } @@ -2248,8 +2260,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( .unwrap() .and_then(|v| { v.to_basic_value_enum(ctx, generator, ctx.primitives.bool) - }) - .map(BasicValueEnum::into_int_value)?; + })? + .into_int_value(); Ok(ctx.builder.build_not( generator.bool_to_i1(ctx, cmp), @@ -2285,7 +2297,18 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( // Invert the final value if __ne__ if *op == Cmpop::NotEq { - ctx.builder.build_not(cmp_phi, "").unwrap() + gen_unaryop_expr_with_values( + generator, + ctx, + Unaryop::Not, + (&Some(ctx.primitives.bool), cmp_phi.into()) + ) + .transpose() + .unwrap() + .and_then(|res| { + res.to_basic_value_enum(ctx, generator, ctx.primitives.bool) + })? + .into_int_value() } else { cmp_phi } From c37c7e8975ea098b1b2ddcc1f2cab6138b958533 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 7 Feb 2025 10:39:21 +0800 Subject: [PATCH 39/49] [core] codegen/expr: Simplify `gen_*_expr_with_values` return value These functions always return `BasicValueEnum` because they operate on `BasicValueEnum`s, and they also always return a value. --- nac3core/src/codegen/expr.rs | 88 ++++++------------- nac3core/src/codegen/values/ndarray/matmul.rs | 8 +- 2 files changed, 30 insertions(+), 66 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 53aa5f1..986ed99 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1319,7 +1319,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( op: Binop, right: (&Option, BasicValueEnum<'ctx>), loc: Location, -) -> Result>, String> { +) -> Result, String> { let (left_ty, left_val) = left; let (right_ty, right_val) = right; @@ -1330,14 +1330,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( // which would be unchanged until further unification, which we would never do // when doing code generation for function instances if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { - Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, true).into())) + Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, true)) } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { - Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, false).into())) + Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, false)) } else if [Operator::LShift, Operator::RShift].contains(&op.base) { let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1); - Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed).into())) + Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed)) } else if ty1 == ty2 && ctx.primitives.float == ty1 { - Ok(Some(ctx.gen_float_ops(op.base, left_val, right_val).into())) + Ok(ctx.gen_float_ops(op.base, left_val, right_val)) } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { // Pow is the only operator that would pass typecheck between float and int assert_eq!(op.base, Operator::Pow); @@ -1347,7 +1347,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( right_val.into_int_value(), Some("f_pow_i"), ); - Ok(Some(res.into())) + Ok(res.into()) } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { @@ -1437,7 +1437,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().const_zero(), ); - Ok(Some(new_list.as_abi_value(ctx).into())) + Ok(new_list.as_abi_value(ctx).into()) } Operator::Mult => { @@ -1524,7 +1524,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( llvm_usize.const_int(1, false), )?; - Ok(Some(new_list.as_abi_value(ctx).into())) + Ok(new_list.as_abi_value(ctx).into()) } _ => todo!("Operator not supported"), @@ -1563,7 +1563,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let result = left .matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out)) .split_unsized(generator, ctx); - Ok(Some(result.to_basic_value_enum().into())) + Ok(result.to_basic_value_enum()) } else { // For other operations, they are all elementwise operations. @@ -1594,14 +1594,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( op, (&Some(ty2_dtype), right_value), ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, common_dtype)?; + )?; Ok(result) }) .unwrap(); - Ok(Some(result.as_abi_value(ctx).into())) + Ok(result.as_abi_value(ctx).into()) } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); @@ -1650,7 +1648,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( (&signature, fun_id), vec![(None, right_val.into())], ) - .map(|f| f.map(Into::into)) + .map(Option::unwrap) + .map(BasicValueEnum::into) } } @@ -1688,6 +1687,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( (&right.custom, right_val), loc, ) + .map(|res| Some(res.into())) } /// Generates LLVM IR for a unary operator expression using the [`Type`] and @@ -1697,11 +1697,11 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( ctx: &mut CodeGenContext<'ctx, '_>, op: ast::Unaryop, operand: (&Option, BasicValueEnum<'ctx>), -) -> Result>, String> { +) -> Result, String> { let (ty, val) = operand; let ty = ctx.unifier.get_representative(ty.unwrap()); - Ok(Some(if ty == ctx.primitives.bool { + Ok(if ty == ctx.primitives.bool { let val = val.into_int_value(); if op == ast::Unaryop::Not { let not = ctx @@ -1722,7 +1722,6 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap(), ), )? - .unwrap() } } else if [ ctx.primitives.int32, @@ -1791,16 +1790,14 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( ctx, NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() }, |generator, ctx, scalar| { - gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))? - .map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype)) - .unwrap() + gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar)) }, )?; mapped_ndarray.as_abi_value(ctx).into() } else { unimplemented!() - })) + }) } /// Generates LLVM IR for a unary operator expression. @@ -1820,6 +1817,7 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>( }; gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val)) + .map(|res| Some(res.into())) } /// Generates LLVM IR for a comparison operator expression using the [`Type`] and @@ -1830,7 +1828,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( left: (Option, BasicValueEnum<'ctx>), ops: &[ast::Cmpop], comparators: &[(Option, BasicValueEnum<'ctx>)], -) -> Result>, String> { +) -> Result, String> { debug_assert_eq!(comparators.len(), ops.len()); if comparators.len() == 1 { @@ -1872,19 +1870,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( (Some(left_ty_dtype), left_scalar), &[op], &[(Some(right_ty_dtype), right_scalar)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, )?; Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) }, )?; - return Ok(Some(result_ndarray.as_abi_value(ctx).into())); + return Ok(result_ndarray.as_abi_value(ctx).into()); } } @@ -2007,13 +1999,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ctx, Unaryop::Not, (&Some(ctx.primitives.bool), result.into()), - ) - .transpose() - .unwrap() - .and_then(|res| { - res.to_basic_value_enum(ctx, generator, ctx.primitives.bool) - })? - .into_int_value() + )?.into_int_value() } else { result } @@ -2116,9 +2102,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( &[Cmpop::Eq], &[(Some(right_elem_ty), right)], )? - .unwrap() - .to_basic_value_enum(ctx, generator, ctx.primitives.bool) - .unwrap() .into_int_value(); gen_if_callback( @@ -2167,8 +2150,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( Unaryop::Not, (&Some(ctx.primitives.bool), acc.into()), )? - .unwrap() - .to_basic_value_enum(ctx, generator, ctx.primitives.bool)? .into_int_value() } else { acc @@ -2256,12 +2237,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( &[op], &[(Some(right_ty), right_elem)], ) - .transpose() - .unwrap() - .and_then(|v| { - v.to_basic_value_enum(ctx, generator, ctx.primitives.bool) - })? - .into_int_value(); + .map(BasicValueEnum::into_int_value)?; Ok(ctx.builder.build_not( generator.bool_to_i1(ctx, cmp), @@ -2301,14 +2277,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( generator, ctx, Unaryop::Not, - (&Some(ctx.primitives.bool), cmp_phi.into()) - ) - .transpose() - .unwrap() - .and_then(|res| { - res.to_basic_value_enum(ctx, generator, ctx.primitives.bool) - })? - .into_int_value() + (&Some(ctx.primitives.bool), cmp_phi.into()), + )?.into_int_value() } else { cmp_phi } @@ -2333,12 +2303,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( }; Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) - })?; + })?.unwrap(); - Ok(Some(match cmp_val { - Some(v) => v.into(), - None => return Ok(None), - })) + Ok(cmp_val.into()) } /// Generates LLVM IR for a comparison operator expression. @@ -2385,6 +2352,7 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( ops, comparator_vals.as_slice(), ) + .map(|res| Some(res.into())) } /// See [`CodeGenerator::gen_expr`]. diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs index f12d36c..cc8d059 100644 --- a/nac3core/src/codegen/values/ndarray/matmul.rs +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -213,9 +213,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( Binop::normal(Operator::Mult), (&Some(rhs_dtype), b_kj), ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, dst_dtype)?; + )?; // dst_[...]ij += x let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap(); @@ -226,9 +224,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( Binop::normal(Operator::Add), (&Some(dst_dtype), x), ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, dst_dtype)?; + )?; ctx.builder.build_store(pdst_ij, dst_ij).unwrap(); Ok(()) From a078481cd2c7e1b51aa991d4fd5c00b6f48a6818 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 3 Feb 2025 13:27:45 +0800 Subject: [PATCH 40/49] [meta] Minor simplification for PrimStore extraction --- nac3artiq/src/codegen.rs | 12 +++++------- nac3core/src/toplevel/builtins.rs | 4 +--- nac3core/src/toplevel/composer.rs | 3 +-- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index fb6992b..d086420 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -41,7 +41,10 @@ use nac3core::{ numpy::unpack_ndarray_var_tys, DefinitionId, GenCall, }, - typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, + typecheck::{ + type_inferencer::PrimitiveStore, + typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, + }, }; /// The parallelism mode within a block. @@ -389,12 +392,7 @@ fn gen_rpc_tag( ) -> Result<(), String> { use nac3core::typecheck::typedef::TypeEnum::*; - let int32 = ctx.primitives.int32; - let int64 = ctx.primitives.int64; - let float = ctx.primitives.float; - let bool = ctx.primitives.bool; - let str = ctx.primitives.str; - let none = ctx.primitives.none; + let PrimitiveStore { int32, int64, float, bool, str, none, .. } = ctx.primitives; if ctx.unifier.unioned(ty, int32) { buffer.push(b'i'); diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index eff614e..c9b5d22 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -36,9 +36,7 @@ pub fn get_exn_constructor( unifier: &mut Unifier, primitives: &PrimitiveStore, ) -> (TopLevelDef, TopLevelDef, Type, Type) { - let int32 = primitives.int32; - let int64 = primitives.int64; - let string = primitives.str; + let PrimitiveStore { int32, int64, str: string, .. } = *primitives; let exception_fields = make_exception_fields(int32, int64, string); let exn_cons_args = vec![ FuncArg { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index a6a0ce7..50d6dd2 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1521,8 +1521,7 @@ impl TopLevelComposer { .any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) { // create constructor for these classes - let string = primitives_ty.str; - let int64 = primitives_ty.int64; + let PrimitiveStore { str: string, int64, .. } = *primitives_ty; let signature = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ FuncArg { From 2df22e29f738ab214893c6d2f0e083618d9a232d Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 10 Feb 2025 11:18:08 +0800 Subject: [PATCH 41/49] [core] codegen: Simplify TupleType::construct --- nac3core/src/codegen/types/tuple.rs | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index ea66feb..3facf5e 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -115,15 +115,8 @@ impl<'ctx> TupleType<'ctx> { /// Constructs a [`TupleValue`] from this type by zero-initializing the tuple value. #[must_use] - pub fn construct( - &self, - ctx: &CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> >::Value { - self.map_struct_value( - Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(), - name, - ) + pub fn construct(&self, name: Option<&'ctx str>) -> >::Value { + self.map_struct_value(self.as_abi_type().const_zero(), name) } /// Constructs a [`TupleValue`] from `objects`. The resulting tuple preserves the order of @@ -143,7 +136,7 @@ impl<'ctx> TupleType<'ctx> { .enumerate() .all(|(i, v)| { v.get_type() == unsafe { self.type_at_index_unchecked(i as u32) } })); - let mut value = self.construct(ctx, name); + let mut value = self.construct(name); for (i, val) in values.into_iter().enumerate() { value.store_element(ctx, i as u32, val); } From 69542c38a2bbac50401c111c29c907d04da69641 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 10 Feb 2025 11:07:01 +0800 Subject: [PATCH 42/49] [core] codegen: Rename TupleValue::{store,load} -> {insert,extract} Better matches the underlying operation. --- nac3core/src/codegen/types/tuple.rs | 2 +- nac3core/src/codegen/values/ndarray/shape.rs | 2 +- nac3core/src/codegen/values/tuple.rs | 8 ++++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index 3facf5e..90abeb3 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -138,7 +138,7 @@ impl<'ctx> TupleType<'ctx> { let mut value = self.construct(name); for (i, val) in values.into_iter().enumerate() { - value.store_element(ctx, i as u32, val); + value.insert_element(ctx, i as u32, val); } value diff --git a/nac3core/src/codegen/values/ndarray/shape.rs b/nac3core/src/codegen/values/ndarray/shape.rs index b3331b6..69e8b50 100644 --- a/nac3core/src/codegen/values/ndarray/shape.rs +++ b/nac3core/src/codegen/values/ndarray/shape.rs @@ -106,7 +106,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( for i in 0..input_seq.get_type().num_elements() { // Get the i-th element off of the tuple and load it into `result`. - let int = input_seq.load_element(ctx, i).into_int_value(); + let int = input_seq.extract_element(ctx, i).into_int_value(); let int = ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap(); unsafe { diff --git a/nac3core/src/codegen/values/tuple.rs b/nac3core/src/codegen/values/tuple.rs index 320e219..1f124c8 100644 --- a/nac3core/src/codegen/values/tuple.rs +++ b/nac3core/src/codegen/values/tuple.rs @@ -45,7 +45,7 @@ impl<'ctx> TupleValue<'ctx> { } /// Stores a value into the tuple element at the given `index`. - pub fn store_element( + pub fn insert_element( &mut self, ctx: &CodeGenContext<'ctx, '_>, index: u32, @@ -63,7 +63,11 @@ impl<'ctx> TupleValue<'ctx> { } /// Loads a value from the tuple element at the given `index`. - pub fn load_element(&self, ctx: &CodeGenContext<'ctx, '_>, index: u32) -> BasicValueEnum<'ctx> { + pub fn extract_element( + &self, + ctx: &CodeGenContext<'ctx, '_>, + index: u32, + ) -> BasicValueEnum<'ctx> { ctx.builder .build_extract_value( self.value, From 67f42185de653fc7a7f315dc6ae1c007bfdbda16 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 10 Feb 2025 10:37:55 +0800 Subject: [PATCH 43/49] [core] codegen/expr: Add concrete ndims value to error message --- nac3core/src/codegen/expr.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 986ed99..cd9b87d 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -43,7 +43,7 @@ use super::{ use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{ - helper::{arraylike_flatten_element_type, PrimDef}, + helper::{arraylike_flatten_element_type, extract_ndims, PrimDef}, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef, }, @@ -1775,10 +1775,13 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( if op == ast::Unaryop::Invert { ast::Unaryop::Not } else { + let ndims = extract_ndims(&ctx.unifier, ty); + codegen_unreachable!( ctx, - "ufunc {} not supported for ndarray[bool, N]", + "ufunc {} not supported for ndarray[bool, {}]", op.op_info().method_name, + ndims, ) } } else { From 0a761cb2637ef9bd97b92362ecc214884883851b Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 3 Feb 2025 15:23:07 +0800 Subject: [PATCH 44/49] [core] Use more TupleType constructors --- nac3core/src/codegen/expr.rs | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index cd9b87d..bab3b75 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -32,7 +32,7 @@ use super::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, }, - types::{ndarray::NDArrayType, ListType, RangeType}, + types::{ndarray::NDArrayType, ListType, RangeType, TupleType}, values::{ ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, @@ -180,23 +180,10 @@ impl<'ctx> CodeGenContext<'ctx, '_> { SymbolValue::Tuple(ls) => { let vals = ls.iter().map(|v| self.gen_symbol_val(generator, v, ty)).collect_vec(); let fields = vals.iter().map(BasicValueEnum::get_type).collect_vec(); - let ty = self.ctx.struct_type(&fields, false); - let ptr = gen_var(self, ty.into(), Some("tuple")).unwrap(); - let zero = self.ctx.i32_type().const_zero(); - unsafe { - for (i, val) in vals.into_iter().enumerate() { - let p = self - .builder - .build_in_bounds_gep( - ptr, - &[zero, self.ctx.i32_type().const_int(i as u64, false)], - "elemptr", - ) - .unwrap(); - self.builder.build_store(p, val).unwrap(); - } - } - self.builder.build_load(ptr, "tup_val").unwrap() + TupleType::new(self, &fields) + .construct_from_objects(self, vals, Some("tup_val")) + .as_abi_value(self) + .into() } SymbolValue::OptionSome(v) => { let ty = match self.unifier.get_ty_immutable(ty).as_ref() { From 35e9c5b38e2bfc240f990b789bca632ed795e1d6 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 3 Feb 2025 15:43:48 +0800 Subject: [PATCH 45/49] [core] codegen: Add String{Type,Value} --- nac3core/src/codegen/expr.rs | 59 ++------- nac3core/src/codegen/irrt/string.rs | 26 ++-- nac3core/src/codegen/mod.rs | 16 +-- nac3core/src/codegen/types/mod.rs | 2 + nac3core/src/codegen/types/string.rs | 177 ++++++++++++++++++++++++++ nac3core/src/codegen/values/mod.rs | 2 + nac3core/src/codegen/values/string.rs | 87 +++++++++++++ 7 files changed, 290 insertions(+), 79 deletions(-) create mode 100644 nac3core/src/codegen/types/string.rs create mode 100644 nac3core/src/codegen/values/string.rs diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index bab3b75..c398ed9 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -32,7 +32,7 @@ use super::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, }, - types::{ndarray::NDArrayType, ListType, RangeType, TupleType}, + types::{ndarray::NDArrayType, ListType, RangeType, StringType, TupleType}, values::{ ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, @@ -168,14 +168,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> { SymbolValue::Bool(v) => self.ctx.i8_type().const_int(u64::from(*v), true).into(), SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(), SymbolValue::Str(v) => { - let str_ptr = self - .builder - .build_global_string_ptr(v, "const") - .map(|v| v.as_pointer_value().into()) - .unwrap(); - let size = self.get_size_type().const_int(v.len() as u64, false); - let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type(); - ty.const_named_struct(&[str_ptr, size.into()]).into() + StringType::new(self).construct_constant(self, v, None).as_abi_value(self).into() } SymbolValue::Tuple(ls) => { let vals = ls.iter().map(|v| self.gen_symbol_val(generator, v, ty)).collect_vec(); @@ -308,15 +301,10 @@ impl<'ctx> CodeGenContext<'ctx, '_> { if let Some(v) = self.const_strings.get(v) { Some(*v) } else { - let str_ptr = self - .builder - .build_global_string_ptr(v, "const") - .map(|v| v.as_pointer_value().into()) - .unwrap(); - let size = self.get_size_type().const_int(v.len() as u64, false); - let ty = self.get_llvm_type(generator, self.primitives.str); - let val = - ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into(); + let val = StringType::new(self) + .construct_constant(self, v, None) + .as_abi_value(self) + .into(); self.const_strings.insert(v.to_string(), val); Some(val) } @@ -1950,39 +1938,12 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } else if left_ty == ctx.primitives.str { assert!(ctx.unifier.unioned(left_ty, right_ty)); - let lhs = lhs.into_struct_value(); - let rhs = rhs.into_struct_value(); + let llvm_str = StringType::new(ctx); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = ctx.get_size_type(); + let lhs = llvm_str.map_struct_value(lhs.into_struct_value(), None); + let rhs = llvm_str.map_struct_value(rhs.into_struct_value(), None); - let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); - ctx.builder.build_store(plhs, lhs).unwrap(); - let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); - ctx.builder.build_store(prhs, rhs).unwrap(); - - let lhs_ptr = ctx.build_in_bounds_gep_and_load( - plhs, - &[llvm_usize.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - let lhs_len = ctx.build_in_bounds_gep_and_load( - plhs, - &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], - None, - ).into_int_value(); - - let rhs_ptr = ctx.build_in_bounds_gep_and_load( - prhs, - &[llvm_usize.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - let rhs_len = ctx.build_in_bounds_gep_and_load( - prhs, - &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], - None, - ).into_int_value(); - let result = call_string_eq(ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); + let result = call_string_eq(ctx, lhs, rhs); if *op == Cmpop::NotEq { gen_unaryop_expr_with_values( generator, diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs index e015570..c7e4eeb 100644 --- a/nac3core/src/codegen/irrt/string.rs +++ b/nac3core/src/codegen/irrt/string.rs @@ -1,26 +1,15 @@ -use inkwell::{ - values::{BasicValueEnum, IntValue, PointerValue}, - AddressSpace, -}; +use inkwell::values::{BasicValueEnum, IntValue}; use super::get_usize_dependent_function_name; -use crate::codegen::{expr::infer_and_call_function, CodeGenContext}; +use crate::codegen::{expr::infer_and_call_function, values::StringValue, CodeGenContext}; /// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. pub fn call_string_eq<'ctx>( ctx: &CodeGenContext<'ctx, '_>, - str1_ptr: PointerValue<'ctx>, - str1_len: IntValue<'ctx>, - str2_ptr: PointerValue<'ctx>, - str2_len: IntValue<'ctx>, + str1: StringValue<'ctx>, + str2: StringValue<'ctx>, ) -> IntValue<'ctx> { let llvm_i1 = ctx.ctx.bool_type(); - let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let llvm_usize = ctx.get_size_type(); - assert_eq!(str1_ptr.get_type(), llvm_pi8); - assert_eq!(str1_len.get_type(), llvm_usize); - assert_eq!(str2_ptr.get_type(), llvm_pi8); - assert_eq!(str2_len.get_type(), llvm_usize); let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq"); @@ -28,7 +17,12 @@ pub fn call_string_eq<'ctx>( ctx, &func_name, Some(llvm_i1.into()), - &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()], + &[ + str1.extract_ptr(ctx).into(), + str1.extract_len(ctx).into(), + str2.extract_ptr(ctx).into(), + str2.extract_len(ctx).into(), + ], Some("str_eq_call"), None, ) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index f1b9cfb..5b6fa21 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -43,7 +43,7 @@ use crate::{ }; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; -use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, TupleType}; +use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, StringType, TupleType}; pub mod builtin_fns; pub mod concrete_type; @@ -786,19 +786,7 @@ pub fn gen_func_impl< (primitives.float, context.f64_type().into()), (primitives.bool, context.i8_type().into()), (primitives.str, { - let name = "str"; - match module.get_struct_type(name) { - None => { - let str_type = context.opaque_struct_type("str"); - let fields = [ - context.i8_type().ptr_type(AddressSpace::default()).into(), - generator.get_size_type(context).into(), - ]; - str_type.set_body(&fields, false); - str_type.into() - } - Some(t) => t.as_basic_type_enum(), - } + StringType::new_with_generator(generator, context).as_abi_type().into() }), (primitives.range, RangeType::new_with_generator(generator, context).as_abi_type().into()), (primitives.exception, { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index abeab5b..bceb804 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -27,11 +27,13 @@ use super::{ }; pub use list::*; pub use range::*; +pub use string::*; pub use tuple::*; mod list; pub mod ndarray; mod range; +mod string; pub mod structure; mod tuple; pub mod utils; diff --git a/nac3core/src/codegen/types/string.rs b/nac3core/src/codegen/types/string.rs new file mode 100644 index 0000000..eae275d --- /dev/null +++ b/nac3core/src/codegen/types/string.rs @@ -0,0 +1,177 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{GlobalValue, IntValue, PointerValue, StructValue}, + AddressSpace, +}; +use itertools::Itertools; + +use nac3core_derive::StructFields; + +use super::{ + structure::{check_struct_type_matches_fields, StructField, StructFields}, + ProxyType, +}; +use crate::codegen::{values::StringValue, CodeGenContext, CodeGenerator}; + +/// Proxy type for a `str` type in LLVM. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct StringType<'ctx> { + ty: StructType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct StringStructFields<'ctx> { + /// Pointer to the first character of the string. + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + pub ptr: StructField<'ctx, PointerValue<'ctx>>, + + /// Length of the string. + #[value_type(usize)] + pub len: StructField<'ctx, IntValue<'ctx>>, +} + +impl<'ctx> StringType<'ctx> { + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields(llvm_usize: IntType<'ctx>) -> StringStructFields<'ctx> { + StringStructFields::new(llvm_usize.get_context(), llvm_usize) + } + + /// Creates an LLVM type corresponding to the expected structure of a `str`. + #[must_use] + fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> StructType<'ctx> { + const NAME: &str = "str"; + + if let Some(t) = ctx.get_struct_type(NAME) { + t + } else { + let str_ty = ctx.opaque_struct_type(NAME); + let field_tys = Self::fields(llvm_usize).into_iter().map(|field| field.1).collect_vec(); + str_ty.set_body(&field_tys, false); + str_ty + } + } + + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { + let llvm_str = Self::llvm_type(ctx, llvm_usize); + + Self { ty: llvm_str, llvm_usize } + } + + /// Creates an instance of [`StringType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`StringType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + + /// Creates an [`StringType`] from a [`StructType`] representing a `str`. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ty, llvm_usize).is_ok()); + + Self { ty, llvm_usize } + } + + /// Creates an [`StringType`] from a [`PointerType`] representing a `str`. + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_struct_type(ptr_ty.get_element_type().into_struct_type(), llvm_usize) + } + + /// Returns the fields present in this [`StringType`]. + #[must_use] + pub fn get_fields(&self) -> StringStructFields<'ctx> { + Self::fields(self.llvm_usize) + } + + /// Constructs a global constant string. + #[must_use] + pub fn construct_constant( + &self, + ctx: &CodeGenContext<'ctx, '_>, + v: &str, + name: Option<&'ctx str>, + ) -> StringValue<'ctx> { + let str_ptr = ctx + .builder + .build_global_string_ptr(v, "const") + .map(GlobalValue::as_pointer_value) + .unwrap(); + let size = ctx.get_size_type().const_int(v.len() as u64, false); + self.map_struct_value( + self.as_abi_type().const_named_struct(&[str_ptr.into(), size.into()]), + name, + ) + } + + /// Converts an existing value into a [`StringValue`]. + #[must_use] + pub fn map_struct_value( + &self, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value(value, self.llvm_usize, name) + } + + /// Converts an existing value into a [`StringValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + value: PointerValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(ctx, value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for StringType<'ctx> { + type ABI = StructType<'ctx>; + type Base = StructType<'ctx>; + type Value = StringValue<'ctx>; + + fn is_representable( + llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::StructType(ty) = llvm_ty.as_basic_type_enum() { + Self::has_same_repr(ty, llvm_usize) + } else { + Err(format!("Expected structure type, got {llvm_ty:?}")) + } + } + + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + check_struct_type_matches_fields(Self::fields(llvm_usize), ty, "str", &[]) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_abi_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> From> for StructType<'ctx> { + fn from(value: StringType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index 90f327e..cf125fe 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -4,12 +4,14 @@ use super::{types::ProxyType, CodeGenContext}; pub use array::*; pub use list::*; pub use range::*; +pub use string::*; pub use tuple::*; mod array; mod list; pub mod ndarray; mod range; +mod string; pub mod structure; mod tuple; pub mod utils; diff --git a/nac3core/src/codegen/values/string.rs b/nac3core/src/codegen/values/string.rs new file mode 100644 index 0000000..a4c8bea --- /dev/null +++ b/nac3core/src/codegen/values/string.rs @@ -0,0 +1,87 @@ +use inkwell::{ + types::IntType, + values::{BasicValueEnum, IntValue, PointerValue, StructValue}, +}; + +use crate::codegen::{ + types::{structure::StructField, StringType}, + values::ProxyValue, + CodeGenContext, +}; + +/// Proxy type for accessing a `str` value in LLVM. +#[derive(Copy, Clone)] +pub struct StringValue<'ctx> { + value: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> StringValue<'ctx> { + /// Creates an [`StringValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + val: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(val, llvm_usize).is_ok()); + + Self { value: val, llvm_usize, name } + } + + /// Creates an [`StringValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ctx: &CodeGenContext<'ctx, '_>, + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let val = ctx.builder.build_load(ptr, "").map(BasicValueEnum::into_struct_value).unwrap(); + + Self::from_struct_value(val, llvm_usize, name) + } + + fn ptr_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().ptr + } + + /// Returns the pointer to the beginning of the string. + pub fn extract_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.ptr_field().extract_value(ctx, self.value) + } + + fn len_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().len + } + + /// Returns the length of the string. + pub fn extract_len(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + self.len_field().extract_value(ctx, self.value) + } +} + +impl<'ctx> ProxyValue<'ctx> for StringValue<'ctx> { + type ABI = StructValue<'ctx>; + type Base = StructValue<'ctx>; + type Type = StringType<'ctx>; + + fn get_type(&self) -> Self::Type { + Self::Type::from_struct_type(self.value.get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } +} + +impl<'ctx> From> for StructValue<'ctx> { + fn from(value: StringValue<'ctx>) -> Self { + value.as_base_value() + } +} From 57552fb2f641c34e81537dfbca90aab238a535b2 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 3 Feb 2025 17:04:37 +0800 Subject: [PATCH 46/49] [core] codegen: Add Option{Type,Value} --- nac3core/src/codegen/expr.rs | 66 +++------ nac3core/src/codegen/mod.rs | 12 +- nac3core/src/codegen/types/mod.rs | 2 + nac3core/src/codegen/types/option.rs | 188 ++++++++++++++++++++++++++ nac3core/src/codegen/values/mod.rs | 2 + nac3core/src/codegen/values/option.rs | 75 ++++++++++ 6 files changed, 296 insertions(+), 49 deletions(-) create mode 100644 nac3core/src/codegen/types/option.rs create mode 100644 nac3core/src/codegen/values/option.rs diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index c398ed9..3f48154 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -32,7 +32,7 @@ use super::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, }, - types::{ndarray::NDArrayType, ListType, RangeType, StringType, TupleType}, + types::{ndarray::NDArrayType, ListType, OptionType, RangeType, StringType, TupleType}, values::{ ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, @@ -179,34 +179,16 @@ impl<'ctx> CodeGenContext<'ctx, '_> { .into() } SymbolValue::OptionSome(v) => { - let ty = match self.unifier.get_ty_immutable(ty).as_ref() { - TypeEnum::TObj { obj_id, params, .. } - if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() => - { - *params.iter().next().unwrap().1 - } - _ => codegen_unreachable!(self, "must be option type"), - }; let val = self.gen_symbol_val(generator, v, ty); - let ptr = generator - .gen_var_alloc(self, val.get_type(), Some("default_opt_some")) - .unwrap(); - self.builder.build_store(ptr, val).unwrap(); - ptr.into() - } - SymbolValue::OptionNone => { - let ty = match self.unifier.get_ty_immutable(ty).as_ref() { - TypeEnum::TObj { obj_id, params, .. } - if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() => - { - *params.iter().next().unwrap().1 - } - _ => codegen_unreachable!(self, "must be option type"), - }; - let actual_ptr_type = - self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default()); - actual_ptr_type.const_null().into() + OptionType::from_unifier_type(generator, self, ty) + .construct_some_value(generator, self, &val, None) + .as_abi_value(self) + .into() } + SymbolValue::OptionNone => OptionType::from_unifier_type(generator, self, ty) + .construct_empty(generator, self, None) + .as_abi_value(self) + .into(), } } @@ -2333,16 +2315,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( const_val.into() } ExprKind::Name { id, .. } if id == &"none".into() => { - match ( - ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(), - ctx.unifier.get_ty(ctx.primitives.option).as_ref(), - ) { - (TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. }) - if *obj_id == *opt_id => + match &*ctx.unifier.get_ty(expr.custom.unwrap()) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => { - ctx.get_llvm_type(generator, *params.iter().next().unwrap().1) - .ptr_type(AddressSpace::default()) - .const_null() + OptionType::from_unifier_type(generator, ctx, expr.custom.unwrap()) + .construct_empty(generator, ctx, None) + .as_abi_value(ctx) .into() } _ => codegen_unreachable!(ctx, "must be option type"), @@ -2827,8 +2806,12 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( }; } ValueEnum::Dynamic(BasicValueEnum::PointerValue(ptr)) => { - let not_null = - ctx.builder.build_is_not_null(ptr, "unwrap_not_null").unwrap(); + let option = OptionType::from_pointer_type( + ptr.get_type(), + ctx.get_size_type(), + ) + .map_pointer_value(ptr, None); + let not_null = option.is_some(ctx); ctx.make_assert( generator, not_null, @@ -2837,12 +2820,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( [None, None, None], expr.location, ); - return Ok(Some( - ctx.builder - .build_load(ptr, "unwrap_some_load") - .map(Into::into) - .unwrap(), - )); + return Ok(Some(unsafe { option.load(ctx).into() })); } ValueEnum::Dynamic(_) => { codegen_unreachable!(ctx, "option must be static or ptr") diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 5b6fa21..b9c743c 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -43,7 +43,9 @@ use crate::{ }; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; -use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, StringType, TupleType}; +use types::{ + ndarray::NDArrayType, ListType, OptionType, ProxyType, RangeType, StringType, TupleType, +}; pub mod builtin_fns; pub mod concrete_type; @@ -538,7 +540,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( if PrimDef::contains_id(*obj_id) { return match &*unifier.get_ty_immutable(ty) { TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => { - get_llvm_type( + let element_type = get_llvm_type( ctx, module, generator, @@ -546,9 +548,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( top_level, type_cache, *params.iter().next().unwrap().1, - ) - .ptr_type(AddressSpace::default()) - .into() + ); + + OptionType::new_with_generator(generator, ctx, &element_type).as_abi_type().into() } TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index bceb804..cbab600 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -26,12 +26,14 @@ use super::{ {CodeGenContext, CodeGenerator}, }; pub use list::*; +pub use option::*; pub use range::*; pub use string::*; pub use tuple::*; mod list; pub mod ndarray; +mod option; mod range; mod string; pub mod structure; diff --git a/nac3core/src/codegen/types/option.rs b/nac3core/src/codegen/types/option.rs new file mode 100644 index 0000000..6347e5a --- /dev/null +++ b/nac3core/src/codegen/types/option.rs @@ -0,0 +1,188 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, IntType, PointerType}, + values::{BasicValue, BasicValueEnum, PointerValue}, + AddressSpace, +}; + +use super::ProxyType; +use crate::{ + codegen::{values::OptionValue, CodeGenContext, CodeGenerator}, + typecheck::typedef::{iter_type_vars, Type, TypeEnum}, +}; + +/// Proxy type for an `Option` type in LLVM. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct OptionType<'ctx> { + ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +impl<'ctx> OptionType<'ctx> { + /// Creates an LLVM type corresponding to the expected structure of an `Option`. + #[must_use] + fn llvm_type(element_type: &impl BasicType<'ctx>) -> PointerType<'ctx> { + element_type.ptr_type(AddressSpace::default()) + } + + fn new_impl(element_type: &impl BasicType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let llvm_option = Self::llvm_type(element_type); + + Self { ty: llvm_option, llvm_usize } + } + + /// Creates an instance of [`OptionType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>, element_type: &impl BasicType<'ctx>) -> Self { + Self::new_impl(element_type, ctx.get_size_type()) + } + + /// Creates an instance of [`OptionType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + element_type: &impl BasicType<'ctx>, + ) -> Self { + Self::new_impl(element_type, generator.get_size_type(ctx)) + } + + /// Creates an [`OptionType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ) -> Self { + // Check unifier type and extract `element_type` + let elem_type = match &*ctx.unifier.get_ty_immutable(ty) { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => + { + iter_type_vars(params).next().unwrap().ty + } + + _ => panic!("Expected `option` type, but got {}", ctx.unifier.stringify(ty)), + }; + + let llvm_usize = ctx.get_size_type(); + let llvm_elem_type = ctx.get_llvm_type(generator, elem_type); + + Self::new_impl(&llvm_elem_type, llvm_usize) + } + + /// Creates an [`OptionType`] from a [`PointerType`]. + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); + + Self { ty: ptr_ty, llvm_usize } + } + + /// Returns the element type of this `Option` type. + #[must_use] + pub fn element_type(&self) -> BasicTypeEnum<'ctx> { + BasicTypeEnum::try_from(self.ty.get_element_type()).unwrap() + } + + /// Allocates an [`OptionValue`] on the stack. + /// + /// The returned value will be `Some(v)` if [`value` contains a value][Option::is_some], + /// otherwise `none` will be returned. + #[must_use] + pub fn construct( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: Option>, + name: Option<&'ctx str>, + ) -> >::Value { + let ptr = if let Some(v) = value { + let pvar = self.raw_alloca_var(generator, ctx, name); + ctx.builder.build_store(pvar, v).unwrap(); + pvar + } else { + self.ty.const_null() + }; + + self.map_pointer_value(ptr, name) + } + /// Allocates an [`OptionValue`] on the stack. + /// + /// The returned value will always be `none`. + #[must_use] + pub fn construct_empty( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + self.construct(generator, ctx, None, name) + } + + /// Allocates an [`OptionValue`] on the stack. + /// + /// The returned value will be set to `Some(value)`. + #[must_use] + pub fn construct_some_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: &impl BasicValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + self.construct(generator, ctx, Some(value.as_basic_value_enum()), name) + } + + /// Converts an existing value into a [`OptionValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for OptionType<'ctx> { + type ABI = PointerType<'ctx>; + type Base = PointerType<'ctx>; + type Value = OptionValue<'ctx>; + + fn is_representable( + llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + Self::has_same_repr(ty, llvm_usize) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn has_same_repr(ty: Self::Base, _: IntType<'ctx>) -> Result<(), String> { + BasicTypeEnum::try_from(ty.get_element_type()) + .map_err(|()| format!("Expected `ty` to be a BasicTypeEnum, got {ty}"))?; + + Ok(()) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.element_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: OptionType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index cf125fe..7a43ba4 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -3,6 +3,7 @@ use inkwell::{types::IntType, values::BasicValue}; use super::{types::ProxyType, CodeGenContext}; pub use array::*; pub use list::*; +pub use option::*; pub use range::*; pub use string::*; pub use tuple::*; @@ -10,6 +11,7 @@ pub use tuple::*; mod array; mod list; pub mod ndarray; +mod option; mod range; mod string; pub mod structure; diff --git a/nac3core/src/codegen/values/option.rs b/nac3core/src/codegen/values/option.rs new file mode 100644 index 0000000..7fca60f --- /dev/null +++ b/nac3core/src/codegen/values/option.rs @@ -0,0 +1,75 @@ +use inkwell::{ + types::IntType, + values::{BasicValueEnum, IntValue, PointerValue}, +}; + +use super::ProxyValue; +use crate::codegen::{types::OptionType, CodeGenContext}; + +/// Proxy type for accessing a `Option` value in LLVM. +#[derive(Copy, Clone)] +pub struct OptionValue<'ctx> { + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> OptionValue<'ctx> { + /// Creates an [`OptionValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); + + Self { value: ptr, llvm_usize, name } + } + + /// Returns an `i1` indicating if this `Option` instance does not hold a value. + #[must_use] + pub fn is_none(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + ctx.builder.build_is_null(self.value, "").unwrap() + } + + /// Returns an `i1` indicating if this `Option` instance contains a value. + #[must_use] + pub fn is_some(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + ctx.builder.build_is_not_null(self.value, "").unwrap() + } + + /// Loads the value present in this `Option` instance. + /// + /// # Safety + /// + /// The caller must ensure that this `option` value [contains a value][Self::is_some]. + #[must_use] + pub unsafe fn load(&self, ctx: &CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> { + ctx.builder.build_load(self.value, "").unwrap() + } +} + +impl<'ctx> ProxyValue<'ctx> for OptionValue<'ctx> { + type ABI = PointerValue<'ctx>; + type Base = PointerValue<'ctx>; + type Type = OptionType<'ctx>; + + fn get_type(&self) -> Self::Type { + Self::Type::from_pointer_type(self.value.get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } +} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: OptionValue<'ctx>) -> Self { + value.as_base_value() + } +} From 064aa0411f678bd40fe8acdeffd3cad71894e1b5 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 4 Feb 2025 14:43:33 +0800 Subject: [PATCH 47/49] [core] codegen: Add Exception{Type,Value} --- nac3core/src/codegen/expr.rs | 63 +++--- nac3core/src/codegen/stmt.rs | 42 +--- nac3core/src/codegen/types/exception.rs | 257 +++++++++++++++++++++++ nac3core/src/codegen/types/mod.rs | 2 + nac3core/src/codegen/values/exception.rs | 188 +++++++++++++++++ nac3core/src/codegen/values/mod.rs | 2 + 6 files changed, 488 insertions(+), 66 deletions(-) create mode 100644 nac3core/src/codegen/types/exception.rs create mode 100644 nac3core/src/codegen/values/exception.rs diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3f48154..b3cb369 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -32,7 +32,9 @@ use super::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, }, - types::{ndarray::NDArrayType, ListType, OptionType, RangeType, StringType, TupleType}, + types::{ + ndarray::NDArrayType, ExceptionType, ListType, OptionType, RangeType, StringType, TupleType, + }, values::{ ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, @@ -576,42 +578,35 @@ impl<'ctx> CodeGenContext<'ctx, '_> { params: [Option>; 3], loc: Location, ) { + let llvm_i32 = self.ctx.i32_type(); + let llvm_i64 = self.ctx.i64_type(); + let llvm_exn = ExceptionType::get_instance(generator, self); + let zelf = if let Some(exception_val) = self.exception_val { - exception_val + llvm_exn.map_pointer_value(exception_val, Some("exn")) } else { - let ty = self.get_llvm_type(generator, self.primitives.exception).into_pointer_type(); - let zelf_ty: BasicTypeEnum = ty.get_element_type().into_struct_type().into(); - let zelf = generator.gen_var_alloc(self, zelf_ty, Some("exn")).unwrap(); - *self.exception_val.insert(zelf) + let zelf = llvm_exn.alloca_var(generator, self, Some("exn")); + self.exception_val = Some(zelf.as_abi_value(self)); + zelf }; - let int32 = self.ctx.i32_type(); - let zero = int32.const_zero(); - unsafe { - let id_ptr = self.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap(); - let id = self.resolver.get_string_id(name); - self.builder.build_store(id_ptr, int32.const_int(id as u64, false)).unwrap(); - let ptr = self - .builder - .build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg") - .unwrap(); - self.builder.build_store(ptr, msg).unwrap(); - let i64_zero = self.ctx.i64_type().const_zero(); - for (i, attr_ind) in [6, 7, 8].iter().enumerate() { - let ptr = self - .builder - .build_in_bounds_gep( - zelf, - &[zero, int32.const_int(*attr_ind, false)], - "exn.param", - ) - .unwrap(); - let val = params[i].map_or(i64_zero, |v| { - self.builder.build_int_s_extend(v, self.ctx.i64_type(), "sext").unwrap() - }); - self.builder.build_store(ptr, val).unwrap(); - } - } - gen_raise(generator, self, Some(&zelf.into()), loc); + + let id = self.resolver.get_string_id(name); + zelf.store_name(self, llvm_i32.const_int(id as u64, false)); + zelf.store_message(self, msg.into_struct_value()); + zelf.store_params( + self, + params + .iter() + .map(|p| { + p.map_or(llvm_i64.const_zero(), |v| { + self.builder.build_int_s_extend(v, self.ctx.i64_type(), "sext").unwrap() + }) + }) + .collect_array() + .as_ref() + .unwrap(), + ); + gen_raise(generator, self, Some(&zelf), loc); } pub fn make_assert( diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 0c1b931..35ffeea 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -17,10 +17,10 @@ use super::{ gen_in_range_check, irrt::{handle_slice_indices, list_slice_assignment}, macros::codegen_unreachable, - types::{ndarray::NDArrayType, RangeType}, + types::{ndarray::NDArrayType, ExceptionType, RangeType}, values::{ ndarray::{RustNDIndex, ScalarOrNDArray}, - ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, + ArrayLikeIndexer, ArraySliceValue, ExceptionValue, ListValue, ProxyValue, }, CodeGenContext, CodeGenerator, }; @@ -1337,43 +1337,19 @@ pub fn exn_constructor<'ctx>( pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - exception: Option<&BasicValueEnum<'ctx>>, + exception: Option<&ExceptionValue<'ctx>>, loc: Location, ) { if let Some(exception) = exception { - unsafe { - let int32 = ctx.ctx.i32_type(); - let zero = int32.const_zero(); - let exception = exception.into_pointer_value(); - let file_ptr = ctx - .builder - .build_in_bounds_gep(exception, &[zero, int32.const_int(1, false)], "file_ptr") - .unwrap(); - let filename = ctx.gen_string(generator, loc.file.0); - ctx.builder.build_store(file_ptr, filename).unwrap(); - let row_ptr = ctx - .builder - .build_in_bounds_gep(exception, &[zero, int32.const_int(2, false)], "row_ptr") - .unwrap(); - ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false)).unwrap(); - let col_ptr = ctx - .builder - .build_in_bounds_gep(exception, &[zero, int32.const_int(3, false)], "col_ptr") - .unwrap(); - ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false)).unwrap(); + exception.store_location(generator, ctx, loc); - let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); - let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap()); - let name_ptr = ctx - .builder - .build_in_bounds_gep(exception, &[zero, int32.const_int(4, false)], "name_ptr") - .unwrap(); - ctx.builder.build_store(name_ptr, fun_name).unwrap(); - } + let current_fun = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap(); + let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap()); + exception.store_func(ctx, fun_name); let raise = get_builtins(generator, ctx, "__nac3_raise"); let exception = *exception; - ctx.build_call_or_invoke(raise, &[exception], "raise"); + ctx.build_call_or_invoke(raise, &[exception.as_abi_value(ctx).into()], "raise"); } else { let resume = get_builtins(generator, ctx, "__nac3_resume"); ctx.build_call_or_invoke(resume, &[], "resume"); @@ -1860,6 +1836,8 @@ pub fn gen_stmt( } else { return Ok(()); }; + let exc = ExceptionType::get_instance(generator, ctx) + .map_pointer_value(exc.into_pointer_value(), None); gen_raise(generator, ctx, Some(&exc), stmt.location); } else { gen_raise(generator, ctx, None, stmt.location); diff --git a/nac3core/src/codegen/types/exception.rs b/nac3core/src/codegen/types/exception.rs new file mode 100644 index 0000000..0a8ec05 --- /dev/null +++ b/nac3core/src/codegen/types/exception.rs @@ -0,0 +1,257 @@ +use inkwell::{ + context::{AsContextRef, Context}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{IntValue, PointerValue, StructValue}, + AddressSpace, +}; +use itertools::Itertools; + +use nac3core_derive::StructFields; + +use super::{ + structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType}, + ProxyType, +}; +use crate::{ + codegen::{values::ExceptionValue, CodeGenContext, CodeGenerator}, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// Proxy type for an `Exception` in LLVM. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct ExceptionType<'ctx> { + ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct ExceptionStructFields<'ctx> { + /// The ID of the exception name. + #[value_type(i32_type())] + pub name: StructField<'ctx, IntValue<'ctx>>, + + /// The file where the exception originated from. + #[value_type(get_struct_type("str").unwrap())] + pub file: StructField<'ctx, StructValue<'ctx>>, + + /// The line number where the exception originated from. + #[value_type(i32_type())] + pub line: StructField<'ctx, IntValue<'ctx>>, + + /// The column number where the exception originated from. + #[value_type(i32_type())] + pub col: StructField<'ctx, IntValue<'ctx>>, + + /// The function name where the exception originated from. + #[value_type(get_struct_type("str").unwrap())] + pub func: StructField<'ctx, StructValue<'ctx>>, + + /// The exception message. + #[value_type(get_struct_type("str").unwrap())] + pub message: StructField<'ctx, StructValue<'ctx>>, + + #[value_type(i64_type())] + pub param0: StructField<'ctx, IntValue<'ctx>>, + + #[value_type(i64_type())] + pub param1: StructField<'ctx, IntValue<'ctx>>, + + #[value_type(i64_type())] + pub param2: StructField<'ctx, IntValue<'ctx>>, +} + +impl<'ctx> ExceptionType<'ctx> { + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields( + ctx: impl AsContextRef<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> ExceptionStructFields<'ctx> { + ExceptionStructFields::new(ctx, llvm_usize) + } + + /// Creates an LLVM type corresponding to the expected structure of an `Exception`. + #[must_use] + fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { + assert!(ctx.get_struct_type("str").is_some()); + + let field_tys = + Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec(); + + ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) + } + + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { + let llvm_str = Self::llvm_type(ctx, llvm_usize); + + Self { ty: llvm_str, llvm_usize } + } + + /// Creates an instance of [`ExceptionType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`ExceptionType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + + /// Creates an [`ExceptionType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type(ctx: &mut CodeGenContext<'ctx, '_>, ty: Type) -> Self { + // Check unifier type + assert!( + matches!(&*ctx.unifier.get_ty_immutable(ty), TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.exception.obj_id(&ctx.unifier).unwrap()) + ); + + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an [`ExceptionType`] from a [`StructType`] representing an `Exception`. + #[must_use] + pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize) + } + + /// Creates an [`ExceptionType`] from a [`PointerType`] representing an `Exception`. + #[must_use] + pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok()); + + Self { ty: ptr_ty, llvm_usize } + } + + /// Returns an instance of [`ExceptionType`] by obtaining the LLVM representation of the builtin + /// `Exception` type. + #[must_use] + pub fn get_instance( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Self { + Self::from_pointer_type( + ctx.get_llvm_type(generator, ctx.primitives.exception).into_pointer_type(), + ctx.get_size_type(), + ) + } + + /// Allocates an instance of [`ExceptionValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca`]. + #[must_use] + pub fn alloca( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(ctx, name), + self.llvm_usize, + name, + ) + } + + /// Allocates an instance of [`ExceptionValue`] as if by calling `alloca` on the base type. + /// + /// See [`ProxyType::raw_alloca_var`]. + #[must_use] + pub fn alloca_var( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca_var(generator, ctx, name), + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ExceptionValue`]. + #[must_use] + pub fn map_struct_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: StructValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_struct_value( + generator, + ctx, + value, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ExceptionValue`]. + #[must_use] + pub fn map_pointer_value( + &self, + value: PointerValue<'ctx>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value(value, self.llvm_usize, name) + } +} + +impl<'ctx> ProxyType<'ctx> for ExceptionType<'ctx> { + type ABI = PointerType<'ctx>; + type Base = PointerType<'ctx>; + type Value = ExceptionValue<'ctx>; + + fn is_representable( + llvm_ty: impl BasicType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + Self::has_same_repr(ty, llvm_usize) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let ctx = ty.get_context(); + + let llvm_ty = ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!("Expected struct type for `list` type, got {llvm_ty}")); + }; + + check_struct_type_matches_fields(Self::fields(ctx, llvm_usize), llvm_ty, "exception", &[]) + } + + fn alloca_type(&self) -> impl BasicType<'ctx> { + self.as_abi_type().get_element_type().into_struct_type() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } + + fn as_abi_type(&self) -> Self::ABI { + self.as_base_type() + } +} + +impl<'ctx> StructProxyType<'ctx> for ExceptionType<'ctx> { + type StructFields = ExceptionStructFields<'ctx>; + + fn get_fields(&self) -> Self::StructFields { + Self::fields(self.ty.get_context(), self.llvm_usize) + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: ExceptionType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index cbab600..1dc776b 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -25,12 +25,14 @@ use super::{ values::{ArraySliceValue, ProxyValue}, {CodeGenContext, CodeGenerator}, }; +pub use exception::*; pub use list::*; pub use option::*; pub use range::*; pub use string::*; pub use tuple::*; +mod exception; mod list; pub mod ndarray; mod option; diff --git a/nac3core/src/codegen/values/exception.rs b/nac3core/src/codegen/values/exception.rs new file mode 100644 index 0000000..0b1796b --- /dev/null +++ b/nac3core/src/codegen/values/exception.rs @@ -0,0 +1,188 @@ +use inkwell::{ + types::IntType, + values::{IntValue, PointerValue, StructValue}, +}; +use itertools::Itertools; + +use nac3parser::ast::Location; + +use super::{structure::StructProxyValue, ProxyValue, StringValue}; +use crate::codegen::{ + types::{ + structure::{StructField, StructProxyType}, + ExceptionType, + }, + CodeGenContext, CodeGenerator, +}; + +/// Proxy type for accessing an `Exception` value in LLVM. +#[derive(Copy, Clone)] +pub struct ExceptionValue<'ctx> { + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> ExceptionValue<'ctx> { + /// Creates an [`ExceptionValue`] from a [`StructValue`]. + #[must_use] + pub fn from_struct_value( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + val: StructValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + let pval = generator + .gen_var_alloc( + ctx, + val.get_type().into(), + name.map(|name| format!("{name}.addr")).as_deref(), + ) + .unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + Self::from_pointer_value(pval, llvm_usize, name) + } + + /// Creates an [`ExceptionValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); + + Self { value: ptr, llvm_usize, name } + } + + fn name_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().name + } + + /// Stores the ID of the exception name into this instance. + pub fn store_name(&self, ctx: &CodeGenContext<'ctx, '_>, name: IntValue<'ctx>) { + debug_assert_eq!(name.get_type(), ctx.ctx.i32_type()); + + self.name_field().store(ctx, self.value, name, self.name); + } + + fn file_field(&self) -> StructField<'ctx, StructValue<'ctx>> { + self.get_type().get_fields().file + } + + /// Stores the file name of the exception source into this instance. + pub fn store_file(&self, ctx: &CodeGenContext<'ctx, '_>, file: StructValue<'ctx>) { + debug_assert!(StringValue::is_instance(file, self.llvm_usize).is_ok()); + + self.file_field().store(ctx, self.value, file, self.name); + } + + fn line_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().line + } + + fn col_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().col + } + + /// Stores the [location][Location] of the exception source into this instance. + pub fn store_location( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + location: Location, + ) { + let llvm_i32 = ctx.ctx.i32_type(); + + let filename = ctx.gen_string(generator, location.file.0); + self.store_file(ctx, filename); + + self.line_field().store( + ctx, + self.value, + llvm_i32.const_int(location.row as u64, false), + self.name, + ); + self.col_field().store( + ctx, + self.value, + llvm_i32.const_int(location.column as u64, false), + self.name, + ); + } + + fn func_field(&self) -> StructField<'ctx, StructValue<'ctx>> { + self.get_type().get_fields().func + } + + /// Stores the function name of the exception source into this instance. + pub fn store_func(&self, ctx: &CodeGenContext<'ctx, '_>, func: StructValue<'ctx>) { + debug_assert!(StringValue::is_instance(func, self.llvm_usize).is_ok()); + + self.func_field().store(ctx, self.value, func, self.name); + } + + fn message_field(&self) -> StructField<'ctx, StructValue<'ctx>> { + self.get_type().get_fields().message + } + + /// Stores the exception message into this instance. + pub fn store_message(&self, ctx: &CodeGenContext<'ctx, '_>, message: StructValue<'ctx>) { + debug_assert!(StringValue::is_instance(message, self.llvm_usize).is_ok()); + + self.message_field().store(ctx, self.value, message, self.name); + } + + fn param0_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().param0 + } + + fn param1_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().param1 + } + + fn param2_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().param2 + } + + /// Stores the parameters of the exception into this instance. + /// + /// If the parameter does not exist, pass `i64 0` in the parameter slot. + pub fn store_params(&self, ctx: &CodeGenContext<'ctx, '_>, params: &[IntValue<'ctx>; 3]) { + debug_assert!(params.iter().all(|p| p.get_type() == ctx.ctx.i64_type())); + + [self.param0_field(), self.param1_field(), self.param2_field()] + .into_iter() + .zip_eq(params) + .for_each(|(field, param)| { + field.store(ctx, self.value, *param, self.name); + }); + } +} + +impl<'ctx> ProxyValue<'ctx> for ExceptionValue<'ctx> { + type ABI = PointerValue<'ctx>; + type Base = PointerValue<'ctx>; + type Type = ExceptionType<'ctx>; + + fn get_type(&self) -> Self::Type { + Self::Type::from_pointer_type(self.value.get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } + + fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI { + self.as_base_value() + } +} + +impl<'ctx> StructProxyValue<'ctx> for ExceptionValue<'ctx> {} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: ExceptionValue<'ctx>) -> Self { + value.as_base_value() + } +} diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index 7a43ba4..5093333 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -2,6 +2,7 @@ use inkwell::{types::IntType, values::BasicValue}; use super::{types::ProxyType, CodeGenContext}; pub use array::*; +pub use exception::*; pub use list::*; pub use option::*; pub use range::*; @@ -9,6 +10,7 @@ pub use string::*; pub use tuple::*; mod array; +mod exception; mod list; pub mod ndarray; mod option; From 715dc71396675d6fd9951b911a7677e482a78a38 Mon Sep 17 00:00:00 2001 From: occheung Date: Mon, 10 Feb 2025 11:08:24 +0800 Subject: [PATCH 48/49] nac3artiq: acquire special python identifiers --- nac3artiq/demo/min_artiq.py | 11 +- nac3artiq/src/codegen.rs | 240 ++++++++++++++++++++---------------- nac3artiq/src/lib.rs | 31 +++++ 3 files changed, 172 insertions(+), 110 deletions(-) diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index fef018b..cba3ad2 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -16,7 +16,7 @@ __all__ = [ "rpc", "ms", "us", "ns", "print_int32", "print_int64", "Core", "TTLOut", - "parallel", "sequential" + "parallel", "legacy_parallel", "sequential" ] @@ -245,7 +245,7 @@ class Core: embedding = EmbeddingMap() if allow_registration: - compiler.analyze(registered_functions, registered_classes, set()) + compiler.analyze(registered_functions, registered_classes, special_ids, set()) allow_registration = False if hasattr(method, "__self__"): @@ -336,4 +336,11 @@ class UnwrapNoneError(Exception): artiq_builtin = True parallel = KernelContextManager() +legacy_parallel = KernelContextManager() sequential = KernelContextManager() + +special_ids = { + "parallel": id(parallel), + "legacy_parallel": id(legacy_parallel), + "sequential": id(sequential), +} diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index d086420..cc625a0 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -12,7 +12,7 @@ use pyo3::{ PyObject, PyResult, Python, }; -use super::{symbol_resolver::InnerResolver, timeline::TimeFns}; +use super::{symbol_resolver::InnerResolver, timeline::TimeFns, SpecialPythonId}; use nac3core::{ codegen::{ expr::{create_fn_and_call, destructure_range, gen_call, infer_and_call_function}, @@ -86,6 +86,9 @@ pub struct ArtiqCodeGenerator<'a> { /// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel` /// statement, which is used to determine when and how the timeline should be updated. parallel_mode: ParallelMode, + + /// Specially treated python IDs to identify `with parallel` and `with sequential` blocks. + special_ids: SpecialPythonId, } impl<'a> ArtiqCodeGenerator<'a> { @@ -93,6 +96,7 @@ impl<'a> ArtiqCodeGenerator<'a> { name: String, size_t: IntType<'_>, timeline: &'a (dyn TimeFns + Sync), + special_ids: SpecialPythonId, ) -> ArtiqCodeGenerator<'a> { assert!(matches!(size_t.get_bit_width(), 32 | 64)); ArtiqCodeGenerator { @@ -103,6 +107,7 @@ impl<'a> ArtiqCodeGenerator<'a> { end: None, timeline, parallel_mode: ParallelMode::None, + special_ids, } } @@ -112,9 +117,10 @@ impl<'a> ArtiqCodeGenerator<'a> { ctx: &Context, target_machine: &TargetMachine, timeline: &'a (dyn TimeFns + Sync), + special_ids: SpecialPythonId, ) -> ArtiqCodeGenerator<'a> { let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None); - Self::new(name, llvm_usize, timeline) + Self::new(name, llvm_usize, timeline, special_ids) } /// If the generator is currently in a direct-`parallel` block context, emits IR that resets the @@ -260,122 +266,140 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> { // - If there is a end variable, it indicates that we are (indirectly) inside a // parallel block, and we should update the max end value. if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node { - if id == &"parallel".into() || id == &"legacy_parallel".into() { - let old_start = self.start.take(); - let old_end = self.end.take(); - let old_parallel_mode = self.parallel_mode; + let resolver = ctx.resolver.clone(); + if let Some(static_value) = + if let Some((_ptr, static_value, _counter)) = ctx.var_assignment.get(id) { + static_value.clone() + } else if let Some(ValueEnum::Static(val)) = + resolver.get_symbol_value(*id, ctx, self) + { + Some(val) + } else { + None + } + { + let python_id = static_value.get_unique_identifier(); + if python_id == self.special_ids.parallel + || python_id == self.special_ids.legacy_parallel + { + let old_start = self.start.take(); + let old_end = self.end.take(); + let old_parallel_mode = self.parallel_mode; - let now = if let Some(old_start) = &old_start { - self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum( + let now = if let Some(old_start) = &old_start { + self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum( + ctx, + self, + old_start.custom.unwrap(), + )? + } else { + self.timeline.emit_now_mu(ctx) + }; + + // Emulate variable allocation, as we need to use the CodeGenContext + // HashMap to store our variable due to lifetime limitation + // Note: we should be able to store variables directly if generic + // associative type is used by limiting the lifetime of CodeGenerator to + // the LLVM Context. + // The name is guaranteed to be unique as users cannot use this as variable + // name. + self.start = old_start.clone().map_or_else( + || { + let start = format!("with-{}-start", self.name_counter).into(); + let start_expr = Located { + // location does not matter at this point + location: stmt.location, + node: ExprKind::Name { id: start, ctx: *name_ctx }, + custom: Some(ctx.primitives.int64), + }; + let start = self + .gen_store_target(ctx, &start_expr, Some("start.addr"))? + .unwrap(); + ctx.builder.build_store(start, now).unwrap(); + Ok(Some(start_expr)) as Result<_, String> + }, + |v| Ok(Some(v)), + )?; + let end = format!("with-{}-end", self.name_counter).into(); + let end_expr = Located { + // location does not matter at this point + location: stmt.location, + node: ExprKind::Name { id: end, ctx: *name_ctx }, + custom: Some(ctx.primitives.int64), + }; + let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap(); + ctx.builder.build_store(end, now).unwrap(); + self.end = Some(end_expr); + self.name_counter += 1; + self.parallel_mode = if python_id == self.special_ids.parallel { + ParallelMode::Deep + } else if python_id == self.special_ids.legacy_parallel { + ParallelMode::Legacy + } else { + unreachable!() + }; + + self.gen_block(ctx, body.iter())?; + + let current = ctx.builder.get_insert_block().unwrap(); + + // if the current block is terminated, move before the terminator + // we want to set the timeline before reaching the terminator + // TODO: This may be unsound if there are multiple exit paths in the + // block... e.g. + // if ...: + // return + // Perhaps we can fix this by using actual with block? + let reset_position = if let Some(terminator) = current.get_terminator() { + ctx.builder.position_before(&terminator); + true + } else { + false + }; + + // set duration + let end_expr = self.end.take().unwrap(); + let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum( ctx, self, - old_start.custom.unwrap(), - )? - } else { - self.timeline.emit_now_mu(ctx) - }; + end_expr.custom.unwrap(), + )?; - // Emulate variable allocation, as we need to use the CodeGenContext - // HashMap to store our variable due to lifetime limitation - // Note: we should be able to store variables directly if generic - // associative type is used by limiting the lifetime of CodeGenerator to - // the LLVM Context. - // The name is guaranteed to be unique as users cannot use this as variable - // name. - self.start = old_start.clone().map_or_else( - || { - let start = format!("with-{}-start", self.name_counter).into(); - let start_expr = Located { - // location does not matter at this point - location: stmt.location, - node: ExprKind::Name { id: start, ctx: *name_ctx }, - custom: Some(ctx.primitives.int64), - }; - let start = self - .gen_store_target(ctx, &start_expr, Some("start.addr"))? - .unwrap(); - ctx.builder.build_store(start, now).unwrap(); - Ok(Some(start_expr)) as Result<_, String> - }, - |v| Ok(Some(v)), - )?; - let end = format!("with-{}-end", self.name_counter).into(); - let end_expr = Located { - // location does not matter at this point - location: stmt.location, - node: ExprKind::Name { id: end, ctx: *name_ctx }, - custom: Some(ctx.primitives.int64), - }; - let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap(); - ctx.builder.build_store(end, now).unwrap(); - self.end = Some(end_expr); - self.name_counter += 1; - self.parallel_mode = match id.to_string().as_str() { - "parallel" => ParallelMode::Deep, - "legacy_parallel" => ParallelMode::Legacy, - _ => unreachable!(), - }; + // inside a sequential block + if old_start.is_none() { + self.timeline.emit_at_mu(ctx, end_val); + } - self.gen_block(ctx, body.iter())?; + // inside a parallel block, should update the outer max now_mu + self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?; - let current = ctx.builder.get_insert_block().unwrap(); + self.parallel_mode = old_parallel_mode; + self.end = old_end; + self.start = old_start; - // if the current block is terminated, move before the terminator - // we want to set the timeline before reaching the terminator - // TODO: This may be unsound if there are multiple exit paths in the - // block... e.g. - // if ...: - // return - // Perhaps we can fix this by using actual with block? - let reset_position = if let Some(terminator) = current.get_terminator() { - ctx.builder.position_before(&terminator); - true - } else { - false - }; + if reset_position { + ctx.builder.position_at_end(current); + } - // set duration - let end_expr = self.end.take().unwrap(); - let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum( - ctx, - self, - end_expr.custom.unwrap(), - )?; + return Ok(()); + } else if python_id == self.special_ids.sequential { + // For deep parallel, temporarily take away start to avoid function calls in + // the block from resetting the timeline. + // This does not affect legacy parallel, as the timeline will be reset after + // this block finishes execution. + let start = self.start.take(); + self.gen_block(ctx, body.iter())?; + self.start = start; - // inside a sequential block - if old_start.is_none() { - self.timeline.emit_at_mu(ctx, end_val); + // Reset the timeline when we are exiting the sequential block + // Legacy parallel does not need this, since it will be reset after codegen + // for this statement is completed + if self.parallel_mode == ParallelMode::Deep { + self.timeline_reset_start(ctx)?; + } + + return Ok(()); } - - // inside a parallel block, should update the outer max now_mu - self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?; - - self.parallel_mode = old_parallel_mode; - self.end = old_end; - self.start = old_start; - - if reset_position { - ctx.builder.position_at_end(current); - } - - return Ok(()); - } else if id == &"sequential".into() { - // For deep parallel, temporarily take away start to avoid function calls in - // the block from resetting the timeline. - // This does not affect legacy parallel, as the timeline will be reset after - // this block finishes execution. - let start = self.start.take(); - self.gen_block(ctx, body.iter())?; - self.start = start; - - // Reset the timeline when we are exiting the sequential block - // Legacy parallel does not need this, since it will be reset after codegen - // for this statement is completed - if self.parallel_mode == ParallelMode::Deep { - self.timeline_reset_start(ctx)?; - } - - return Ok(()); } } } diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index ba6c4fa..d4136a0 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -162,6 +162,13 @@ pub struct PrimitivePythonId { module: u64, } +#[derive(Clone, Default)] +pub struct SpecialPythonId { + parallel: u64, + legacy_parallel: u64, + sequential: u64, +} + type TopLevelComponent = (Stmt, String, PyObject); // TopLevelComposer is unsendable as it holds the unification table, which is @@ -179,6 +186,7 @@ struct Nac3 { string_store: Arc>>, exception_ids: Arc>>, deferred_eval_store: DeferredEvaluationStore, + special_ids: SpecialPythonId, /// LLVM-related options for code generation. llvm_options: CodeGenLLVMOptions, } @@ -797,6 +805,7 @@ impl Nac3 { &context, &self.get_llvm_target_machine(), self.time_fns, + self.special_ids.clone(), )) }) .collect(); @@ -813,6 +822,7 @@ impl Nac3 { &context, &self.get_llvm_target_machine(), self.time_fns, + self.special_ids.clone(), ); let module = context.create_module("main"); let target_machine = self.llvm_options.create_target_machine().unwrap(); @@ -1192,6 +1202,7 @@ impl Nac3 { string_store: Arc::new(string_store.into()), exception_ids: Arc::default(), deferred_eval_store: DeferredEvaluationStore::new(), + special_ids: Default::default(), llvm_options: CodeGenLLVMOptions { opt_level: OptimizationLevel::Default, target: isa.get_llvm_target_options(), @@ -1203,6 +1214,7 @@ impl Nac3 { &mut self, functions: &PySet, classes: &PySet, + special_ids: &PyDict, content_modules: &PySet, ) -> PyResult<()> { let (modules, class_ids) = @@ -1236,6 +1248,25 @@ impl Nac3 { for module in modules.into_values() { self.register_module(&module, &class_ids)?; } + + self.special_ids = SpecialPythonId { + parallel: special_ids.get_item("parallel").ok().flatten().unwrap().extract().unwrap(), + legacy_parallel: special_ids + .get_item("legacy_parallel") + .ok() + .flatten() + .unwrap() + .extract() + .unwrap(), + sequential: special_ids + .get_item("sequential") + .ok() + .flatten() + .unwrap() + .extract() + .unwrap(), + }; + Ok(()) } From 82a580c5c6e68e0c8f92ecd44da2a18b1adf7d0a Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Mon, 10 Feb 2025 16:53:35 +0800 Subject: [PATCH 49/49] flake: update ARTIQ source used for PGO --- flake.nix | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flake.nix b/flake.nix index 51551c7..243153b 100644 --- a/flake.nix +++ b/flake.nix @@ -113,8 +113,8 @@ (pkgs.fetchFromGitHub { owner = "m-labs"; repo = "artiq"; - rev = "28c9de3e251daa89a8c9fd79d5ab64a3ec03bac6"; - sha256 = "sha256-vAvpbHc5B+1wtG8zqN7j9dQE1ON+i22v+uqA+tw6Gak="; + rev = "554b0749ca5985bf4d006c4f29a05e83de0a226d"; + sha256 = "sha256-3eSNHTSlmdzLMcEMIspxqjmjrcQe4aIGqIfRgquUg18="; }) ]; buildInputs = [