From e13d7533297fec1f50cbfe3ac100c7b85b9bc3d0 Mon Sep 17 00:00:00 2001 From: ram Date: Wed, 11 Dec 2024 17:14:11 +0000 Subject: [PATCH 1/7] Implement string equality operator using IRRT and optimise LLVM implementation --- nac3core/irrt/irrt.cpp | 1 + nac3core/irrt/irrt/string.hpp | 22 ++++++ nac3core/src/codegen/expr.rs | 103 +++++----------------------- nac3core/src/codegen/irrt/mod.rs | 2 + nac3core/src/codegen/irrt/string.rs | 45 ++++++++++++ nac3standalone/demo/src/str.py | 8 ++- 6 files changed, 94 insertions(+), 87 deletions(-) create mode 100644 nac3core/irrt/irrt/string.hpp create mode 100644 nac3core/src/codegen/irrt/string.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 7966322a..05d07ad3 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -3,3 +3,4 @@ #include "irrt/math.hpp" #include "irrt/ndarray.hpp" #include "irrt/slice.hpp" +#include "irrt/string.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp new file mode 100644 index 00000000..5fc35715 --- /dev/null +++ b/nac3core/irrt/irrt/string.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "irrt/int_types.hpp" +namespace { +template +int32_t __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { + if (str1 == str2) return 1; + if (len1 != len2) return 0; + for (SizeT i = 0; i < len1; ++i) { + if (static_cast(str1[i]) != static_cast(str2[i])) { + return 0; + } + } + return 1; +} +} // namespace + +extern "C" { +int32_t nac3_str_eq(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { + return __nac3_str_eq_impl(str1, len1, str2, len2); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 2cd56700..9945d917 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -24,7 +24,7 @@ use super::{ irrt::*, llvm_intrinsics::{ call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, - call_int_umin, call_memcpy_generic, + call_memcpy_generic, }, macros::codegen_unreachable, need_sret, numpy, @@ -2072,111 +2072,42 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } else if left_ty == ctx.primitives.str { assert!(ctx.unifier.unioned(left_ty, right_ty)); - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let lhs = lhs.into_struct_value(); let rhs = rhs.into_struct_value(); + let llvm_i32 = ctx.ctx.i32_type(); + let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(plhs, lhs).unwrap(); let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(prhs, rhs).unwrap(); + let lhs_ptr = ctx.build_in_bounds_gep_and_load( + plhs, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_pointer_value(); let lhs_len = ctx.build_in_bounds_gep_and_load( plhs, &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); + + let rhs_ptr = ctx.build_in_bounds_gep_and_load( + prhs, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_pointer_value(); let rhs_len = ctx.build_in_bounds_gep_and_load( prhs, &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); - - let len = call_int_umin(ctx, lhs_len, rhs_len, None); - - let current_bb = ctx.builder.get_insert_block().unwrap(); - let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end"); - - ctx.builder.position_at_end(post_foreach_cmp); - let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap(); - ctx.builder.position_at_end(current_bb); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (len, false), - |generator, ctx, _, i| { - let lhs_char = { - let plhs_data = ctx.build_in_bounds_gep_and_load( - plhs, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - - ctx.build_in_bounds_gep_and_load( - plhs_data, - &[i], - None - ).into_int_value() - }; - let rhs_char = { - let prhs_data = ctx.build_in_bounds_gep_and_load( - prhs, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - - ctx.build_in_bounds_gep_and_load( - prhs_data, - &[i], - None - ).into_int_value() - }; - - gen_if_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap()) - }, - |_, ctx| { - let bb = ctx.builder.get_insert_block().unwrap(); - cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]); - ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); - - Ok(()) - }, - |_, _| Ok(()), - )?; - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let bb = ctx.builder.get_insert_block().unwrap(); - let is_len_eq = ctx.builder.build_int_compare( - IntPredicate::EQ, - lhs_len, - rhs_len, - "", - ).unwrap(); - cmp_phi.add_incoming(&[(&is_len_eq, bb)]); - ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); - - ctx.builder.position_at_end(post_foreach_cmp); - let cmp_phi = cmp_phi.as_basic_value().into_int_value(); - - // Invert the final value if __ne__ + let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); if *op == Cmpop::NotEq { - ctx.builder.build_not(cmp_phi, "").unwrap() + ctx.builder.build_not(result, "").unwrap() } else { - cmp_phi + result } } else if [left_ty, right_ty] .iter() diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index f6c4a1eb..3106928f 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -15,11 +15,13 @@ pub use list::*; pub use math::*; pub use ndarray::*; pub use slice::*; +pub use string::*; mod list; mod math; mod ndarray; mod slice; +mod string; #[must_use] pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> { diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs new file mode 100644 index 00000000..0a9d3a13 --- /dev/null +++ b/nac3core/src/codegen/irrt/string.rs @@ -0,0 +1,45 @@ +use inkwell::{ + values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}, + AddressSpace, +}; +use itertools::Either; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. +pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + str1_ptr: PointerValue<'ctx>, + str1_len: IntValue<'ctx>, + str2_ptr: PointerValue<'ctx>, + str2_len: IntValue<'ctx>, +) -> IntValue<'ctx> { + let string_eq_fn = ctx.module.get_function("nac3_str_eq").unwrap_or_else(|| { + let i8_ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + let i64_type = ctx.ctx.i64_type(); + let i32_type = ctx.ctx.i32_type(); + let fn_type = i32_type.fn_type( + &[i8_ptr_type.into(), i64_type.into(), i8_ptr_type.into(), i64_type.into()], + false, + ); + ctx.module.add_function("nac3_str_eq", fn_type, None) + }); + let result = ctx + .builder + .build_call( + string_eq_fn, + &[ + str1_ptr.into(), + ctx.builder.build_int_z_extend(str1_len, ctx.ctx.i64_type(), "").unwrap().into(), + str2_ptr.into(), + ctx.builder.build_int_z_extend(str2_len, ctx.ctx.i64_type(), "").unwrap().into(), + ], + "string_eq", + ) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(); + generator.bool_to_i1(ctx, result) +} diff --git a/nac3standalone/demo/src/str.py b/nac3standalone/demo/src/str.py index d28f05e2..11632dc8 100644 --- a/nac3standalone/demo/src/str.py +++ b/nac3standalone/demo/src/str.py @@ -11,6 +11,9 @@ def str_eq(): output_bool("a" == "a") output_bool("test string" == "test string") output_bool("test string1" == "test string2") + output_bool("test" == "testing") + output_bool("abcd" == "abdc") + output_bool(" " == " ") def str_ne(): @@ -21,10 +24,13 @@ def str_ne(): output_bool("a" != "a") output_bool("test string" != "test string") output_bool("test string1" != "test string2") + output_bool("test" != "testing") + output_bool("abcd" != "abdc") + output_bool(" " != " ") def run() -> int32: str_eq() str_ne() - return 0 + return 0 \ No newline at end of file From 780d33c8a74b6b4c6a1ebdd338148daf2566d457 Mon Sep 17 00:00:00 2001 From: ram Date: Thu, 12 Dec 2024 10:05:11 +0000 Subject: [PATCH 2/7] Edit function call to support 32-bit and 64-bit str --- nac3core/irrt/irrt/string.hpp | 15 ++++++++------- nac3standalone/demo/src/str.py | 8 +------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp index 5fc35715..d873b9f3 100644 --- a/nac3core/irrt/irrt/string.hpp +++ b/nac3core/irrt/irrt/string.hpp @@ -1,17 +1,14 @@ #pragma once #include "irrt/int_types.hpp" + namespace { template int32_t __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { - if (str1 == str2) return 1; - if (len1 != len2) return 0; - for (SizeT i = 0; i < len1; ++i) { - if (static_cast(str1[i]) != static_cast(str2[i])) { - return 0; - } + if (len1 != len2){ + return 0; } - return 1; + return (__builtin_strncmp(str1, str2, static_cast(len1)) == 0) ? 1 : 0; } } // namespace @@ -19,4 +16,8 @@ extern "C" { int32_t nac3_str_eq(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { return __nac3_str_eq_impl(str1, len1, str2, len2); } + +int32_t nac3_str_eq_i32(const char* str1, uint32_t len1, const char* str2, uint32_t len2) { + return __nac3_str_eq_impl(str1, len1, str2, len2); +} } \ No newline at end of file diff --git a/nac3standalone/demo/src/str.py b/nac3standalone/demo/src/str.py index 11632dc8..d28f05e2 100644 --- a/nac3standalone/demo/src/str.py +++ b/nac3standalone/demo/src/str.py @@ -11,9 +11,6 @@ def str_eq(): output_bool("a" == "a") output_bool("test string" == "test string") output_bool("test string1" == "test string2") - output_bool("test" == "testing") - output_bool("abcd" == "abdc") - output_bool(" " == " ") def str_ne(): @@ -24,13 +21,10 @@ def str_ne(): output_bool("a" != "a") output_bool("test string" != "test string") output_bool("test string1" != "test string2") - output_bool("test" != "testing") - output_bool("abcd" != "abdc") - output_bool(" " != " ") def run() -> int32: str_eq() str_ne() - return 0 \ No newline at end of file + return 0 From 543a648af87c2bb62e9638825f6444b4135ae1ee Mon Sep 17 00:00:00 2001 From: ram Date: Thu, 12 Dec 2024 10:10:39 +0000 Subject: [PATCH 3/7] Edit function call to support 32-bit and 64-bit str --- nac3core/irrt/irrt/string.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp index d873b9f3..036b56a7 100644 --- a/nac3core/irrt/irrt/string.hpp +++ b/nac3core/irrt/irrt/string.hpp @@ -4,7 +4,7 @@ namespace { template -int32_t __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { +SizeT __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { if (len1 != len2){ return 0; } @@ -13,7 +13,7 @@ int32_t __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT } // namespace extern "C" { -int32_t nac3_str_eq(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { +int64_t nac3_str_eq(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { return __nac3_str_eq_impl(str1, len1, str2, len2); } From 9b0d37b1f0198347ab55204f8b6084ae386cd6c4 Mon Sep 17 00:00:00 2001 From: ram Date: Thu, 12 Dec 2024 10:29:00 +0000 Subject: [PATCH 4/7] Amend to follow formatting of other C++ files --- nac3core/irrt/irrt/string.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp index 036b56a7..0f28d5de 100644 --- a/nac3core/irrt/irrt/string.hpp +++ b/nac3core/irrt/irrt/string.hpp @@ -13,11 +13,11 @@ SizeT __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT l } // namespace extern "C" { -int64_t nac3_str_eq(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { - return __nac3_str_eq_impl(str1, len1, str2, len2); -} - -int32_t nac3_str_eq_i32(const char* str1, uint32_t len1, const char* str2, uint32_t len2) { +uint32_t nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) { return __nac3_str_eq_impl(str1, len1, str2, len2); } + +uint64_t nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { + return __nac3_str_eq_impl(str1, len1, str2, len2); +} } \ No newline at end of file From 0b6a9bd89bc4725771be193ce7bdf612a7271962 Mon Sep 17 00:00:00 2001 From: ram Date: Fri, 13 Dec 2024 15:43:50 +0000 Subject: [PATCH 5/7] Updated to use memcmp instead of strncmp --- nac3core/irrt/irrt/string.hpp | 2 +- nac3core/src/codegen/irrt/string.rs | 39 +++++++++++++---------------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp index 0f28d5de..f695dcdc 100644 --- a/nac3core/irrt/irrt/string.hpp +++ b/nac3core/irrt/irrt/string.hpp @@ -8,7 +8,7 @@ SizeT __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT l if (len1 != len2){ return 0; } - return (__builtin_strncmp(str1, str2, static_cast(len1)) == 0) ? 1 : 0; + return (__builtin_memcmp(str1, str2, static_cast(len1)) == 0) ? 1 : 0; } } // namespace diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs index 0a9d3a13..039c7273 100644 --- a/nac3core/src/codegen/irrt/string.rs +++ b/nac3core/src/codegen/irrt/string.rs @@ -1,7 +1,4 @@ -use inkwell::{ - values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}, - AddressSpace, -}; +use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}; use itertools::Either; use crate::codegen::{CodeGenContext, CodeGenerator}; @@ -15,27 +12,27 @@ pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( str2_ptr: PointerValue<'ctx>, str2_len: IntValue<'ctx>, ) -> IntValue<'ctx> { - let string_eq_fn = ctx.module.get_function("nac3_str_eq").unwrap_or_else(|| { - let i8_ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let i64_type = ctx.ctx.i64_type(); - let i32_type = ctx.ctx.i32_type(); - let fn_type = i32_type.fn_type( - &[i8_ptr_type.into(), i64_type.into(), i8_ptr_type.into(), i64_type.into()], - false, - ); - ctx.module.add_function("nac3_str_eq", fn_type, None) + let func = ctx.module.get_function("nac3_str_eq").unwrap_or_else(|| { + ctx.module.add_function( + "nac3_str_eq", + ctx.ctx.i32_type().fn_type( + &[ + str1_ptr.get_type().into(), + str1_len.get_type().into(), + str2_ptr.get_type().into(), + str2_len.get_type().into(), + ], + false, + ), + None, + ) }); let result = ctx .builder .build_call( - string_eq_fn, - &[ - str1_ptr.into(), - ctx.builder.build_int_z_extend(str1_len, ctx.ctx.i64_type(), "").unwrap().into(), - str2_ptr.into(), - ctx.builder.build_int_z_extend(str2_len, ctx.ctx.i64_type(), "").unwrap().into(), - ], - "string_eq", + func, + &[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()], + "str_eq_call", ) .map(CallSiteValue::try_as_basic_value) .map(|v| v.map_left(BasicValueEnum::into_int_value)) From e1a2f1239d63941dc79cb135a91ff63fea51dc6e Mon Sep 17 00:00:00 2001 From: ram Date: Mon, 16 Dec 2024 09:48:51 +0000 Subject: [PATCH 6/7] Implement 64 bit function call --- nac3core/src/codegen/irrt/string.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs index 039c7273..fb0f27b9 100644 --- a/nac3core/src/codegen/irrt/string.rs +++ b/nac3core/src/codegen/irrt/string.rs @@ -1,7 +1,7 @@ use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}; use itertools::Either; -use crate::codegen::{CodeGenContext, CodeGenerator}; +use crate::codegen::{macros::codegen_unreachable, CodeGenContext, CodeGenerator}; /// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( @@ -12,10 +12,16 @@ pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( str2_ptr: PointerValue<'ctx>, str2_len: IntValue<'ctx>, ) -> IntValue<'ctx> { - let func = ctx.module.get_function("nac3_str_eq").unwrap_or_else(|| { + let (func_name, return_type) = match ctx.ctx.i32_type().get_bit_width() { + 32 => ("nac3_str_eq", ctx.ctx.i32_type()), + 64 => ("nac3_str_eq64", ctx.ctx.i64_type()), + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + + let func = ctx.module.get_function(func_name).unwrap_or_else(|| { ctx.module.add_function( - "nac3_str_eq", - ctx.ctx.i32_type().fn_type( + func_name, + return_type.fn_type( &[ str1_ptr.get_type().into(), str1_len.get_type().into(), From cc185863631276205794c61d5c308c26df07d2eb Mon Sep 17 00:00:00 2001 From: ram Date: Mon, 16 Dec 2024 14:35:09 +0000 Subject: [PATCH 7/7] Use llvm_usize for first GEP index, llvm_i32 for second GEP index --- nac3core/src/codegen/expr.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 9945d917..523380d8 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2076,6 +2076,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( let rhs = rhs.into_struct_value(); let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(plhs, lhs).unwrap(); @@ -2084,23 +2085,23 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( let lhs_ptr = ctx.build_in_bounds_gep_and_load( plhs, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], + &[llvm_usize.const_zero(), llvm_i32.const_zero()], None, ).into_pointer_value(); let lhs_len = ctx.build_in_bounds_gep_and_load( plhs, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); let rhs_ptr = ctx.build_in_bounds_gep_and_load( prhs, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], + &[llvm_usize.const_zero(), llvm_i32.const_zero()], None, ).into_pointer_value(); let rhs_len = ctx.build_in_bounds_gep_and_load( prhs, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len);