From e13d7533297fec1f50cbfe3ac100c7b85b9bc3d0 Mon Sep 17 00:00:00 2001 From: ram Date: Wed, 11 Dec 2024 17:14:11 +0000 Subject: [PATCH] Implement string equality operator using IRRT and optimise LLVM implementation --- nac3core/irrt/irrt.cpp | 1 + nac3core/irrt/irrt/string.hpp | 22 ++++++ nac3core/src/codegen/expr.rs | 103 +++++----------------------- nac3core/src/codegen/irrt/mod.rs | 2 + nac3core/src/codegen/irrt/string.rs | 45 ++++++++++++ nac3standalone/demo/src/str.py | 8 ++- 6 files changed, 94 insertions(+), 87 deletions(-) create mode 100644 nac3core/irrt/irrt/string.hpp create mode 100644 nac3core/src/codegen/irrt/string.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 7966322..05d07ad 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -3,3 +3,4 @@ #include "irrt/math.hpp" #include "irrt/ndarray.hpp" #include "irrt/slice.hpp" +#include "irrt/string.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp new file mode 100644 index 0000000..5fc3571 --- /dev/null +++ b/nac3core/irrt/irrt/string.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "irrt/int_types.hpp" +namespace { +template +int32_t __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { + if (str1 == str2) return 1; + if (len1 != len2) return 0; + for (SizeT i = 0; i < len1; ++i) { + if (static_cast(str1[i]) != static_cast(str2[i])) { + return 0; + } + } + return 1; +} +} // namespace + +extern "C" { +int32_t nac3_str_eq(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { + return __nac3_str_eq_impl(str1, len1, str2, len2); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 2cd5670..9945d91 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -24,7 +24,7 @@ use super::{ irrt::*, llvm_intrinsics::{ call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, - call_int_umin, call_memcpy_generic, + call_memcpy_generic, }, macros::codegen_unreachable, need_sret, numpy, @@ -2072,111 +2072,42 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } else if left_ty == ctx.primitives.str { assert!(ctx.unifier.unioned(left_ty, right_ty)); - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let lhs = lhs.into_struct_value(); let rhs = rhs.into_struct_value(); + let llvm_i32 = ctx.ctx.i32_type(); + let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(plhs, lhs).unwrap(); let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); ctx.builder.build_store(prhs, rhs).unwrap(); + let lhs_ptr = ctx.build_in_bounds_gep_and_load( + plhs, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_pointer_value(); let lhs_len = ctx.build_in_bounds_gep_and_load( plhs, &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); + + let rhs_ptr = ctx.build_in_bounds_gep_and_load( + prhs, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_pointer_value(); let rhs_len = ctx.build_in_bounds_gep_and_load( prhs, &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], None, ).into_int_value(); - - let len = call_int_umin(ctx, lhs_len, rhs_len, None); - - let current_bb = ctx.builder.get_insert_block().unwrap(); - let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end"); - - ctx.builder.position_at_end(post_foreach_cmp); - let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap(); - ctx.builder.position_at_end(current_bb); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (len, false), - |generator, ctx, _, i| { - let lhs_char = { - let plhs_data = ctx.build_in_bounds_gep_and_load( - plhs, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - - ctx.build_in_bounds_gep_and_load( - plhs_data, - &[i], - None - ).into_int_value() - }; - let rhs_char = { - let prhs_data = ctx.build_in_bounds_gep_and_load( - prhs, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - None, - ).into_pointer_value(); - - ctx.build_in_bounds_gep_and_load( - prhs_data, - &[i], - None - ).into_int_value() - }; - - gen_if_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap()) - }, - |_, ctx| { - let bb = ctx.builder.get_insert_block().unwrap(); - cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]); - ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); - - Ok(()) - }, - |_, _| Ok(()), - )?; - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let bb = ctx.builder.get_insert_block().unwrap(); - let is_len_eq = ctx.builder.build_int_compare( - IntPredicate::EQ, - lhs_len, - rhs_len, - "", - ).unwrap(); - cmp_phi.add_incoming(&[(&is_len_eq, bb)]); - ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); - - ctx.builder.position_at_end(post_foreach_cmp); - let cmp_phi = cmp_phi.as_basic_value().into_int_value(); - - // Invert the final value if __ne__ + let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); if *op == Cmpop::NotEq { - ctx.builder.build_not(cmp_phi, "").unwrap() + ctx.builder.build_not(result, "").unwrap() } else { - cmp_phi + result } } else if [left_ty, right_ty] .iter() diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index f6c4a1e..3106928 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -15,11 +15,13 @@ pub use list::*; pub use math::*; pub use ndarray::*; pub use slice::*; +pub use string::*; mod list; mod math; mod ndarray; mod slice; +mod string; #[must_use] pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> { diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs new file mode 100644 index 0000000..0a9d3a1 --- /dev/null +++ b/nac3core/src/codegen/irrt/string.rs @@ -0,0 +1,45 @@ +use inkwell::{ + values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}, + AddressSpace, +}; +use itertools::Either; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. +pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + str1_ptr: PointerValue<'ctx>, + str1_len: IntValue<'ctx>, + str2_ptr: PointerValue<'ctx>, + str2_len: IntValue<'ctx>, +) -> IntValue<'ctx> { + let string_eq_fn = ctx.module.get_function("nac3_str_eq").unwrap_or_else(|| { + let i8_ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + let i64_type = ctx.ctx.i64_type(); + let i32_type = ctx.ctx.i32_type(); + let fn_type = i32_type.fn_type( + &[i8_ptr_type.into(), i64_type.into(), i8_ptr_type.into(), i64_type.into()], + false, + ); + ctx.module.add_function("nac3_str_eq", fn_type, None) + }); + let result = ctx + .builder + .build_call( + string_eq_fn, + &[ + str1_ptr.into(), + ctx.builder.build_int_z_extend(str1_len, ctx.ctx.i64_type(), "").unwrap().into(), + str2_ptr.into(), + ctx.builder.build_int_z_extend(str2_len, ctx.ctx.i64_type(), "").unwrap().into(), + ], + "string_eq", + ) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(); + generator.bool_to_i1(ctx, result) +} diff --git a/nac3standalone/demo/src/str.py b/nac3standalone/demo/src/str.py index d28f05e..11632dc 100644 --- a/nac3standalone/demo/src/str.py +++ b/nac3standalone/demo/src/str.py @@ -11,6 +11,9 @@ def str_eq(): output_bool("a" == "a") output_bool("test string" == "test string") output_bool("test string1" == "test string2") + output_bool("test" == "testing") + output_bool("abcd" == "abdc") + output_bool(" " == " ") def str_ne(): @@ -21,10 +24,13 @@ def str_ne(): output_bool("a" != "a") output_bool("test string" != "test string") output_bool("test string1" != "test string2") + output_bool("test" != "testing") + output_bool("abcd" != "abdc") + output_bool(" " != " ") def run() -> int32: str_eq() str_ne() - return 0 + return 0 \ No newline at end of file