Implement string equality operator using IRRT and optimise LLVM implementation #561

Merged
sb10q merged 9 commits from ramtej/nac3:feature/string-equality into master 2024-12-30 13:02:09 +08:00
5 changed files with 94 additions and 88 deletions
Showing only changes of commit dd8bf1a35e - Show all commits

View File

@ -8,3 +8,4 @@
#include "irrt/ndarray/def.hpp"
#include "irrt/ndarray/iter.hpp"
#include "irrt/ndarray/indexing.hpp"
#include "irrt/string.hpp"

View File

@ -0,0 +1,23 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
template<typename SizeT>
SizeT __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) {
if (len1 != len2){
return 0;
}
return (__builtin_memcmp(str1, str2, static_cast<SizeT>(len1)) == 0) ? 1 : 0;
}
} // namespace
extern "C" {
uint32_t nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) {
return __nac3_str_eq_impl<uint32_t>(str1, len1, str2, len2);
}
uint64_t nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) {
return __nac3_str_eq_impl<uint64_t>(str1, len1, str2, len2);
}
}

View File

@ -24,7 +24,7 @@ use super::{
irrt::*,
llvm_intrinsics::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
call_int_umin, call_memcpy_generic,
call_memcpy_generic,
},
macros::codegen_unreachable,
need_sret, numpy,
@ -2045,111 +2045,43 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
} else if left_ty == ctx.primitives.str {
assert!(ctx.unifier.unioned(left_ty, right_ty));
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let lhs = lhs.into_struct_value();
let rhs = rhs.into_struct_value();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
ctx.builder.build_store(plhs, lhs).unwrap();
let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
ctx.builder.build_store(prhs, rhs).unwrap();
let lhs_ptr = ctx.build_in_bounds_gep_and_load(
plhs,
&[llvm_usize.const_zero(), llvm_i32.const_zero()],
None,
).into_pointer_value();
let lhs_len = ctx.build_in_bounds_gep_and_load(
plhs,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
&[llvm_usize.const_zero(), llvm_i32.const_int(1, false)],
None,
).into_int_value();
let rhs_ptr = ctx.build_in_bounds_gep_and_load(
prhs,
&[llvm_usize.const_zero(), llvm_i32.const_zero()],
None,
).into_pointer_value();
let rhs_len = ctx.build_in_bounds_gep_and_load(
prhs,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
&[llvm_usize.const_zero(), llvm_i32.const_int(1, false)],
None,
).into_int_value();
let len = call_int_umin(ctx, lhs_len, rhs_len, None);
let current_bb = ctx.builder.get_insert_block().unwrap();
let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end");
ctx.builder.position_at_end(post_foreach_cmp);
let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap();
ctx.builder.position_at_end(current_bb);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let lhs_char = {
let plhs_data = ctx.build_in_bounds_gep_and_load(
plhs,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_pointer_value();
ctx.build_in_bounds_gep_and_load(
plhs_data,
&[i],
None
).into_int_value()
};
let rhs_char = {
let prhs_data = ctx.build_in_bounds_gep_and_load(
prhs,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_pointer_value();
ctx.build_in_bounds_gep_and_load(
prhs_data,
&[i],
None
).into_int_value()
};
gen_if_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap())
},
|_, ctx| {
let bb = ctx.builder.get_insert_block().unwrap();
cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]);
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
Ok(())
},
|_, _| Ok(()),
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let bb = ctx.builder.get_insert_block().unwrap();
let is_len_eq = ctx.builder.build_int_compare(
IntPredicate::EQ,
lhs_len,
rhs_len,
"",
).unwrap();
cmp_phi.add_incoming(&[(&is_len_eq, bb)]);
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
ctx.builder.position_at_end(post_foreach_cmp);
let cmp_phi = cmp_phi.as_basic_value().into_int_value();
// Invert the final value if __ne__
let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len);
if *op == Cmpop::NotEq {
ctx.builder.build_not(cmp_phi, "").unwrap()
ctx.builder.build_not(result, "").unwrap()
} else {
cmp_phi
result
}
} else if [left_ty, right_ty]
.iter()

View File

@ -15,12 +15,14 @@ pub use list::*;
pub use math::*;
pub use range::*;
pub use slice::*;
pub use string::*;
mod list;
mod math;
pub mod ndarray;
mod range;
mod slice;
mod string;
#[must_use]
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {

View File

@ -0,0 +1,48 @@
use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue};
use itertools::Either;
use crate::codegen::{macros::codegen_unreachable, CodeGenContext, CodeGenerator};
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
str1_ptr: PointerValue<'ctx>,
str1_len: IntValue<'ctx>,
str2_ptr: PointerValue<'ctx>,
str2_len: IntValue<'ctx>,
) -> IntValue<'ctx> {
let (func_name, return_type) = match ctx.ctx.i32_type().get_bit_width() {
32 => ("nac3_str_eq", ctx.ctx.i32_type()),
64 => ("nac3_str_eq64", ctx.ctx.i64_type()),
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let func = ctx.module.get_function(func_name).unwrap_or_else(|| {
ctx.module.add_function(
func_name,
return_type.fn_type(
&[
str1_ptr.get_type().into(),
str1_len.get_type().into(),
str2_ptr.get_type().into(),
str2_len.get_type().into(),
],
false,
),
None,
)
});
let result = ctx
.builder
.build_call(
func,
&[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()],
"str_eq_call",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, result)
}